openprotein-python 0.8.2__1-py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (84) hide show
  1. openprotein/__init__.py +164 -0
  2. openprotein/_version.py +48 -0
  3. openprotein/align/__init__.py +8 -0
  4. openprotein/align/align.py +395 -0
  5. openprotein/align/api.py +428 -0
  6. openprotein/align/future.py +55 -0
  7. openprotein/align/msa.py +129 -0
  8. openprotein/align/schemas.py +165 -0
  9. openprotein/base.py +181 -0
  10. openprotein/chains.py +88 -0
  11. openprotein/common/__init__.py +5 -0
  12. openprotein/common/features.py +7 -0
  13. openprotein/common/model_metadata.py +33 -0
  14. openprotein/common/reduction.py +8 -0
  15. openprotein/config.py +9 -0
  16. openprotein/csv.py +31 -0
  17. openprotein/data/__init__.py +9 -0
  18. openprotein/data/api.py +218 -0
  19. openprotein/data/assaydataset.py +178 -0
  20. openprotein/data/data.py +93 -0
  21. openprotein/data/schemas.py +27 -0
  22. openprotein/design/__init__.py +16 -0
  23. openprotein/design/api.py +259 -0
  24. openprotein/design/design.py +125 -0
  25. openprotein/design/future.py +146 -0
  26. openprotein/design/schemas.py +607 -0
  27. openprotein/embeddings/__init__.py +27 -0
  28. openprotein/embeddings/api.py +619 -0
  29. openprotein/embeddings/embeddings.py +151 -0
  30. openprotein/embeddings/esm.py +33 -0
  31. openprotein/embeddings/future.py +146 -0
  32. openprotein/embeddings/models.py +421 -0
  33. openprotein/embeddings/openprotein.py +21 -0
  34. openprotein/embeddings/poet.py +446 -0
  35. openprotein/embeddings/poet2.py +505 -0
  36. openprotein/embeddings/schemas.py +78 -0
  37. openprotein/errors.py +76 -0
  38. openprotein/fasta.py +92 -0
  39. openprotein/fold/__init__.py +21 -0
  40. openprotein/fold/alphafold2.py +131 -0
  41. openprotein/fold/api.py +287 -0
  42. openprotein/fold/boltz.py +691 -0
  43. openprotein/fold/esmfold.py +54 -0
  44. openprotein/fold/fold.py +107 -0
  45. openprotein/fold/future.py +509 -0
  46. openprotein/fold/models.py +139 -0
  47. openprotein/fold/schemas.py +39 -0
  48. openprotein/jobs/__init__.py +9 -0
  49. openprotein/jobs/api.py +71 -0
  50. openprotein/jobs/futures.py +746 -0
  51. openprotein/jobs/jobs.py +69 -0
  52. openprotein/jobs/schemas.py +135 -0
  53. openprotein/models/__init__.py +4 -0
  54. openprotein/models/base.py +63 -0
  55. openprotein/models/foundation/rfdiffusion.py +283 -0
  56. openprotein/models/models.py +33 -0
  57. openprotein/predictor/__init__.py +25 -0
  58. openprotein/predictor/api.py +384 -0
  59. openprotein/predictor/models.py +374 -0
  60. openprotein/predictor/prediction.py +79 -0
  61. openprotein/predictor/predictor.py +242 -0
  62. openprotein/predictor/schemas.py +113 -0
  63. openprotein/predictor/validate.py +40 -0
  64. openprotein/prompt/__init__.py +9 -0
  65. openprotein/prompt/api.py +505 -0
  66. openprotein/prompt/models.py +142 -0
  67. openprotein/prompt/prompt.py +130 -0
  68. openprotein/prompt/schemas.py +49 -0
  69. openprotein/protein.py +587 -0
  70. openprotein/svd/__init__.py +9 -0
  71. openprotein/svd/api.py +206 -0
  72. openprotein/svd/models.py +288 -0
  73. openprotein/svd/schemas.py +31 -0
  74. openprotein/svd/svd.py +134 -0
  75. openprotein/umap/__init__.py +9 -0
  76. openprotein/umap/api.py +259 -0
  77. openprotein/umap/models.py +211 -0
  78. openprotein/umap/schemas.py +35 -0
  79. openprotein/umap/umap.py +175 -0
  80. openprotein/utils/uuid.py +29 -0
  81. openprotein_python-0.8.2.dist-info/METADATA +176 -0
  82. openprotein_python-0.8.2.dist-info/RECORD +84 -0
  83. openprotein_python-0.8.2.dist-info/WHEEL +4 -0
  84. openprotein_python-0.8.2.dist-info/licenses/LICENSE.txt +30 -0
