openprotein-python 0.8.2__1-py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openprotein/__init__.py +164 -0
- openprotein/_version.py +48 -0
- openprotein/align/__init__.py +8 -0
- openprotein/align/align.py +395 -0
- openprotein/align/api.py +428 -0
- openprotein/align/future.py +55 -0
- openprotein/align/msa.py +129 -0
- openprotein/align/schemas.py +165 -0
- openprotein/base.py +181 -0
- openprotein/chains.py +88 -0
- openprotein/common/__init__.py +5 -0
- openprotein/common/features.py +7 -0
- openprotein/common/model_metadata.py +33 -0
- openprotein/common/reduction.py +8 -0
- openprotein/config.py +9 -0
- openprotein/csv.py +31 -0
- openprotein/data/__init__.py +9 -0
- openprotein/data/api.py +218 -0
- openprotein/data/assaydataset.py +178 -0
- openprotein/data/data.py +93 -0
- openprotein/data/schemas.py +27 -0
- openprotein/design/__init__.py +16 -0
- openprotein/design/api.py +259 -0
- openprotein/design/design.py +125 -0
- openprotein/design/future.py +146 -0
- openprotein/design/schemas.py +607 -0
- openprotein/embeddings/__init__.py +27 -0
- openprotein/embeddings/api.py +619 -0
- openprotein/embeddings/embeddings.py +151 -0
- openprotein/embeddings/esm.py +33 -0
- openprotein/embeddings/future.py +146 -0
- openprotein/embeddings/models.py +421 -0
- openprotein/embeddings/openprotein.py +21 -0
- openprotein/embeddings/poet.py +446 -0
- openprotein/embeddings/poet2.py +505 -0
- openprotein/embeddings/schemas.py +78 -0
- openprotein/errors.py +76 -0
- openprotein/fasta.py +92 -0
- openprotein/fold/__init__.py +21 -0
- openprotein/fold/alphafold2.py +131 -0
- openprotein/fold/api.py +287 -0
- openprotein/fold/boltz.py +691 -0
- openprotein/fold/esmfold.py +54 -0
- openprotein/fold/fold.py +107 -0
- openprotein/fold/future.py +509 -0
- openprotein/fold/models.py +139 -0
- openprotein/fold/schemas.py +39 -0
- openprotein/jobs/__init__.py +9 -0
- openprotein/jobs/api.py +71 -0
- openprotein/jobs/futures.py +746 -0
- openprotein/jobs/jobs.py +69 -0
- openprotein/jobs/schemas.py +135 -0
- openprotein/models/__init__.py +4 -0
- openprotein/models/base.py +63 -0
- openprotein/models/foundation/rfdiffusion.py +283 -0
- openprotein/models/models.py +33 -0
- openprotein/predictor/__init__.py +25 -0
- openprotein/predictor/api.py +384 -0
- openprotein/predictor/models.py +374 -0
- openprotein/predictor/prediction.py +79 -0
- openprotein/predictor/predictor.py +242 -0
- openprotein/predictor/schemas.py +113 -0
- openprotein/predictor/validate.py +40 -0
- openprotein/prompt/__init__.py +9 -0
- openprotein/prompt/api.py +505 -0
- openprotein/prompt/models.py +142 -0
- openprotein/prompt/prompt.py +130 -0
- openprotein/prompt/schemas.py +49 -0
- openprotein/protein.py +587 -0
- openprotein/svd/__init__.py +9 -0
- openprotein/svd/api.py +206 -0
- openprotein/svd/models.py +288 -0
- openprotein/svd/schemas.py +31 -0
- openprotein/svd/svd.py +134 -0
- openprotein/umap/__init__.py +9 -0
- openprotein/umap/api.py +259 -0
- openprotein/umap/models.py +211 -0
- openprotein/umap/schemas.py +35 -0
- openprotein/umap/umap.py +175 -0
- openprotein/utils/uuid.py +29 -0
- openprotein_python-0.8.2.dist-info/METADATA +176 -0
- openprotein_python-0.8.2.dist-info/RECORD +84 -0
- openprotein_python-0.8.2.dist-info/WHEEL +4 -0
- openprotein_python-0.8.2.dist-info/licenses/LICENSE.txt +30 -0
|
@@ -0,0 +1,374 @@
|
|
|
1
|
+
"""Predictor models for making predictions on new sequences."""
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
from openprotein.base import APISession
|
|
6
|
+
from openprotein.data import AssayDataset, DataAPI
|
|
7
|
+
from openprotein.embeddings import EmbeddingModel, EmbeddingsAPI
|
|
8
|
+
from openprotein.errors import InvalidParameterError
|
|
9
|
+
from openprotein.jobs import Future, JobsAPI, JobType
|
|
10
|
+
from openprotein.svd import SVDAPI, SVDModel
|
|
11
|
+
|
|
12
|
+
from . import api
|
|
13
|
+
from .prediction import PredictionResultFuture
|
|
14
|
+
from .schemas import (
|
|
15
|
+
PredictorEnsembleJob,
|
|
16
|
+
PredictorMetadata,
|
|
17
|
+
PredictorTrainJob,
|
|
18
|
+
PredictorType,
|
|
19
|
+
)
|
|
20
|
+
from .validate import CVResultFuture
|
|
21
|
+
|
|
22
|
+
if TYPE_CHECKING:
|
|
23
|
+
from openprotein.design import ModelCriterion
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class PredictorModel(Future):
|
|
27
|
+
"""
|
|
28
|
+
Class providing predict endpoint for fitted predictor models.
|
|
29
|
+
|
|
30
|
+
Also implements a Future that waits for train job.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
job: PredictorTrainJob | None
|
|
34
|
+
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
session: APISession,
|
|
38
|
+
job: PredictorTrainJob | PredictorEnsembleJob | None = None,
|
|
39
|
+
metadata: PredictorMetadata | None = None,
|
|
40
|
+
):
|
|
41
|
+
"""
|
|
42
|
+
Construct a predictor model.
|
|
43
|
+
|
|
44
|
+
Takes in either a train job, or the predictor metadata.
|
|
45
|
+
|
|
46
|
+
:meta private:
|
|
47
|
+
"""
|
|
48
|
+
self._training_assay = None
|
|
49
|
+
|
|
50
|
+
# initialize the predictor metadata
|
|
51
|
+
if metadata is None:
|
|
52
|
+
if job is None or job.job_id is None:
|
|
53
|
+
raise ValueError("Expected predictor metadata or job")
|
|
54
|
+
metadata = api.predictor_get(session, job.job_id)
|
|
55
|
+
self._metadata = metadata
|
|
56
|
+
if job is None:
|
|
57
|
+
if metadata.model_spec.type != PredictorType.ENSEMBLE:
|
|
58
|
+
jobs_api = getattr(session, "jobs", None)
|
|
59
|
+
assert isinstance(jobs_api, JobsAPI)
|
|
60
|
+
job = PredictorTrainJob.create(jobs_api.get_job(job_id=metadata.id))
|
|
61
|
+
else:
|
|
62
|
+
job = PredictorEnsembleJob(
|
|
63
|
+
created_date=self._metadata.created_date,
|
|
64
|
+
status=self._metadata.status,
|
|
65
|
+
job_type=JobType.predictor_train,
|
|
66
|
+
)
|
|
67
|
+
super().__init__(session, job)
|
|
68
|
+
|
|
69
|
+
def __str__(self) -> str:
|
|
70
|
+
return str(self.metadata)
|
|
71
|
+
|
|
72
|
+
def __repr__(self) -> str:
|
|
73
|
+
return repr(self.metadata)
|
|
74
|
+
|
|
75
|
+
def __or__(self, model: "PredictorModel") -> "PredictorModelGroup":
|
|
76
|
+
if self.sequence_length is not None:
|
|
77
|
+
if model.sequence_length != self.sequence_length:
|
|
78
|
+
raise ValueError(
|
|
79
|
+
"Expected sequence lengths to either match or be None."
|
|
80
|
+
)
|
|
81
|
+
return PredictorModelGroup(
|
|
82
|
+
session=self.session,
|
|
83
|
+
models=[self, model],
|
|
84
|
+
sequence_length=self.sequence_length or model.sequence_length,
|
|
85
|
+
check_sequence_length=False,
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
def __lt__(self, target: float) -> "ModelCriterion":
|
|
89
|
+
from openprotein.design import ModelCriterion
|
|
90
|
+
|
|
91
|
+
if len(self.training_properties) == 1:
|
|
92
|
+
return ModelCriterion(
|
|
93
|
+
model_id=self.id,
|
|
94
|
+
measurement_name=self.training_properties[0],
|
|
95
|
+
criterion=ModelCriterion.Criterion(
|
|
96
|
+
target=target, direction=ModelCriterion.Criterion.DirectionEnum.lt
|
|
97
|
+
),
|
|
98
|
+
)
|
|
99
|
+
raise self.InvalidMultitaskModelToCriterion()
|
|
100
|
+
|
|
101
|
+
def __gt__(self, target: float) -> "ModelCriterion":
|
|
102
|
+
from openprotein.design import ModelCriterion
|
|
103
|
+
|
|
104
|
+
if len(self.training_properties) == 1:
|
|
105
|
+
return ModelCriterion(
|
|
106
|
+
model_id=self.id,
|
|
107
|
+
measurement_name=self.training_properties[0],
|
|
108
|
+
criterion=ModelCriterion.Criterion(
|
|
109
|
+
target=target, direction=ModelCriterion.Criterion.DirectionEnum.gt
|
|
110
|
+
),
|
|
111
|
+
)
|
|
112
|
+
raise self.InvalidMultitaskModelToCriterion()
|
|
113
|
+
|
|
114
|
+
def __eq__(self, target: float) -> "ModelCriterion":
|
|
115
|
+
from openprotein.design import ModelCriterion
|
|
116
|
+
|
|
117
|
+
if len(self.training_properties) == 1:
|
|
118
|
+
return ModelCriterion(
|
|
119
|
+
model_id=self.id,
|
|
120
|
+
measurement_name=self.training_properties[0],
|
|
121
|
+
criterion=ModelCriterion.Criterion(
|
|
122
|
+
target=target, direction=ModelCriterion.Criterion.DirectionEnum.eq
|
|
123
|
+
),
|
|
124
|
+
)
|
|
125
|
+
raise self.InvalidMultitaskModelToCriterion()
|
|
126
|
+
|
|
127
|
+
class InvalidMultitaskModelToCriterion(Exception):
|
|
128
|
+
"""
|
|
129
|
+
Exception raised when trying to create model criterion from multitask predictor.
|
|
130
|
+
|
|
131
|
+
:meta private:
|
|
132
|
+
"""
|
|
133
|
+
|
|
134
|
+
@property
|
|
135
|
+
def id(self):
|
|
136
|
+
"""ID of predictor."""
|
|
137
|
+
return self._metadata.id
|
|
138
|
+
|
|
139
|
+
@property
|
|
140
|
+
def reduction(self):
|
|
141
|
+
"""The reduction of th embeddings used to train the predictor, if any."""
|
|
142
|
+
return (
|
|
143
|
+
self._metadata.model_spec.features.reduction
|
|
144
|
+
if self._metadata.model_spec.features is not None
|
|
145
|
+
else None
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
@property
|
|
149
|
+
def sequence_length(self):
|
|
150
|
+
"""The sequence length constraint on the predictor, if any."""
|
|
151
|
+
if (constraints := self._metadata.model_spec.constraints) is not None:
|
|
152
|
+
return constraints.sequence_length
|
|
153
|
+
return None
|
|
154
|
+
|
|
155
|
+
@property
|
|
156
|
+
def training_assay(self) -> AssayDataset:
|
|
157
|
+
"""The assay the predictor was trained on."""
|
|
158
|
+
if self._training_assay is None:
|
|
159
|
+
self._training_assay = self.get_assay()
|
|
160
|
+
return self._training_assay
|
|
161
|
+
|
|
162
|
+
@property
|
|
163
|
+
def training_properties(self) -> list[str]:
|
|
164
|
+
"""The list of properties the predictor was trained on."""
|
|
165
|
+
return self._metadata.training_dataset.properties
|
|
166
|
+
|
|
167
|
+
@property
|
|
168
|
+
def metadata(self):
|
|
169
|
+
"""The predictor metadata."""
|
|
170
|
+
self._refresh_metadata()
|
|
171
|
+
return self._metadata
|
|
172
|
+
|
|
173
|
+
def _refresh_metadata(self):
|
|
174
|
+
if not self._metadata.is_done():
|
|
175
|
+
self._metadata = api.predictor_get(self.session, self._metadata.id)
|
|
176
|
+
|
|
177
|
+
def get_model(self) -> EmbeddingModel | SVDModel | None:
|
|
178
|
+
"""Retrieve the embeddings or SVD model used to create embeddings to train on."""
|
|
179
|
+
if (
|
|
180
|
+
(features := self._metadata.model_spec.features)
|
|
181
|
+
and (model_id := features.model_id) is None
|
|
182
|
+
or features is None
|
|
183
|
+
):
|
|
184
|
+
return None
|
|
185
|
+
elif features.type.upper() == "PLM":
|
|
186
|
+
model = EmbeddingModel.create(session=self.session, model_id=model_id)
|
|
187
|
+
elif features.type.upper() == "SVD":
|
|
188
|
+
svd_api = getattr(self.session, "svd", None)
|
|
189
|
+
assert isinstance(svd_api, SVDAPI)
|
|
190
|
+
model = svd_api.get_svd(svd_id=model_id)
|
|
191
|
+
else:
|
|
192
|
+
raise ValueError(f"Unexpected feature type {features.type}")
|
|
193
|
+
return model
|
|
194
|
+
|
|
195
|
+
@property
|
|
196
|
+
def model(self) -> EmbeddingModel | SVDModel | None:
|
|
197
|
+
"""The embeddings or SVD model used to create embeddings to train on."""
|
|
198
|
+
return self.get_model()
|
|
199
|
+
|
|
200
|
+
def delete(self) -> bool:
|
|
201
|
+
"""
|
|
202
|
+
Delete this predictor model.
|
|
203
|
+
"""
|
|
204
|
+
return api.predictor_delete(self.session, self.id)
|
|
205
|
+
|
|
206
|
+
def get(self, verbose: bool = False):
|
|
207
|
+
"""
|
|
208
|
+
Returns the train loss curves.
|
|
209
|
+
"""
|
|
210
|
+
return self.metadata.traingraphs
|
|
211
|
+
|
|
212
|
+
def get_assay(self) -> AssayDataset:
|
|
213
|
+
"""
|
|
214
|
+
Get assay used for train job.
|
|
215
|
+
|
|
216
|
+
Returns
|
|
217
|
+
-------
|
|
218
|
+
AssayDataset: Assay dataset used for train job.
|
|
219
|
+
"""
|
|
220
|
+
data_api = getattr(self.session, "data", None)
|
|
221
|
+
assert isinstance(data_api, DataAPI)
|
|
222
|
+
return data_api.get(assay_id=self._metadata.training_dataset.assay_id)
|
|
223
|
+
|
|
224
|
+
def crossvalidate(self, n_splits: int | None = None) -> CVResultFuture:
|
|
225
|
+
"""
|
|
226
|
+
Run a crossvalidation on the trained predictor.
|
|
227
|
+
"""
|
|
228
|
+
return CVResultFuture.create(
|
|
229
|
+
session=self.session,
|
|
230
|
+
job=api.predictor_crossvalidate_post(
|
|
231
|
+
session=self.session,
|
|
232
|
+
predictor_id=self.id,
|
|
233
|
+
n_splits=n_splits,
|
|
234
|
+
),
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
def predict(self, sequences: list[bytes] | list[str]) -> PredictionResultFuture:
|
|
238
|
+
"""
|
|
239
|
+
Make predictions about the trained properties for a list of sequences.
|
|
240
|
+
"""
|
|
241
|
+
if self.sequence_length is not None:
|
|
242
|
+
for sequence in sequences:
|
|
243
|
+
# convert to string to check token length
|
|
244
|
+
sequence = sequence if isinstance(sequence, str) else sequence.decode()
|
|
245
|
+
if len(sequence) != self.sequence_length:
|
|
246
|
+
raise InvalidParameterError(
|
|
247
|
+
f"Expected sequences to predict to be of length {self.sequence_length}"
|
|
248
|
+
)
|
|
249
|
+
return PredictionResultFuture.create(
|
|
250
|
+
session=self.session,
|
|
251
|
+
job=api.predictor_predict_post(
|
|
252
|
+
session=self.session, predictor_id=self.id, sequences=sequences
|
|
253
|
+
),
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
def single_site(self, sequence: bytes | str) -> PredictionResultFuture:
|
|
257
|
+
"""
|
|
258
|
+
Compute the single-site mutated predictions of a base sequence.
|
|
259
|
+
"""
|
|
260
|
+
if self.sequence_length is not None:
|
|
261
|
+
# convert to string to check token length
|
|
262
|
+
seq = sequence if isinstance(sequence, str) else sequence.decode()
|
|
263
|
+
if len(seq) != self.sequence_length:
|
|
264
|
+
raise InvalidParameterError(
|
|
265
|
+
f"Expected sequence to predict to be of length {self.sequence_length}"
|
|
266
|
+
)
|
|
267
|
+
return PredictionResultFuture.create(
|
|
268
|
+
session=self.session,
|
|
269
|
+
job=api.predictor_predict_single_site_post(
|
|
270
|
+
session=self.session, predictor_id=self.id, base_sequence=sequence
|
|
271
|
+
),
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
class PredictorModelGroup(Future):
|
|
276
|
+
"""
|
|
277
|
+
Class providing predict endpoint for fitted predictor models.
|
|
278
|
+
|
|
279
|
+
Also implements a Future that waits for train job.
|
|
280
|
+
"""
|
|
281
|
+
|
|
282
|
+
__models__: list[PredictorModel]
|
|
283
|
+
|
|
284
|
+
def __init__(
|
|
285
|
+
self,
|
|
286
|
+
session: APISession,
|
|
287
|
+
models: list[PredictorModel],
|
|
288
|
+
sequence_length: int | None = None,
|
|
289
|
+
check_sequence_length: bool = True, # turn off checking - prevent n^2 operation when chaining many
|
|
290
|
+
):
|
|
291
|
+
if len(models) == 0:
|
|
292
|
+
raise ValueError("Expected at least one model to group")
|
|
293
|
+
# calculate and check sequence length compatibility
|
|
294
|
+
if check_sequence_length:
|
|
295
|
+
for m in models:
|
|
296
|
+
if m.sequence_length is not None:
|
|
297
|
+
if sequence_length is None:
|
|
298
|
+
sequence_length = m.sequence_length
|
|
299
|
+
elif sequence_length != m.sequence_length:
|
|
300
|
+
raise ValueError(
|
|
301
|
+
"Expected sequence lengths of all models to either match or be None."
|
|
302
|
+
)
|
|
303
|
+
self.sequence_length = sequence_length
|
|
304
|
+
self.session = session
|
|
305
|
+
self.__models__ = models
|
|
306
|
+
|
|
307
|
+
def __str__(self) -> str:
|
|
308
|
+
return repr(self.__models__)
|
|
309
|
+
|
|
310
|
+
def __repr__(self) -> str:
|
|
311
|
+
return repr(self.__models__)
|
|
312
|
+
|
|
313
|
+
def __or__(self, model: PredictorModel) -> "PredictorModelGroup":
|
|
314
|
+
if self.sequence_length is not None:
|
|
315
|
+
if model.sequence_length != self.sequence_length:
|
|
316
|
+
raise ValueError(
|
|
317
|
+
"Expected sequence lengths to either match or be None."
|
|
318
|
+
)
|
|
319
|
+
return PredictorModelGroup(
|
|
320
|
+
session=self.session,
|
|
321
|
+
models=self.__models__ + [model],
|
|
322
|
+
sequence_length=self.sequence_length or model.sequence_length,
|
|
323
|
+
check_sequence_length=False,
|
|
324
|
+
)
|
|
325
|
+
|
|
326
|
+
def predict(self, sequences: list[bytes] | list[str]) -> PredictionResultFuture:
|
|
327
|
+
"""
|
|
328
|
+
Make predictions about the trained properties for a list of sequences.
|
|
329
|
+
"""
|
|
330
|
+
if self.sequence_length is not None:
|
|
331
|
+
for sequence in sequences:
|
|
332
|
+
# convert to string to check token length
|
|
333
|
+
sequence = sequence if isinstance(sequence, str) else sequence.decode()
|
|
334
|
+
if len(sequence) != self.sequence_length:
|
|
335
|
+
raise InvalidParameterError(
|
|
336
|
+
f"Expected sequences to predict to be of length {self.sequence_length}"
|
|
337
|
+
)
|
|
338
|
+
return PredictionResultFuture.create(
|
|
339
|
+
session=self.session,
|
|
340
|
+
job=api.predictor_predict_multi_post(
|
|
341
|
+
session=self.session,
|
|
342
|
+
predictor_ids=[m.id for m in self.__models__],
|
|
343
|
+
sequences=sequences,
|
|
344
|
+
),
|
|
345
|
+
)
|
|
346
|
+
|
|
347
|
+
def single_site(self, sequence: bytes | str) -> PredictionResultFuture:
|
|
348
|
+
"""
|
|
349
|
+
Compute the single-site mutated predictions of a base sequence.
|
|
350
|
+
"""
|
|
351
|
+
if self.sequence_length is not None:
|
|
352
|
+
# convert to string to check token length
|
|
353
|
+
seq = sequence if isinstance(sequence, str) else sequence.decode()
|
|
354
|
+
if len(seq) != self.sequence_length:
|
|
355
|
+
raise InvalidParameterError(
|
|
356
|
+
f"Expected sequence to predict to be of length {self.sequence_length}"
|
|
357
|
+
)
|
|
358
|
+
return PredictionResultFuture.create(
|
|
359
|
+
session=self.session,
|
|
360
|
+
job=api.predictor_predict_single_site_post(
|
|
361
|
+
session=self.session, predictor_id=self.id, base_sequence=sequence
|
|
362
|
+
),
|
|
363
|
+
)
|
|
364
|
+
|
|
365
|
+
def get(self, verbose: bool = False):
|
|
366
|
+
"""
|
|
367
|
+
Returns the predictor model.
|
|
368
|
+
|
|
369
|
+
:meta private:
|
|
370
|
+
"""
|
|
371
|
+
return self
|
|
372
|
+
|
|
373
|
+
def delete(self):
|
|
374
|
+
return api.predictor_delete(session=self.session, predictor_id=self.id)
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
"""Prediction results represented as futures."""
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
from openprotein.base import APISession
|
|
6
|
+
from openprotein.jobs import Future
|
|
7
|
+
|
|
8
|
+
from . import api
|
|
9
|
+
from .schemas import (
|
|
10
|
+
PredictJob,
|
|
11
|
+
PredictMultiJob,
|
|
12
|
+
PredictMultiSingleSiteJob,
|
|
13
|
+
PredictSingleSiteJob,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class PredictionResultFuture(Future):
|
|
18
|
+
"""Prediction results represented as a future."""
|
|
19
|
+
|
|
20
|
+
job: PredictJob | PredictSingleSiteJob | PredictMultiJob | PredictMultiSingleSiteJob
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
session: APISession,
|
|
25
|
+
job: (
|
|
26
|
+
PredictJob
|
|
27
|
+
| PredictSingleSiteJob
|
|
28
|
+
| PredictMultiJob
|
|
29
|
+
| PredictMultiSingleSiteJob
|
|
30
|
+
),
|
|
31
|
+
sequences: list[bytes] | None = None,
|
|
32
|
+
):
|
|
33
|
+
super().__init__(session, job)
|
|
34
|
+
self._sequences = sequences
|
|
35
|
+
|
|
36
|
+
@property
|
|
37
|
+
def sequences(self):
|
|
38
|
+
if self._sequences is None:
|
|
39
|
+
self._sequences = api.predictor_predict_get_sequences(
|
|
40
|
+
self.session, self.job.job_id
|
|
41
|
+
)
|
|
42
|
+
return self._sequences
|
|
43
|
+
|
|
44
|
+
@property
|
|
45
|
+
def id(self):
|
|
46
|
+
return self.job.job_id
|
|
47
|
+
|
|
48
|
+
def __keys__(self):
|
|
49
|
+
return self.sequences
|
|
50
|
+
|
|
51
|
+
def get_item(self, sequence: bytes) -> tuple[np.ndarray, np.ndarray]:
|
|
52
|
+
"""
|
|
53
|
+
Get embedding results for specified sequence.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
sequence (bytes): sequence to fetch results for
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
mu (np.ndarray): means of sequence prediction
|
|
60
|
+
var (np.ndarray): variances of sequence prediction
|
|
61
|
+
"""
|
|
62
|
+
data = api.predictor_predict_get_sequence_result(
|
|
63
|
+
self.session, self.job.job_id, sequence
|
|
64
|
+
)
|
|
65
|
+
return api.decode_predict(data)
|
|
66
|
+
|
|
67
|
+
def get(self, verbose: bool = False) -> tuple[np.ndarray, np.ndarray]:
|
|
68
|
+
"""
|
|
69
|
+
Get embedding results for specified sequence.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
sequence (bytes): sequence to fetch results for
|
|
73
|
+
|
|
74
|
+
Returns:
|
|
75
|
+
mu (np.ndarray): means of predictions
|
|
76
|
+
var (np.ndarray): variances of predictions
|
|
77
|
+
"""
|
|
78
|
+
data = api.predictor_predict_get_batched_result(self.session, self.job.job_id)
|
|
79
|
+
return api.decode_predict(data, batched=True)
|
|
@@ -0,0 +1,242 @@
|
|
|
1
|
+
"""Predictor API providing the interface to train and predict predictors."""
|
|
2
|
+
|
|
3
|
+
from openprotein.base import APISession
|
|
4
|
+
from openprotein.common import FeatureType, ReductionType
|
|
5
|
+
from openprotein.data import (
|
|
6
|
+
AssayDataset,
|
|
7
|
+
AssayMetadata,
|
|
8
|
+
)
|
|
9
|
+
from openprotein.embeddings import EmbeddingModel, EmbeddingsAPI
|
|
10
|
+
from openprotein.errors import InvalidParameterError
|
|
11
|
+
from openprotein.svd import SVDAPI, SVDModel
|
|
12
|
+
|
|
13
|
+
from . import api
|
|
14
|
+
from .models import PredictorModel
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class PredictorAPI:
|
|
18
|
+
"""Predictor API providing the interface to train and predict predictors."""
|
|
19
|
+
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
session: APISession,
|
|
23
|
+
):
|
|
24
|
+
self.session = session
|
|
25
|
+
|
|
26
|
+
def get_predictor(
|
|
27
|
+
self,
|
|
28
|
+
predictor_id: str,
|
|
29
|
+
include_stats: bool = False,
|
|
30
|
+
include_calibration_curves: bool = False,
|
|
31
|
+
) -> PredictorModel:
|
|
32
|
+
"""
|
|
33
|
+
Get predictor by model_id.
|
|
34
|
+
|
|
35
|
+
PredictorModel allows all the usual prediction job manipulation:
|
|
36
|
+
e.g. making POST and GET requests for this predictor specifically.
|
|
37
|
+
|
|
38
|
+
Parameters
|
|
39
|
+
----------
|
|
40
|
+
predictor_id : str
|
|
41
|
+
The model identifier.
|
|
42
|
+
include_stats : bool
|
|
43
|
+
Whether to include stats of the predictor from the latest evaluation
|
|
44
|
+
(pearson, spearman, ece). Default is False.
|
|
45
|
+
include_calibration_curves : bool
|
|
46
|
+
Whether to include calibration curves of the predictor from the latest
|
|
47
|
+
evaluation. Default is False.
|
|
48
|
+
|
|
49
|
+
Returns
|
|
50
|
+
-------
|
|
51
|
+
PredictorModel
|
|
52
|
+
The predictor model to inspect and make predictions with.
|
|
53
|
+
|
|
54
|
+
Raises
|
|
55
|
+
------
|
|
56
|
+
HTTPError
|
|
57
|
+
If the GET request does not succeed.
|
|
58
|
+
"""
|
|
59
|
+
metadata = api.predictor_get(
|
|
60
|
+
session=self.session,
|
|
61
|
+
predictor_id=predictor_id,
|
|
62
|
+
include_stats=include_stats,
|
|
63
|
+
include_calibration_curves=include_calibration_curves,
|
|
64
|
+
)
|
|
65
|
+
return PredictorModel(
|
|
66
|
+
session=self.session,
|
|
67
|
+
metadata=metadata,
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
def list_predictors(
|
|
71
|
+
self,
|
|
72
|
+
limit: int = 100,
|
|
73
|
+
offset: int = 0,
|
|
74
|
+
include_stats: bool = False,
|
|
75
|
+
include_calibration_curves: bool = False,
|
|
76
|
+
) -> list[PredictorModel]:
|
|
77
|
+
"""
|
|
78
|
+
List predictors available.
|
|
79
|
+
|
|
80
|
+
Parameters
|
|
81
|
+
----------
|
|
82
|
+
limit : int
|
|
83
|
+
Limit of the number of predictors to return in list. Default is 100.
|
|
84
|
+
offset : int
|
|
85
|
+
Offset to the predictors to query for paged queries. Default is 0.
|
|
86
|
+
include_stats : bool
|
|
87
|
+
Whether to include stats of each predictor from the latest evaluation
|
|
88
|
+
(pearson, spearman, ece). Default is False.
|
|
89
|
+
include_calibration_curves : bool
|
|
90
|
+
Whether to include calibration curves of each predictor from the latest
|
|
91
|
+
evaluation. Default is False.
|
|
92
|
+
|
|
93
|
+
Returns
|
|
94
|
+
-------
|
|
95
|
+
list[PredictorModel]
|
|
96
|
+
List of predictor models to inspect and make predictions with.
|
|
97
|
+
|
|
98
|
+
Raises
|
|
99
|
+
------
|
|
100
|
+
HTTPError
|
|
101
|
+
If the GET request does not succeed.
|
|
102
|
+
"""
|
|
103
|
+
metadatas = api.predictor_list(
|
|
104
|
+
session=self.session,
|
|
105
|
+
limit=limit,
|
|
106
|
+
offset=offset,
|
|
107
|
+
include_stats=include_stats,
|
|
108
|
+
include_calibration_curves=include_calibration_curves,
|
|
109
|
+
)
|
|
110
|
+
return [
|
|
111
|
+
PredictorModel(
|
|
112
|
+
session=self.session,
|
|
113
|
+
metadata=m,
|
|
114
|
+
)
|
|
115
|
+
for m in metadatas
|
|
116
|
+
]
|
|
117
|
+
|
|
118
|
+
def fit_gp(
|
|
119
|
+
self,
|
|
120
|
+
assay: AssayDataset | AssayMetadata | str,
|
|
121
|
+
properties: list[str],
|
|
122
|
+
model: EmbeddingModel | SVDModel | str,
|
|
123
|
+
feature_type: FeatureType | None = None,
|
|
124
|
+
reduction: ReductionType | None = None,
|
|
125
|
+
name: str | None = None,
|
|
126
|
+
description: str | None = None,
|
|
127
|
+
**kwargs,
|
|
128
|
+
) -> PredictorModel:
|
|
129
|
+
"""
|
|
130
|
+
Fit a GP on an assay with the specified feature model and hyperparameters.
|
|
131
|
+
|
|
132
|
+
Parameters
|
|
133
|
+
----------
|
|
134
|
+
assay : AssayMetadata or AssayDataset or str
|
|
135
|
+
Assay to fit GP on.
|
|
136
|
+
properties : list of str
|
|
137
|
+
Properties in the assay to fit the gp on.
|
|
138
|
+
model : EmbeddingModel or SVDModel or str
|
|
139
|
+
Instance of either EmbeddingModel or SVDModel to use depending
|
|
140
|
+
on feature type. Can also be a str specifying the model id,
|
|
141
|
+
but then feature_type would have to be specified.
|
|
142
|
+
feature_type : FeatureType or None
|
|
143
|
+
Type of features to use for encoding sequences. "SVD" or "PLM".
|
|
144
|
+
None would require model to be EmbeddingModel or SVDModel.
|
|
145
|
+
reduction : str or None, optional
|
|
146
|
+
Type of embedding reduction to use for computing features.
|
|
147
|
+
E.g. "MEAN" or "SUM". Used only if using EmbeddingModel, and
|
|
148
|
+
must be non-nil if using an EmbeddingModel. Defaults to None.
|
|
149
|
+
kwargs :
|
|
150
|
+
Additional keyword arguments to be passed to foundational models, e.g. prompt_id for PoET models.
|
|
151
|
+
|
|
152
|
+
Returns
|
|
153
|
+
-------
|
|
154
|
+
PredictorModel
|
|
155
|
+
The GP model being fit.
|
|
156
|
+
"""
|
|
157
|
+
# extract feature type
|
|
158
|
+
feature_type = (
|
|
159
|
+
FeatureType.PLM
|
|
160
|
+
if isinstance(model, EmbeddingModel)
|
|
161
|
+
else FeatureType.SVD if isinstance(model, SVDModel) else feature_type
|
|
162
|
+
)
|
|
163
|
+
if feature_type is None:
|
|
164
|
+
raise InvalidParameterError(
|
|
165
|
+
"Expected feature_type to be provided if passing str model_id as model"
|
|
166
|
+
)
|
|
167
|
+
# get model if model_id
|
|
168
|
+
if feature_type == FeatureType.PLM:
|
|
169
|
+
if reduction is None:
|
|
170
|
+
raise InvalidParameterError(
|
|
171
|
+
"Expected reduction if using EmbeddingModel"
|
|
172
|
+
)
|
|
173
|
+
if isinstance(model, str):
|
|
174
|
+
embeddings_api = getattr(self.session, "embedding", None)
|
|
175
|
+
assert isinstance(embeddings_api, EmbeddingsAPI)
|
|
176
|
+
model = embeddings_api.get_model(model)
|
|
177
|
+
assert isinstance(model, EmbeddingModel), "Expected EmbeddingModel"
|
|
178
|
+
model_id = model.id
|
|
179
|
+
elif feature_type == FeatureType.SVD:
|
|
180
|
+
if isinstance(model, str):
|
|
181
|
+
svd_api = getattr(self.session, "svd", None)
|
|
182
|
+
assert isinstance(svd_api, SVDAPI)
|
|
183
|
+
model = svd_api.get_svd(model)
|
|
184
|
+
assert isinstance(model, SVDModel), "Expected SVDModel"
|
|
185
|
+
model_id = model.id
|
|
186
|
+
# get assay_id
|
|
187
|
+
assay_id = (
|
|
188
|
+
assay.assay_id
|
|
189
|
+
if isinstance(assay, AssayMetadata)
|
|
190
|
+
else assay.id if isinstance(assay, AssayDataset) else assay
|
|
191
|
+
)
|
|
192
|
+
return PredictorModel(
|
|
193
|
+
session=self.session,
|
|
194
|
+
job=api.predictor_fit_gp_post(
|
|
195
|
+
session=self.session,
|
|
196
|
+
assay_id=assay_id,
|
|
197
|
+
properties=properties,
|
|
198
|
+
feature_type=feature_type,
|
|
199
|
+
model_id=model_id,
|
|
200
|
+
reduction=reduction,
|
|
201
|
+
name=name,
|
|
202
|
+
description=description,
|
|
203
|
+
**kwargs,
|
|
204
|
+
),
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
def delete_predictor(self, predictor_id: str) -> bool:
|
|
208
|
+
"""
|
|
209
|
+
Delete predictor model.
|
|
210
|
+
|
|
211
|
+
Parameters
|
|
212
|
+
----------
|
|
213
|
+
predictor_id : str
|
|
214
|
+
The ID of the predictor.
|
|
215
|
+
Returns
|
|
216
|
+
-------
|
|
217
|
+
bool
|
|
218
|
+
True: successful deletion
|
|
219
|
+
|
|
220
|
+
"""
|
|
221
|
+
return api.predictor_delete(session=self.session, predictor_id=predictor_id)
|
|
222
|
+
|
|
223
|
+
def ensemble(self, predictors: list[PredictorModel]) -> PredictorModel:
|
|
224
|
+
"""
|
|
225
|
+
Ensemble predictor models together.
|
|
226
|
+
|
|
227
|
+
Parameters
|
|
228
|
+
__________
|
|
229
|
+
predictors: list[PredictorModel]
|
|
230
|
+
List of predictors to ensemble together.
|
|
231
|
+
Returns
|
|
232
|
+
-------
|
|
233
|
+
PredictorModel
|
|
234
|
+
Ensembled predictor model
|
|
235
|
+
"""
|
|
236
|
+
return PredictorModel(
|
|
237
|
+
session=self.session,
|
|
238
|
+
metadata=api.predictor_ensemble(
|
|
239
|
+
session=self.session,
|
|
240
|
+
predictor_ids=[predictor.id for predictor in predictors],
|
|
241
|
+
),
|
|
242
|
+
)
|