openprotein-python 0.8.2__1-py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (84) hide show
  1. openprotein/__init__.py +164 -0
  2. openprotein/_version.py +48 -0
  3. openprotein/align/__init__.py +8 -0
  4. openprotein/align/align.py +395 -0
  5. openprotein/align/api.py +428 -0
  6. openprotein/align/future.py +55 -0
  7. openprotein/align/msa.py +129 -0
  8. openprotein/align/schemas.py +165 -0
  9. openprotein/base.py +181 -0
  10. openprotein/chains.py +88 -0
  11. openprotein/common/__init__.py +5 -0
  12. openprotein/common/features.py +7 -0
  13. openprotein/common/model_metadata.py +33 -0
  14. openprotein/common/reduction.py +8 -0
  15. openprotein/config.py +9 -0
  16. openprotein/csv.py +31 -0
  17. openprotein/data/__init__.py +9 -0
  18. openprotein/data/api.py +218 -0
  19. openprotein/data/assaydataset.py +178 -0
  20. openprotein/data/data.py +93 -0
  21. openprotein/data/schemas.py +27 -0
  22. openprotein/design/__init__.py +16 -0
  23. openprotein/design/api.py +259 -0
  24. openprotein/design/design.py +125 -0
  25. openprotein/design/future.py +146 -0
  26. openprotein/design/schemas.py +607 -0
  27. openprotein/embeddings/__init__.py +27 -0
  28. openprotein/embeddings/api.py +619 -0
  29. openprotein/embeddings/embeddings.py +151 -0
  30. openprotein/embeddings/esm.py +33 -0
  31. openprotein/embeddings/future.py +146 -0
  32. openprotein/embeddings/models.py +421 -0
  33. openprotein/embeddings/openprotein.py +21 -0
  34. openprotein/embeddings/poet.py +446 -0
  35. openprotein/embeddings/poet2.py +505 -0
  36. openprotein/embeddings/schemas.py +78 -0
  37. openprotein/errors.py +76 -0
  38. openprotein/fasta.py +92 -0
  39. openprotein/fold/__init__.py +21 -0
  40. openprotein/fold/alphafold2.py +131 -0
  41. openprotein/fold/api.py +287 -0
  42. openprotein/fold/boltz.py +691 -0
  43. openprotein/fold/esmfold.py +54 -0
  44. openprotein/fold/fold.py +107 -0
  45. openprotein/fold/future.py +509 -0
  46. openprotein/fold/models.py +139 -0
  47. openprotein/fold/schemas.py +39 -0
  48. openprotein/jobs/__init__.py +9 -0
  49. openprotein/jobs/api.py +71 -0
  50. openprotein/jobs/futures.py +746 -0
  51. openprotein/jobs/jobs.py +69 -0
  52. openprotein/jobs/schemas.py +135 -0
  53. openprotein/models/__init__.py +4 -0
  54. openprotein/models/base.py +63 -0
  55. openprotein/models/foundation/rfdiffusion.py +283 -0
  56. openprotein/models/models.py +33 -0
  57. openprotein/predictor/__init__.py +25 -0
  58. openprotein/predictor/api.py +384 -0
  59. openprotein/predictor/models.py +374 -0
  60. openprotein/predictor/prediction.py +79 -0
  61. openprotein/predictor/predictor.py +242 -0
  62. openprotein/predictor/schemas.py +113 -0
  63. openprotein/predictor/validate.py +40 -0
  64. openprotein/prompt/__init__.py +9 -0
  65. openprotein/prompt/api.py +505 -0
  66. openprotein/prompt/models.py +142 -0
  67. openprotein/prompt/prompt.py +130 -0
  68. openprotein/prompt/schemas.py +49 -0
  69. openprotein/protein.py +587 -0
  70. openprotein/svd/__init__.py +9 -0
  71. openprotein/svd/api.py +206 -0
  72. openprotein/svd/models.py +288 -0
  73. openprotein/svd/schemas.py +31 -0
  74. openprotein/svd/svd.py +134 -0
  75. openprotein/umap/__init__.py +9 -0
  76. openprotein/umap/api.py +259 -0
  77. openprotein/umap/models.py +211 -0
  78. openprotein/umap/schemas.py +35 -0
  79. openprotein/umap/umap.py +175 -0
  80. openprotein/utils/uuid.py +29 -0
  81. openprotein_python-0.8.2.dist-info/METADATA +176 -0
  82. openprotein_python-0.8.2.dist-info/RECORD +84 -0
  83. openprotein_python-0.8.2.dist-info/WHEEL +4 -0
  84. openprotein_python-0.8.2.dist-info/licenses/LICENSE.txt +30 -0