@@ -0,0 +1,384 @@
1
+ """Predictor REST API for making HTTP calls to our predictor backend."""
2
+
3
+ import io
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ from pydantic import TypeAdapter
8
+
9
+ from openprotein.base import APISession
10
+ from openprotein.errors import APIError
11
+ from openprotein.jobs import Job
12
+
13
+ from .schemas import (
14
+ Job,
15
+ PredictJob,
16
+ PredictMultiJob,
17
+ PredictMultiSingleSiteJob,
18
+ PredictorCVJob,
19
+ PredictorMetadata,
20
+ PredictorTrainJob,
21
+ PredictSingleSiteJob,
22
+ )
23
+
24
+ PATH_PREFIX = "v1/predictor"
25
+
26
+
27
+ def predictor_list(
28
+ session: APISession,
29
+ limit: int = 100,
30
+ offset: int = 0,
31
+ include_stats: bool = False,
32
+ include_calibration_curves: bool = False,
33
+ ) -> list[PredictorMetadata]:
34
+ """
35
+ List trained predictors.
36
+
37
+ Parameters
38
+ ----------
39
+ session : APISession
40
+ Session object for API communication.
41
+ limit: int
42
+ Limit of the number of predictors to return in list.
43
+ offset: int
44
+ Offset to the predictors to query for paged queries.
45
+ include_stats: bool
46
+ Whether to include stats of the predictor from latest evaluation, i.e. pearson, spearman, ece.
47
+ include_calibration_curves: bool
48
+ Whether to include calibration curves of the predictor from latest evaluation.
49
+
50
+ Returns
51
+ -------
52
+ list[PredictorMetadata]
53
+ List of predictors
54
+ """
55
+ endpoint = PATH_PREFIX
56
+ response = session.get(
57
+ endpoint,
58
+ params={
59
+ "limit": limit,
60
+ "offset": offset,
61
+ "stats": include_stats,
62
+ "curve": include_calibration_curves,
63
+ },
64
+ )
65
+ return TypeAdapter(list[PredictorMetadata]).validate_python(response.json())
66
+
67
+
68
+ def predictor_get(
69
+ session: APISession,
70
+ predictor_id: str,
71
+ include_stats: bool = False,
72
+ include_calibration_curves: bool = False,
73
+ ) -> PredictorMetadata:
74
+ """
75
+ Get a trained predictor by its identifier.
76
+
77
+ Parameters
78
+ ----------
79
+ session : APISession
80
+ Session object for API communication.
81
+ predictor_id : str
82
+ Unique identifier of the predictor.
83
+ include_stats : bool
84
+ Whether to include stats of the predictor from the latest evaluation (pearson, spearman, ece).
85
+ include_calibration_curves : bool
86
+ Whether to include calibration curves of the predictor from the latest evaluation.
87
+ Returns
88
+ -------
89
+ PredictorMetadata
90
+ Metadata of the requested predictor.
91
+ """
92
+ endpoint = PATH_PREFIX + f"/{predictor_id}"
93
+ response = session.get(
94
+ endpoint,
95
+ params={
96
+ "stats": include_stats,
97
+ "curve": include_calibration_curves,
98
+ },
99
+ )
100
+ return TypeAdapter(PredictorMetadata).validate_python(response.json())
101
+
102
+
103
+ def predictor_fit_gp_post(
104
+ session: APISession,
105
+ assay_id: str,
106
+ properties: list[str],
107
+ feature_type: str,
108
+ model_id: str,
109
+ reduction: str | None = None,
110
+ name: str | None = None,
111
+ description: str | None = None,
112
+ **kwargs,
113
+ ) -> PredictorTrainJob:
114
+ """
115
+ Create SVD fit job.
116
+
117
+ Parameters
118
+ ----------
119
+ session : APISession
120
+ Session object for API communication.
121
+ assay_id : str
122
+ Assay ID to fit GP on.
123
+ properties: list[str]
124
+ Properties in the assay to fit the gp on.
125
+ feature_type: str
126
+ Type of features to use for encoding sequences. PLM or SVD.
127
+ model_id : str
128
+ Protembed/SVD model to use depending on feature type.
129
+ reduction : str | None
130
+ Type of embedding reduction to use for computing features. default = None
131
+ name: str | None
132
+ Optional name of predictor model. Randomly generated if not provided.
133
+ description: str | None
134
+ Optional description to attach to the model.
135
+ kwargs:
136
+ Additional keyword arguments to be passed to foundational models, e.g. prompt_id for PoET models.
137
+
138
+ Returns
139
+ -------
140
+ PredictorTrainJob
141
+ Train job that can be tracked for progress.
142
+ """
143
+ endpoint = PATH_PREFIX + "/gp"
144
+
145
+ body = {
146
+ "dataset": {
147
+ "assay_id": assay_id,
148
+ "properties": properties,
149
+ },
150
+ "features": {
151
+ "type": feature_type,
152
+ "model_id": model_id,
153
+ },
154
+ "kernel": {
155
+ "type": "rbf",
156
+ # "multitask": True
157
+ },
158
+ }
159
+ if reduction is not None:
160
+ body["features"]["reduction"] = reduction
161
+ if name is not None:
162
+ body["name"] = name
163
+ if description is not None:
164
+ body["description"] = description
165
+ # add kwargs for embeddings kwargs
166
+ body.update(kwargs)
167
+
168
+ response = session.post(endpoint, json=body)
169
+ return PredictorTrainJob.model_validate(response.json())
170
+
171
+
172
+ def predictor_ensemble(
173
+ session: APISession, predictor_ids: list[str]
174
+ ) -> PredictorMetadata:
175
+ endpoint = PATH_PREFIX + f"/ensemble"
176
+
177
+ body = {
178
+ "model_ids": predictor_ids,
179
+ }
180
+
181
+ response = session.post(endpoint, json=body)
182
+ return PredictorMetadata.model_validate(response.json())
183
+
184
+
185
+ def predictor_delete(session: APISession, predictor_id: str):
186
+ endpoint = PATH_PREFIX + f"/{predictor_id}"
187
+ response = session.delete(endpoint)
188
+ if 200 <= response.status_code < 300:
189
+ return True
190
+ else:
191
+ raise APIError(response.text)
192
+
193
+
194
+ def predictor_crossvalidate_post(
195
+ session: APISession, predictor_id: str, n_splits: int | None = None
196
+ ):
197
+ endpoint = PATH_PREFIX + f"/{predictor_id}/crossvalidate"
198
+
199
+ params = {}
200
+ if n_splits is not None:
201
+ params["n_splits"] = n_splits
202
+ response = session.post(endpoint, params=params)
203
+
204
+ return PredictorCVJob.model_validate(response.json())
205
+
206
+
207
+ def predictor_crossvalidate_get(session: APISession, crossvalidate_job_id: str):
208
+ endpoint = PATH_PREFIX + f"/crossvalidate/{crossvalidate_job_id}"
209
+
210
+ response = session.get(endpoint)
211
+ return response.content
212
+
213
+
214
+ def predictor_predict_post(
215
+ session: APISession, predictor_id: str, sequences: list[bytes] | list[str]
216
+ ):
217
+ endpoint = PATH_PREFIX + f"/{predictor_id}/predict"
218
+
219
+ sequences_unicode = [(s if isinstance(s, str) else s.decode()) for s in sequences]
220
+ body = {
221
+ "sequences": sequences_unicode,
222
+ }
223
+ response = session.post(endpoint, json=body)
224
+
225
+ return PredictJob.model_validate(response.json())
226
+
227
+
228
+ def predictor_predict_multi_post(
229
+ session: APISession, predictor_ids: list[str], sequences: list[bytes] | list[str]
230
+ ):
231
+ endpoint = PATH_PREFIX + f"/predict"
232
+
233
+ sequences_unicode = [(s if isinstance(s, str) else s.decode()) for s in sequences]
234
+ body = {
235
+ "model_ids": predictor_ids,
236
+ "sequences": sequences_unicode,
237
+ }
238
+ response = session.post(endpoint, json=body)
239
+
240
+ return PredictMultiJob.model_validate(response.json())
241
+
242
+
243
+ def predictor_predict_single_site_post(
244
+ session: APISession,
245
+ predictor_id: str,
246
+ base_sequence: bytes | str,
247
+ ):
248
+ endpoint = PATH_PREFIX + f"/{predictor_id}/predict_single_site"
249
+
250
+ base_sequence = (
251
+ base_sequence.decode() if isinstance(base_sequence, bytes) else base_sequence
252
+ )
253
+ body = {
254
+ "base_sequence": base_sequence,
255
+ }
256
+ response = session.post(endpoint, json=body)
257
+
258
+ return PredictSingleSiteJob.model_validate(response.json())
259
+
260
+
261
+ def predictor_predict_multi_single_site_post(
262
+ session: APISession,
263
+ predictor_ids: list[str],
264
+ base_sequence: bytes | str,
265
+ ):
266
+ endpoint = PATH_PREFIX + f"/predict_single_site"
267
+
268
+ base_sequence = (
269
+ base_sequence.decode() if isinstance(base_sequence, bytes) else base_sequence
270
+ )
271
+ body = {
272
+ "model_ids": predictor_ids,
273
+ "base_sequence": base_sequence,
274
+ }
275
+ response = session.post(endpoint, json=body)
276
+
277
+ return PredictMultiSingleSiteJob.model_validate(response.json())
278
+
279
+
280
+ def predictor_predict_get_sequences(
281
+ session: APISession, prediction_job_id: str
282
+ ) -> list[bytes]:
283
+ endpoint = PATH_PREFIX + f"/predict/{prediction_job_id}/sequences"
284
+
285
+ response = session.get(endpoint)
286
+ return TypeAdapter(list[bytes]).validate_python(response.json())
287
+
288
+
289
+ def predictor_predict_get_sequence_result(
290
+ session: APISession, prediction_job_id: str, sequence: bytes | str
291
+ ) -> bytes:
292
+ """
293
+ Get encoded result for a sequence from the request ID.
294
+
295
+ Parameters
296
+ ----------
297
+ session : APISession
298
+ Session object for API communication.
299
+ job_id : str
300
+ job ID to retrieve results from
301
+ sequence from: bytes
302
+ sequence to retrieve predictions for
303
+
304
+ Returns
305
+ -------
306
+ result : bytes
307
+ """
308
+ if isinstance(sequence, bytes):
309
+ sequence = sequence.decode()
310
+ endpoint = PATH_PREFIX + f"/predict/{prediction_job_id}/{sequence}"
311
+ response = session.get(endpoint)
312
+ return response.content
313
+
314
+
315
+ def predictor_predict_get_batched_result(
316
+ session: APISession, prediction_job_id: str
317
+ ) -> bytes:
318
+ """
319
+ Get encoded result for a sequence from the request ID.
320
+
321
+ Parameters
322
+ ----------
323
+ session : APISession
324
+ Session object for API communication.
325
+ prediction_job_id : str
326
+ job ID to retrieve results from
327
+ sequence : bytes
328
+ sequence to retrieve results for
329
+
330
+ Returns
331
+ -------
332
+ result : bytes
333
+ """
334
+ endpoint = PATH_PREFIX + f"/predict/{prediction_job_id}"
335
+ response = session.get(endpoint)
336
+ return response.content
337
+
338
+
339
+ def decode_predict(data: bytes, batched: bool = False) -> tuple[np.ndarray, np.ndarray]:
340
+ """
341
+ Decode prediction scores.
342
+
343
+ Args:
344
+ data (bytes): raw bytes encoding the array received over the API
345
+ batched (bool): whether or not the result was batched. affects the retrieved csv format whether they contain additional columns and header rows.
346
+
347
+ Returns:
348
+ mus (np.ndarray): decoded array of means
349
+ vars (np.ndarray): decoded array of variances
350
+ """
351
+ s = io.BytesIO(data)
352
+ if batched:
353
+ # should contain header and sequence column
354
+ df = pd.read_csv(s)
355
+ scores = df.iloc[:, 1:].values
356
+ else:
357
+ # should be a single row with 2n columns
358
+ df = pd.read_csv(s, header=None)
359
+ scores = df.values
360
+ mus = scores[:, ::2]
361
+ vars = scores[:, 1::2]
362
+ return mus, vars
363
+
364
+
365
+ def decode_crossvalidate(data: bytes) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
366
+ """
367
+ Decode crossvalidate scores.
368
+
369
+ Args:
370
+ data (bytes): raw bytes encoding the array received over the API
371
+
372
+ Returns:
373
+ mus (np.ndarray): decoded array of means
374
+ vars (np.ndarray): decoded array of variances
375
+ """
376
+ s = io.BytesIO(data)
377
+ # should contain header and sequence column
378
+ df = pd.read_csv(s)
379
+ scores = df.values
380
+ # row_num, seq, measurement_name, y, y_mu, y_var
381
+ y = scores[:, 3]
382
+ mus = scores[:, 4]
383
+ vars = scores[:, 5]
384
+ return y, mus, vars