@@ -0,0 +1,374 @@
1
+ """Predictor models for making predictions on new sequences."""
2
+
3
+ from typing import TYPE_CHECKING
4
+
5
+ from openprotein.base import APISession
6
+ from openprotein.data import AssayDataset, DataAPI
7
+ from openprotein.embeddings import EmbeddingModel, EmbeddingsAPI
8
+ from openprotein.errors import InvalidParameterError
9
+ from openprotein.jobs import Future, JobsAPI, JobType
10
+ from openprotein.svd import SVDAPI, SVDModel
11
+
12
+ from . import api
13
+ from .prediction import PredictionResultFuture
14
+ from .schemas import (
15
+ PredictorEnsembleJob,
16
+ PredictorMetadata,
17
+ PredictorTrainJob,
18
+ PredictorType,
19
+ )
20
+ from .validate import CVResultFuture
21
+
22
+ if TYPE_CHECKING:
23
+ from openprotein.design import ModelCriterion
24
+
25
+
26
+ class PredictorModel(Future):
27
+ """
28
+ Class providing predict endpoint for fitted predictor models.
29
+
30
+ Also implements a Future that waits for train job.
31
+ """
32
+
33
+ job: PredictorTrainJob | None
34
+
35
+ def __init__(
36
+ self,
37
+ session: APISession,
38
+ job: PredictorTrainJob | PredictorEnsembleJob | None = None,
39
+ metadata: PredictorMetadata | None = None,
40
+ ):
41
+ """
42
+ Construct a predictor model.
43
+
44
+ Takes in either a train job, or the predictor metadata.
45
+
46
+ :meta private:
47
+ """
48
+ self._training_assay = None
49
+
50
+ # initialize the predictor metadata
51
+ if metadata is None:
52
+ if job is None or job.job_id is None:
53
+ raise ValueError("Expected predictor metadata or job")
54
+ metadata = api.predictor_get(session, job.job_id)
55
+ self._metadata = metadata
56
+ if job is None:
57
+ if metadata.model_spec.type != PredictorType.ENSEMBLE:
58
+ jobs_api = getattr(session, "jobs", None)
59
+ assert isinstance(jobs_api, JobsAPI)
60
+ job = PredictorTrainJob.create(jobs_api.get_job(job_id=metadata.id))
61
+ else:
62
+ job = PredictorEnsembleJob(
63
+ created_date=self._metadata.created_date,
64
+ status=self._metadata.status,
65
+ job_type=JobType.predictor_train,
66
+ )
67
+ super().__init__(session, job)
68
+
69
+ def __str__(self) -> str:
70
+ return str(self.metadata)
71
+
72
+ def __repr__(self) -> str:
73
+ return repr(self.metadata)
74
+
75
+ def __or__(self, model: "PredictorModel") -> "PredictorModelGroup":
76
+ if self.sequence_length is not None:
77
+ if model.sequence_length != self.sequence_length:
78
+ raise ValueError(
79
+ "Expected sequence lengths to either match or be None."
80
+ )
81
+ return PredictorModelGroup(
82
+ session=self.session,
83
+ models=[self, model],
84
+ sequence_length=self.sequence_length or model.sequence_length,
85
+ check_sequence_length=False,
86
+ )
87
+
88
+ def __lt__(self, target: float) -> "ModelCriterion":
89
+ from openprotein.design import ModelCriterion
90
+
91
+ if len(self.training_properties) == 1:
92
+ return ModelCriterion(
93
+ model_id=self.id,
94
+ measurement_name=self.training_properties[0],
95
+ criterion=ModelCriterion.Criterion(
96
+ target=target, direction=ModelCriterion.Criterion.DirectionEnum.lt
97
+ ),
98
+ )
99
+ raise self.InvalidMultitaskModelToCriterion()
100
+
101
+ def __gt__(self, target: float) -> "ModelCriterion":
102
+ from openprotein.design import ModelCriterion
103
+
104
+ if len(self.training_properties) == 1:
105
+ return ModelCriterion(
106
+ model_id=self.id,
107
+ measurement_name=self.training_properties[0],
108
+ criterion=ModelCriterion.Criterion(
109
+ target=target, direction=ModelCriterion.Criterion.DirectionEnum.gt
110
+ ),
111
+ )
112
+ raise self.InvalidMultitaskModelToCriterion()
113
+
114
+ def __eq__(self, target: float) -> "ModelCriterion":
115
+ from openprotein.design import ModelCriterion
116
+
117
+ if len(self.training_properties) == 1:
118
+ return ModelCriterion(
119
+ model_id=self.id,
120
+ measurement_name=self.training_properties[0],
121
+ criterion=ModelCriterion.Criterion(
122
+ target=target, direction=ModelCriterion.Criterion.DirectionEnum.eq
123
+ ),
124
+ )
125
+ raise self.InvalidMultitaskModelToCriterion()
126
+
127
+ class InvalidMultitaskModelToCriterion(Exception):
128
+ """
129
+ Exception raised when trying to create model criterion from multitask predictor.
130
+
131
+ :meta private:
132
+ """
133
+
134
+ @property
135
+ def id(self):
136
+ """ID of predictor."""
137
+ return self._metadata.id
138
+
139
+ @property
140
+ def reduction(self):
141
+ """The reduction of th embeddings used to train the predictor, if any."""
142
+ return (
143
+ self._metadata.model_spec.features.reduction
144
+ if self._metadata.model_spec.features is not None
145
+ else None
146
+ )
147
+
148
+ @property
149
+ def sequence_length(self):
150
+ """The sequence length constraint on the predictor, if any."""
151
+ if (constraints := self._metadata.model_spec.constraints) is not None:
152
+ return constraints.sequence_length
153
+ return None
154
+
155
+ @property
156
+ def training_assay(self) -> AssayDataset:
157
+ """The assay the predictor was trained on."""
158
+ if self._training_assay is None:
159
+ self._training_assay = self.get_assay()
160
+ return self._training_assay
161
+
162
+ @property
163
+ def training_properties(self) -> list[str]:
164
+ """The list of properties the predictor was trained on."""
165
+ return self._metadata.training_dataset.properties
166
+
167
+ @property
168
+ def metadata(self):
169
+ """The predictor metadata."""
170
+ self._refresh_metadata()
171
+ return self._metadata
172
+
173
+ def _refresh_metadata(self):
174
+ if not self._metadata.is_done():
175
+ self._metadata = api.predictor_get(self.session, self._metadata.id)
176
+
177
+ def get_model(self) -> EmbeddingModel | SVDModel | None:
178
+ """Retrieve the embeddings or SVD model used to create embeddings to train on."""
179
+ if (
180
+ (features := self._metadata.model_spec.features)
181
+ and (model_id := features.model_id) is None
182
+ or features is None
183
+ ):
184
+ return None
185
+ elif features.type.upper() == "PLM":
186
+ model = EmbeddingModel.create(session=self.session, model_id=model_id)
187
+ elif features.type.upper() == "SVD":
188
+ svd_api = getattr(self.session, "svd", None)
189
+ assert isinstance(svd_api, SVDAPI)
190
+ model = svd_api.get_svd(svd_id=model_id)
191
+ else:
192
+ raise ValueError(f"Unexpected feature type {features.type}")
193
+ return model
194
+
195
+ @property
196
+ def model(self) -> EmbeddingModel | SVDModel | None:
197
+ """The embeddings or SVD model used to create embeddings to train on."""
198
+ return self.get_model()
199
+
200
+ def delete(self) -> bool:
201
+ """
202
+ Delete this predictor model.
203
+ """
204
+ return api.predictor_delete(self.session, self.id)
205
+
206
+ def get(self, verbose: bool = False):
207
+ """
208
+ Returns the train loss curves.
209
+ """
210
+ return self.metadata.traingraphs
211
+
212
+ def get_assay(self) -> AssayDataset:
213
+ """
214
+ Get assay used for train job.
215
+
216
+ Returns
217
+ -------
218
+ AssayDataset: Assay dataset used for train job.
219
+ """
220
+ data_api = getattr(self.session, "data", None)
221
+ assert isinstance(data_api, DataAPI)
222
+ return data_api.get(assay_id=self._metadata.training_dataset.assay_id)
223
+
224
+ def crossvalidate(self, n_splits: int | None = None) -> CVResultFuture:
225
+ """
226
+ Run a crossvalidation on the trained predictor.
227
+ """
228
+ return CVResultFuture.create(
229
+ session=self.session,
230
+ job=api.predictor_crossvalidate_post(
231
+ session=self.session,
232
+ predictor_id=self.id,
233
+ n_splits=n_splits,
234
+ ),
235
+ )
236
+
237
+ def predict(self, sequences: list[bytes] | list[str]) -> PredictionResultFuture:
238
+ """
239
+ Make predictions about the trained properties for a list of sequences.
240
+ """
241
+ if self.sequence_length is not None:
242
+ for sequence in sequences:
243
+ # convert to string to check token length
244
+ sequence = sequence if isinstance(sequence, str) else sequence.decode()
245
+ if len(sequence) != self.sequence_length:
246
+ raise InvalidParameterError(
247
+ f"Expected sequences to predict to be of length {self.sequence_length}"
248
+ )
249
+ return PredictionResultFuture.create(
250
+ session=self.session,
251
+ job=api.predictor_predict_post(
252
+ session=self.session, predictor_id=self.id, sequences=sequences
253
+ ),
254
+ )
255
+
256
+ def single_site(self, sequence: bytes | str) -> PredictionResultFuture:
257
+ """
258
+ Compute the single-site mutated predictions of a base sequence.
259
+ """
260
+ if self.sequence_length is not None:
261
+ # convert to string to check token length
262
+ seq = sequence if isinstance(sequence, str) else sequence.decode()
263
+ if len(seq) != self.sequence_length:
264
+ raise InvalidParameterError(
265
+ f"Expected sequence to predict to be of length {self.sequence_length}"
266
+ )
267
+ return PredictionResultFuture.create(
268
+ session=self.session,
269
+ job=api.predictor_predict_single_site_post(
270
+ session=self.session, predictor_id=self.id, base_sequence=sequence
271
+ ),
272
+ )
273
+
274
+
275
+ class PredictorModelGroup(Future):
276
+ """
277
+ Class providing predict endpoint for fitted predictor models.
278
+
279
+ Also implements a Future that waits for train job.
280
+ """
281
+
282
+ __models__: list[PredictorModel]
283
+
284
+ def __init__(
285
+ self,
286
+ session: APISession,
287
+ models: list[PredictorModel],
288
+ sequence_length: int | None = None,
289
+ check_sequence_length: bool = True, # turn off checking - prevent n^2 operation when chaining many
290
+ ):
291
+ if len(models) == 0:
292
+ raise ValueError("Expected at least one model to group")
293
+ # calculate and check sequence length compatibility
294
+ if check_sequence_length:
295
+ for m in models:
296
+ if m.sequence_length is not None:
297
+ if sequence_length is None:
298
+ sequence_length = m.sequence_length
299
+ elif sequence_length != m.sequence_length:
300
+ raise ValueError(
301
+ "Expected sequence lengths of all models to either match or be None."
302
+ )
303
+ self.sequence_length = sequence_length
304
+ self.session = session
305
+ self.__models__ = models
306
+
307
+ def __str__(self) -> str:
308
+ return repr(self.__models__)
309
+
310
+ def __repr__(self) -> str:
311
+ return repr(self.__models__)
312
+
313
+ def __or__(self, model: PredictorModel) -> "PredictorModelGroup":
314
+ if self.sequence_length is not None:
315
+ if model.sequence_length != self.sequence_length:
316
+ raise ValueError(
317
+ "Expected sequence lengths to either match or be None."
318
+ )
319
+ return PredictorModelGroup(
320
+ session=self.session,
321
+ models=self.__models__ + [model],
322
+ sequence_length=self.sequence_length or model.sequence_length,
323
+ check_sequence_length=False,
324
+ )
325
+
326
+ def predict(self, sequences: list[bytes] | list[str]) -> PredictionResultFuture:
327
+ """
328
+ Make predictions about the trained properties for a list of sequences.
329
+ """
330
+ if self.sequence_length is not None:
331
+ for sequence in sequences:
332
+ # convert to string to check token length
333
+ sequence = sequence if isinstance(sequence, str) else sequence.decode()
334
+ if len(sequence) != self.sequence_length:
335
+ raise InvalidParameterError(
336
+ f"Expected sequences to predict to be of length {self.sequence_length}"
337
+ )
338
+ return PredictionResultFuture.create(
339
+ session=self.session,
340
+ job=api.predictor_predict_multi_post(
341
+ session=self.session,
342
+ predictor_ids=[m.id for m in self.__models__],
343
+ sequences=sequences,
344
+ ),
345
+ )
346
+
347
+ def single_site(self, sequence: bytes | str) -> PredictionResultFuture:
348
+ """
349
+ Compute the single-site mutated predictions of a base sequence.
350
+ """
351
+ if self.sequence_length is not None:
352
+ # convert to string to check token length
353
+ seq = sequence if isinstance(sequence, str) else sequence.decode()
354
+ if len(seq) != self.sequence_length:
355
+ raise InvalidParameterError(
356
+ f"Expected sequence to predict to be of length {self.sequence_length}"
357
+ )
358
+ return PredictionResultFuture.create(
359
+ session=self.session,
360
+ job=api.predictor_predict_single_site_post(
361
+ session=self.session, predictor_id=self.id, base_sequence=sequence
362
+ ),
363
+ )
364
+
365
+ def get(self, verbose: bool = False):
366
+ """
367
+ Returns the predictor model.
368
+
369
+ :meta private:
370
+ """
371
+ return self
372
+
373
+ def delete(self):
374
+ return api.predictor_delete(session=self.session, predictor_id=self.id)
@@ -0,0 +1,79 @@
1
+ """Prediction results represented as futures."""
2
+
3
+ import numpy as np
4
+
5
+ from openprotein.base import APISession
6
+ from openprotein.jobs import Future
7
+
8
+ from . import api
9
+ from .schemas import (
10
+ PredictJob,
11
+ PredictMultiJob,
12
+ PredictMultiSingleSiteJob,
13
+ PredictSingleSiteJob,
14
+ )
15
+
16
+
17
+ class PredictionResultFuture(Future):
18
+ """Prediction results represented as a future."""
19
+
20
+ job: PredictJob | PredictSingleSiteJob | PredictMultiJob | PredictMultiSingleSiteJob
21
+
22
+ def __init__(
23
+ self,
24
+ session: APISession,
25
+ job: (
26
+ PredictJob
27
+ | PredictSingleSiteJob
28
+ | PredictMultiJob
29
+ | PredictMultiSingleSiteJob
30
+ ),
31
+ sequences: list[bytes] | None = None,
32
+ ):
33
+ super().__init__(session, job)
34
+ self._sequences = sequences
35
+
36
+ @property
37
+ def sequences(self):
38
+ if self._sequences is None:
39
+ self._sequences = api.predictor_predict_get_sequences(
40
+ self.session, self.job.job_id
41
+ )
42
+ return self._sequences
43
+
44
+ @property
45
+ def id(self):
46
+ return self.job.job_id
47
+
48
+ def __keys__(self):
49
+ return self.sequences
50
+
51
+ def get_item(self, sequence: bytes) -> tuple[np.ndarray, np.ndarray]:
52
+ """
53
+ Get embedding results for specified sequence.
54
+
55
+ Args:
56
+ sequence (bytes): sequence to fetch results for
57
+
58
+ Returns:
59
+ mu (np.ndarray): means of sequence prediction
60
+ var (np.ndarray): variances of sequence prediction
61
+ """
62
+ data = api.predictor_predict_get_sequence_result(
63
+ self.session, self.job.job_id, sequence
64
+ )
65
+ return api.decode_predict(data)
66
+
67
+ def get(self, verbose: bool = False) -> tuple[np.ndarray, np.ndarray]:
68
+ """
69
+ Get embedding results for specified sequence.
70
+
71
+ Args:
72
+ sequence (bytes): sequence to fetch results for
73
+
74
+ Returns:
75
+ mu (np.ndarray): means of predictions
76
+ var (np.ndarray): variances of predictions
77
+ """
78
+ data = api.predictor_predict_get_batched_result(self.session, self.job.job_id)
79
+ return api.decode_predict(data, batched=True)
@@ -0,0 +1,242 @@
1
+ """Predictor API providing the interface to train and predict predictors."""
2
+
3
+ from openprotein.base import APISession
4
+ from openprotein.common import FeatureType, ReductionType
5
+ from openprotein.data import (
6
+ AssayDataset,
7
+ AssayMetadata,
8
+ )
9
+ from openprotein.embeddings import EmbeddingModel, EmbeddingsAPI
10
+ from openprotein.errors import InvalidParameterError
11
+ from openprotein.svd import SVDAPI, SVDModel
12
+
13
+ from . import api
14
+ from .models import PredictorModel
15
+
16
+
17
+ class PredictorAPI:
18
+ """Predictor API providing the interface to train and predict predictors."""
19
+
20
+ def __init__(
21
+ self,
22
+ session: APISession,
23
+ ):
24
+ self.session = session
25
+
26
+ def get_predictor(
27
+ self,
28
+ predictor_id: str,
29
+ include_stats: bool = False,
30
+ include_calibration_curves: bool = False,
31
+ ) -> PredictorModel:
32
+ """
33
+ Get predictor by model_id.
34
+
35
+ PredictorModel allows all the usual prediction job manipulation:
36
+ e.g. making POST and GET requests for this predictor specifically.
37
+
38
+ Parameters
39
+ ----------
40
+ predictor_id : str
41
+ The model identifier.
42
+ include_stats : bool
43
+ Whether to include stats of the predictor from the latest evaluation
44
+ (pearson, spearman, ece). Default is False.
45
+ include_calibration_curves : bool
46
+ Whether to include calibration curves of the predictor from the latest
47
+ evaluation. Default is False.
48
+
49
+ Returns
50
+ -------
51
+ PredictorModel
52
+ The predictor model to inspect and make predictions with.
53
+
54
+ Raises
55
+ ------
56
+ HTTPError
57
+ If the GET request does not succeed.
58
+ """
59
+ metadata = api.predictor_get(
60
+ session=self.session,
61
+ predictor_id=predictor_id,
62
+ include_stats=include_stats,
63
+ include_calibration_curves=include_calibration_curves,
64
+ )
65
+ return PredictorModel(
66
+ session=self.session,
67
+ metadata=metadata,
68
+ )
69
+
70
+ def list_predictors(
71
+ self,
72
+ limit: int = 100,
73
+ offset: int = 0,
74
+ include_stats: bool = False,
75
+ include_calibration_curves: bool = False,
76
+ ) -> list[PredictorModel]:
77
+ """
78
+ List predictors available.
79
+
80
+ Parameters
81
+ ----------
82
+ limit : int
83
+ Limit of the number of predictors to return in list. Default is 100.
84
+ offset : int
85
+ Offset to the predictors to query for paged queries. Default is 0.
86
+ include_stats : bool
87
+ Whether to include stats of each predictor from the latest evaluation
88
+ (pearson, spearman, ece). Default is False.
89
+ include_calibration_curves : bool
90
+ Whether to include calibration curves of each predictor from the latest
91
+ evaluation. Default is False.
92
+
93
+ Returns
94
+ -------
95
+ list[PredictorModel]
96
+ List of predictor models to inspect and make predictions with.
97
+
98
+ Raises
99
+ ------
100
+ HTTPError
101
+ If the GET request does not succeed.
102
+ """
103
+ metadatas = api.predictor_list(
104
+ session=self.session,
105
+ limit=limit,
106
+ offset=offset,
107
+ include_stats=include_stats,
108
+ include_calibration_curves=include_calibration_curves,
109
+ )
110
+ return [
111
+ PredictorModel(
112
+ session=self.session,
113
+ metadata=m,
114
+ )
115
+ for m in metadatas
116
+ ]
117
+
118
+ def fit_gp(
119
+ self,
120
+ assay: AssayDataset | AssayMetadata | str,
121
+ properties: list[str],
122
+ model: EmbeddingModel | SVDModel | str,
123
+ feature_type: FeatureType | None = None,
124
+ reduction: ReductionType | None = None,
125
+ name: str | None = None,
126
+ description: str | None = None,
127
+ **kwargs,
128
+ ) -> PredictorModel:
129
+ """
130
+ Fit a GP on an assay with the specified feature model and hyperparameters.
131
+
132
+ Parameters
133
+ ----------
134
+ assay : AssayMetadata or AssayDataset or str
135
+ Assay to fit GP on.
136
+ properties : list of str
137
+ Properties in the assay to fit the gp on.
138
+ model : EmbeddingModel or SVDModel or str
139
+ Instance of either EmbeddingModel or SVDModel to use depending
140
+ on feature type. Can also be a str specifying the model id,
141
+ but then feature_type would have to be specified.
142
+ feature_type : FeatureType or None
143
+ Type of features to use for encoding sequences. "SVD" or "PLM".
144
+ None would require model to be EmbeddingModel or SVDModel.
145
+ reduction : str or None, optional
146
+ Type of embedding reduction to use for computing features.
147
+ E.g. "MEAN" or "SUM". Used only if using EmbeddingModel, and
148
+ must be non-nil if using an EmbeddingModel. Defaults to None.
149
+ kwargs :
150
+ Additional keyword arguments to be passed to foundational models, e.g. prompt_id for PoET models.
151
+
152
+ Returns
153
+ -------
154
+ PredictorModel
155
+ The GP model being fit.
156
+ """
157
+ # extract feature type
158
+ feature_type = (
159
+ FeatureType.PLM
160
+ if isinstance(model, EmbeddingModel)
161
+ else FeatureType.SVD if isinstance(model, SVDModel) else feature_type
162
+ )
163
+ if feature_type is None:
164
+ raise InvalidParameterError(
165
+ "Expected feature_type to be provided if passing str model_id as model"
166
+ )
167
+ # get model if model_id
168
+ if feature_type == FeatureType.PLM:
169
+ if reduction is None:
170
+ raise InvalidParameterError(
171
+ "Expected reduction if using EmbeddingModel"
172
+ )
173
+ if isinstance(model, str):
174
+ embeddings_api = getattr(self.session, "embedding", None)
175
+ assert isinstance(embeddings_api, EmbeddingsAPI)
176
+ model = embeddings_api.get_model(model)
177
+ assert isinstance(model, EmbeddingModel), "Expected EmbeddingModel"
178
+ model_id = model.id
179
+ elif feature_type == FeatureType.SVD:
180
+ if isinstance(model, str):
181
+ svd_api = getattr(self.session, "svd", None)
182
+ assert isinstance(svd_api, SVDAPI)
183
+ model = svd_api.get_svd(model)
184
+ assert isinstance(model, SVDModel), "Expected SVDModel"
185
+ model_id = model.id
186
+ # get assay_id
187
+ assay_id = (
188
+ assay.assay_id
189
+ if isinstance(assay, AssayMetadata)
190
+ else assay.id if isinstance(assay, AssayDataset) else assay
191
+ )
192
+ return PredictorModel(
193
+ session=self.session,
194
+ job=api.predictor_fit_gp_post(
195
+ session=self.session,
196
+ assay_id=assay_id,
197
+ properties=properties,
198
+ feature_type=feature_type,
199
+ model_id=model_id,
200
+ reduction=reduction,
201
+ name=name,
202
+ description=description,
203
+ **kwargs,
204
+ ),
205
+ )
206
+
207
+ def delete_predictor(self, predictor_id: str) -> bool:
208
+ """
209
+ Delete predictor model.
210
+
211
+ Parameters
212
+ ----------
213
+ predictor_id : str
214
+ The ID of the predictor.
215
+ Returns
216
+ -------
217
+ bool
218
+ True: successful deletion
219
+
220
+ """
221
+ return api.predictor_delete(session=self.session, predictor_id=predictor_id)
222
+
223
+ def ensemble(self, predictors: list[PredictorModel]) -> PredictorModel:
224
+ """
225
+ Ensemble predictor models together.
226
+
227
+ Parameters
228
+ __________
229
+ predictors: list[PredictorModel]
230
+ List of predictors to ensemble together.
231
+ Returns
232
+ -------
233
+ PredictorModel
234
+ Ensembled predictor model
235
+ """
236
+ return PredictorModel(
237
+ session=self.session,
238
+ metadata=api.predictor_ensemble(
239
+ session=self.session,
240
+ predictor_ids=[predictor.id for predictor in predictors],
241
+ ),
242
+ )