openprotein-python 0.8.2__1-py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openprotein/__init__.py +164 -0
- openprotein/_version.py +48 -0
- openprotein/align/__init__.py +8 -0
- openprotein/align/align.py +395 -0
- openprotein/align/api.py +428 -0
- openprotein/align/future.py +55 -0
- openprotein/align/msa.py +129 -0
- openprotein/align/schemas.py +165 -0
- openprotein/base.py +181 -0
- openprotein/chains.py +88 -0
- openprotein/common/__init__.py +5 -0
- openprotein/common/features.py +7 -0
- openprotein/common/model_metadata.py +33 -0
- openprotein/common/reduction.py +8 -0
- openprotein/config.py +9 -0
- openprotein/csv.py +31 -0
- openprotein/data/__init__.py +9 -0
- openprotein/data/api.py +218 -0
- openprotein/data/assaydataset.py +178 -0
- openprotein/data/data.py +93 -0
- openprotein/data/schemas.py +27 -0
- openprotein/design/__init__.py +16 -0
- openprotein/design/api.py +259 -0
- openprotein/design/design.py +125 -0
- openprotein/design/future.py +146 -0
- openprotein/design/schemas.py +607 -0
- openprotein/embeddings/__init__.py +27 -0
- openprotein/embeddings/api.py +619 -0
- openprotein/embeddings/embeddings.py +151 -0
- openprotein/embeddings/esm.py +33 -0
- openprotein/embeddings/future.py +146 -0
- openprotein/embeddings/models.py +421 -0
- openprotein/embeddings/openprotein.py +21 -0
- openprotein/embeddings/poet.py +446 -0
- openprotein/embeddings/poet2.py +505 -0
- openprotein/embeddings/schemas.py +78 -0
- openprotein/errors.py +76 -0
- openprotein/fasta.py +92 -0
- openprotein/fold/__init__.py +21 -0
- openprotein/fold/alphafold2.py +131 -0
- openprotein/fold/api.py +287 -0
- openprotein/fold/boltz.py +691 -0
- openprotein/fold/esmfold.py +54 -0
- openprotein/fold/fold.py +107 -0
- openprotein/fold/future.py +509 -0
- openprotein/fold/models.py +139 -0
- openprotein/fold/schemas.py +39 -0
- openprotein/jobs/__init__.py +9 -0
- openprotein/jobs/api.py +71 -0
- openprotein/jobs/futures.py +746 -0
- openprotein/jobs/jobs.py +69 -0
- openprotein/jobs/schemas.py +135 -0
- openprotein/models/__init__.py +4 -0
- openprotein/models/base.py +63 -0
- openprotein/models/foundation/rfdiffusion.py +283 -0
- openprotein/models/models.py +33 -0
- openprotein/predictor/__init__.py +25 -0
- openprotein/predictor/api.py +384 -0
- openprotein/predictor/models.py +374 -0
- openprotein/predictor/prediction.py +79 -0
- openprotein/predictor/predictor.py +242 -0
- openprotein/predictor/schemas.py +113 -0
- openprotein/predictor/validate.py +40 -0
- openprotein/prompt/__init__.py +9 -0
- openprotein/prompt/api.py +505 -0
- openprotein/prompt/models.py +142 -0
- openprotein/prompt/prompt.py +130 -0
- openprotein/prompt/schemas.py +49 -0
- openprotein/protein.py +587 -0
- openprotein/svd/__init__.py +9 -0
- openprotein/svd/api.py +206 -0
- openprotein/svd/models.py +288 -0
- openprotein/svd/schemas.py +31 -0
- openprotein/svd/svd.py +134 -0
- openprotein/umap/__init__.py +9 -0
- openprotein/umap/api.py +259 -0
- openprotein/umap/models.py +211 -0
- openprotein/umap/schemas.py +35 -0
- openprotein/umap/umap.py +175 -0
- openprotein/utils/uuid.py +29 -0
- openprotein_python-0.8.2.dist-info/METADATA +176 -0
- openprotein_python-0.8.2.dist-info/RECORD +84 -0
- openprotein_python-0.8.2.dist-info/WHEEL +4 -0
- openprotein_python-0.8.2.dist-info/licenses/LICENSE.txt +30 -0
openprotein/svd/api.py
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
1
|
+
"""SVD REST API for making HTTP calls to our SVD backend."""
|
|
2
|
+
|
|
3
|
+
import io
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
from pydantic import TypeAdapter
|
|
7
|
+
|
|
8
|
+
from openprotein.base import APISession
|
|
9
|
+
from openprotein.errors import APIError, InvalidParameterError
|
|
10
|
+
|
|
11
|
+
from .schemas import SVDEmbeddingsJob, SVDFitJob, SVDMetadata
|
|
12
|
+
|
|
13
|
+
PATH_PREFIX = "v1/embeddings/svd"
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def svd_list_get(session: APISession) -> list[SVDMetadata]:
|
|
17
|
+
"""Get SVD job metadata for all SVDs. Including SVD dimension and sequence lengths."""
|
|
18
|
+
endpoint = PATH_PREFIX
|
|
19
|
+
response = session.get(endpoint)
|
|
20
|
+
return TypeAdapter(list[SVDMetadata]).validate_python(response.json())
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def svd_get(session: APISession, svd_id: str) -> SVDMetadata:
|
|
24
|
+
"""Get SVD job metadata. Including SVD dimension and sequence lengths."""
|
|
25
|
+
endpoint = PATH_PREFIX + f"/{svd_id}"
|
|
26
|
+
response = session.get(endpoint)
|
|
27
|
+
return SVDMetadata.model_validate(response.json())
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def svd_get_sequences(session: APISession, svd_id: str) -> list[bytes]:
|
|
31
|
+
"""
|
|
32
|
+
Get sequences used to fit an SVD.
|
|
33
|
+
|
|
34
|
+
Parameters
|
|
35
|
+
----------
|
|
36
|
+
session : APISession
|
|
37
|
+
Session object for API communication.
|
|
38
|
+
svd_id : str
|
|
39
|
+
SVD ID whose sequences to fetch
|
|
40
|
+
|
|
41
|
+
Returns
|
|
42
|
+
-------
|
|
43
|
+
sequences : List[bytes]
|
|
44
|
+
"""
|
|
45
|
+
endpoint = PATH_PREFIX + f"/{svd_id}/sequences"
|
|
46
|
+
response = session.get(endpoint)
|
|
47
|
+
return TypeAdapter(list[bytes]).validate_python(response.json())
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def embed_get_sequence_result(
|
|
51
|
+
session: APISession, job_id: str, sequence: str | bytes
|
|
52
|
+
) -> bytes:
|
|
53
|
+
"""
|
|
54
|
+
Get encoded svd embeddings result for a sequence from the request ID.
|
|
55
|
+
|
|
56
|
+
Parameters
|
|
57
|
+
----------
|
|
58
|
+
session : APISession
|
|
59
|
+
Session object for API communication.
|
|
60
|
+
job_id : str
|
|
61
|
+
job ID to retrieve results from
|
|
62
|
+
sequence : bytes
|
|
63
|
+
sequence to retrieve results for
|
|
64
|
+
|
|
65
|
+
Returns
|
|
66
|
+
-------
|
|
67
|
+
result : bytes
|
|
68
|
+
"""
|
|
69
|
+
if isinstance(sequence, bytes):
|
|
70
|
+
sequence = sequence.decode()
|
|
71
|
+
endpoint = PATH_PREFIX + f"/embed/{job_id}/{sequence}"
|
|
72
|
+
response = session.get(endpoint)
|
|
73
|
+
return response.content
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def embed_decode(data: bytes) -> np.ndarray:
|
|
77
|
+
"""
|
|
78
|
+
Decode embedding as numpy array.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
data (bytes): raw bytes encoding the array received over the API
|
|
82
|
+
|
|
83
|
+
Returns:
|
|
84
|
+
np.ndarray: decoded array
|
|
85
|
+
"""
|
|
86
|
+
s = io.BytesIO(data)
|
|
87
|
+
return np.load(s, allow_pickle=False)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def svd_delete(session: APISession, svd_id: str):
|
|
91
|
+
"""
|
|
92
|
+
Delete and SVD model.
|
|
93
|
+
|
|
94
|
+
Parameters
|
|
95
|
+
----------
|
|
96
|
+
session : APISession
|
|
97
|
+
Session object for API communication.
|
|
98
|
+
svd_id : str
|
|
99
|
+
SVD model to delete
|
|
100
|
+
|
|
101
|
+
Returns
|
|
102
|
+
-------
|
|
103
|
+
bool
|
|
104
|
+
"""
|
|
105
|
+
|
|
106
|
+
endpoint = PATH_PREFIX + f"/{svd_id}"
|
|
107
|
+
response = session.delete(endpoint)
|
|
108
|
+
if 200 <= response.status_code < 300:
|
|
109
|
+
return True
|
|
110
|
+
else:
|
|
111
|
+
raise APIError(response.text)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def svd_fit_post(
|
|
115
|
+
session: APISession,
|
|
116
|
+
model_id: str,
|
|
117
|
+
sequences: list[bytes] | list[str] | None = None,
|
|
118
|
+
assay_id: str | None = None,
|
|
119
|
+
n_components: int = 1024,
|
|
120
|
+
reduction: str | None = None,
|
|
121
|
+
**kwargs,
|
|
122
|
+
) -> SVDFitJob:
|
|
123
|
+
"""
|
|
124
|
+
Create SVD fit job.
|
|
125
|
+
|
|
126
|
+
Parameters
|
|
127
|
+
----------
|
|
128
|
+
session: APISession
|
|
129
|
+
Session object for API communication.
|
|
130
|
+
model_id: str
|
|
131
|
+
ID of embeddings model to use.
|
|
132
|
+
sequences: list of bytes or None, optional
|
|
133
|
+
Optional sequences to fit SVD with. Either use sequences or
|
|
134
|
+
assay_id. sequences is preferred.
|
|
135
|
+
assay_id: str | None, optional
|
|
136
|
+
Optional ID of assay containing sequences to fit SVD with. Either
|
|
137
|
+
use sequences or assay_id. Ignored if sequences are provided.
|
|
138
|
+
n_components: int
|
|
139
|
+
Number of SVD components to fit. Defaults to 1024
|
|
140
|
+
reduction: str | None
|
|
141
|
+
Type of embedding reduction to use for computing features.
|
|
142
|
+
E.g. "MEAN" or "SUM". Useful when dealing with variable length
|
|
143
|
+
sequence. Defaults to None.
|
|
144
|
+
kwargs:
|
|
145
|
+
Additional keyword arguments to be passed to foundational models, e.g. prompt_id for PoET models.
|
|
146
|
+
|
|
147
|
+
Returns
|
|
148
|
+
-------
|
|
149
|
+
Job
|
|
150
|
+
"""
|
|
151
|
+
|
|
152
|
+
endpoint = PATH_PREFIX
|
|
153
|
+
|
|
154
|
+
body = {
|
|
155
|
+
"model_id": model_id,
|
|
156
|
+
"n_components": n_components,
|
|
157
|
+
}
|
|
158
|
+
if reduction is not None:
|
|
159
|
+
body["reduction"] = reduction
|
|
160
|
+
if sequences is not None:
|
|
161
|
+
# both provided
|
|
162
|
+
if assay_id is not None:
|
|
163
|
+
raise InvalidParameterError("Expected only either sequences or assay_id")
|
|
164
|
+
sequences = [(s if isinstance(s, str) else s.decode()) for s in sequences]
|
|
165
|
+
body["sequences"] = sequences
|
|
166
|
+
else:
|
|
167
|
+
# both are none
|
|
168
|
+
if assay_id is None:
|
|
169
|
+
raise InvalidParameterError("Expected either sequences or assay_id")
|
|
170
|
+
body["assay_id"] = assay_id
|
|
171
|
+
# add kwargs for embeddings kwargs
|
|
172
|
+
body.update(**kwargs)
|
|
173
|
+
|
|
174
|
+
response = session.post(endpoint, json=body)
|
|
175
|
+
# return job for metadata
|
|
176
|
+
return SVDFitJob.model_validate(response.json())
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def svd_embed_post(
|
|
180
|
+
session: APISession, svd_id: str, sequences: list[bytes] | list[str]
|
|
181
|
+
) -> SVDEmbeddingsJob:
|
|
182
|
+
"""
|
|
183
|
+
POST a request for embeddings from the given SVD model.
|
|
184
|
+
|
|
185
|
+
Parameters
|
|
186
|
+
----------
|
|
187
|
+
session : APISession
|
|
188
|
+
Session object for API communication.
|
|
189
|
+
svd_id : str
|
|
190
|
+
SVD model to use
|
|
191
|
+
sequences : List[bytes]
|
|
192
|
+
sequences to SVD
|
|
193
|
+
|
|
194
|
+
Returns
|
|
195
|
+
-------
|
|
196
|
+
Job
|
|
197
|
+
"""
|
|
198
|
+
endpoint = PATH_PREFIX + f"/{svd_id}/embed"
|
|
199
|
+
|
|
200
|
+
sequences_unicode = [(s if isinstance(s, str) else s.decode()) for s in sequences]
|
|
201
|
+
body = {
|
|
202
|
+
"sequences": sequences_unicode,
|
|
203
|
+
}
|
|
204
|
+
response = session.post(endpoint, json=body)
|
|
205
|
+
|
|
206
|
+
return SVDEmbeddingsJob.model_validate(response.json())
|
|
@@ -0,0 +1,288 @@
|
|
|
1
|
+
"""SVD model representations which can be used for creating reduced embeddings."""
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
from openprotein import config
|
|
8
|
+
from openprotein.base import APISession
|
|
9
|
+
from openprotein.common import FeatureType
|
|
10
|
+
from openprotein.data import AssayDataset, AssayMetadata, DataAPI
|
|
11
|
+
from openprotein.embeddings import EmbeddingModel, EmbeddingsResultFuture
|
|
12
|
+
from openprotein.errors import InvalidParameterError
|
|
13
|
+
from openprotein.jobs import Future, JobsAPI
|
|
14
|
+
|
|
15
|
+
from . import api
|
|
16
|
+
from .schemas import SVDEmbeddingsJob, SVDFitJob, SVDMetadata
|
|
17
|
+
|
|
18
|
+
if TYPE_CHECKING:
|
|
19
|
+
from openprotein.predictor import PredictorModel
|
|
20
|
+
from openprotein.umap import UMAPModel
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class SVDModel(Future):
|
|
24
|
+
"""
|
|
25
|
+
SVD model that can be used to create reduced embeddings.
|
|
26
|
+
|
|
27
|
+
The model is also implemented as a `Future` to allow waiting for a fit job.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
job: SVDFitJob
|
|
31
|
+
|
|
32
|
+
def __init__(
|
|
33
|
+
self,
|
|
34
|
+
session: APISession,
|
|
35
|
+
job: SVDFitJob | None = None,
|
|
36
|
+
metadata: SVDMetadata | None = None,
|
|
37
|
+
):
|
|
38
|
+
"""Construct the SVD model using either job get or svd metadata get."""
|
|
39
|
+
# initialize the metadata
|
|
40
|
+
if metadata is None:
|
|
41
|
+
# use job to fetch metadata
|
|
42
|
+
if job is None:
|
|
43
|
+
raise ValueError("Expected svd metadata or job")
|
|
44
|
+
metadata = api.svd_get(session=session, svd_id=job.job_id)
|
|
45
|
+
self._metadata = metadata
|
|
46
|
+
if job is None:
|
|
47
|
+
jobs_api = getattr(session, "jobs", None)
|
|
48
|
+
assert isinstance(jobs_api, JobsAPI)
|
|
49
|
+
job = SVDFitJob.create(jobs_api.get_job(job_id=metadata.id))
|
|
50
|
+
# getter initializes job if not provided
|
|
51
|
+
super().__init__(session=session, job=job)
|
|
52
|
+
|
|
53
|
+
def __str__(self) -> str:
|
|
54
|
+
return str(self.metadata)
|
|
55
|
+
|
|
56
|
+
def __repr__(self) -> str:
|
|
57
|
+
return repr(self.metadata)
|
|
58
|
+
|
|
59
|
+
@property
|
|
60
|
+
def id(self):
|
|
61
|
+
return self._metadata.id
|
|
62
|
+
|
|
63
|
+
@property
|
|
64
|
+
def n_components(self):
|
|
65
|
+
"""Number of components of the SVD."""
|
|
66
|
+
return self._metadata.n_components
|
|
67
|
+
|
|
68
|
+
@property
|
|
69
|
+
def sequence_length(self):
|
|
70
|
+
"""Sequence length constraint of the SVD."""
|
|
71
|
+
return self._metadata.sequence_length
|
|
72
|
+
|
|
73
|
+
@property
|
|
74
|
+
def reduction(self):
|
|
75
|
+
"""Reduction of embeddings used to fit the SVD."""
|
|
76
|
+
return self._metadata.reduction
|
|
77
|
+
|
|
78
|
+
@property
|
|
79
|
+
def metadata(self):
|
|
80
|
+
"""Metadata of the SVD."""
|
|
81
|
+
self._refresh_metadata()
|
|
82
|
+
return self._metadata
|
|
83
|
+
|
|
84
|
+
def _refresh_metadata(self):
|
|
85
|
+
if not self._metadata.is_done():
|
|
86
|
+
self._metadata = api.svd_get(session=self.session, svd_id=self._metadata.id)
|
|
87
|
+
|
|
88
|
+
def get_model(self) -> EmbeddingModel:
|
|
89
|
+
model = EmbeddingModel.create(session=self.session, model_id=self._metadata.id)
|
|
90
|
+
return model
|
|
91
|
+
|
|
92
|
+
@property
|
|
93
|
+
def model(self) -> EmbeddingModel:
|
|
94
|
+
"""Base embeddings model used for the SVD."""
|
|
95
|
+
return self.get_model()
|
|
96
|
+
|
|
97
|
+
def delete(self) -> bool:
|
|
98
|
+
"""
|
|
99
|
+
Delete this SVD model.
|
|
100
|
+
"""
|
|
101
|
+
return api.svd_delete(self.session, self.id)
|
|
102
|
+
|
|
103
|
+
def get(self, verbose: bool = False):
|
|
104
|
+
"""Retrieve this SVD model itself."""
|
|
105
|
+
return self
|
|
106
|
+
|
|
107
|
+
def get_inputs(self) -> list[bytes]:
|
|
108
|
+
"""
|
|
109
|
+
Get sequences used for svd job.
|
|
110
|
+
|
|
111
|
+
Returns
|
|
112
|
+
-------
|
|
113
|
+
list[bytes]
|
|
114
|
+
List of sequences
|
|
115
|
+
"""
|
|
116
|
+
return api.svd_get_sequences(session=self.session, svd_id=self.id)
|
|
117
|
+
|
|
118
|
+
def embed(
|
|
119
|
+
self, sequences: list[bytes] | list[str], **kwargs
|
|
120
|
+
) -> "SVDEmbeddingsResultFuture":
|
|
121
|
+
"""
|
|
122
|
+
Use this SVD model to get reduced embeddings from input sequences.
|
|
123
|
+
|
|
124
|
+
Parameters
|
|
125
|
+
----------
|
|
126
|
+
sequences : List[bytes]
|
|
127
|
+
List of protein sequences.
|
|
128
|
+
|
|
129
|
+
Returns
|
|
130
|
+
-------
|
|
131
|
+
SVDEmbeddingsResultFuture
|
|
132
|
+
Future result containing the reduced embeddings.
|
|
133
|
+
"""
|
|
134
|
+
return SVDEmbeddingsResultFuture.create(
|
|
135
|
+
session=self.session,
|
|
136
|
+
job=api.svd_embed_post(
|
|
137
|
+
session=self.session, svd_id=self.id, sequences=sequences, **kwargs
|
|
138
|
+
),
|
|
139
|
+
sequences=sequences,
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
def fit_umap(
|
|
143
|
+
self,
|
|
144
|
+
sequences: list[bytes] | list[str] | None = None,
|
|
145
|
+
assay: AssayDataset | None = None,
|
|
146
|
+
n_components: int = 2,
|
|
147
|
+
**kwargs,
|
|
148
|
+
) -> "UMAPModel":
|
|
149
|
+
"""
|
|
150
|
+
Fit an UMAP on the embedding results of this model.
|
|
151
|
+
|
|
152
|
+
This function will create an UMAPModel based on the embeddings from this model \
|
|
153
|
+
as well as the hyperparameters specified in the args.
|
|
154
|
+
|
|
155
|
+
Parameters
|
|
156
|
+
----------
|
|
157
|
+
sequences : List[bytes]
|
|
158
|
+
sequences to UMAP
|
|
159
|
+
n_components: int
|
|
160
|
+
number of components in UMAP. Will determine output shapes
|
|
161
|
+
reduction: ReductionType | None
|
|
162
|
+
embeddings reduction to use (e.g. mean)
|
|
163
|
+
|
|
164
|
+
Returns
|
|
165
|
+
-------
|
|
166
|
+
UMAPModel
|
|
167
|
+
UMAP model fitted on the reduced embeddings from provided sequences or assay.
|
|
168
|
+
"""
|
|
169
|
+
# local import for cyclic dep
|
|
170
|
+
from openprotein.umap import UMAPAPI
|
|
171
|
+
|
|
172
|
+
umap_api = getattr(self.session, "umap", None)
|
|
173
|
+
assert isinstance(umap_api, UMAPAPI)
|
|
174
|
+
|
|
175
|
+
# Ensure either or
|
|
176
|
+
if (assay is None and sequences is None) or (
|
|
177
|
+
assay is not None and sequences is not None
|
|
178
|
+
):
|
|
179
|
+
raise InvalidParameterError(
|
|
180
|
+
"Expected either assay or sequences to fit UMAP on!"
|
|
181
|
+
)
|
|
182
|
+
model_id = self.id
|
|
183
|
+
return umap_api.fit_umap(
|
|
184
|
+
model=model_id,
|
|
185
|
+
feature_type=FeatureType.SVD,
|
|
186
|
+
sequences=sequences,
|
|
187
|
+
assay_id=assay.id if assay is not None else None,
|
|
188
|
+
n_components=n_components,
|
|
189
|
+
**kwargs,
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
def fit_gp(
|
|
193
|
+
self,
|
|
194
|
+
assay: AssayMetadata | AssayDataset | str,
|
|
195
|
+
properties: list[str],
|
|
196
|
+
name: str | None = None,
|
|
197
|
+
description: str | None = None,
|
|
198
|
+
**kwargs,
|
|
199
|
+
) -> "PredictorModel":
|
|
200
|
+
"""
|
|
201
|
+
Fit a GP on assay using this embedding model and hyperparameters.
|
|
202
|
+
|
|
203
|
+
Parameters
|
|
204
|
+
----------
|
|
205
|
+
assay : AssayMetadata or AssayDataset or str
|
|
206
|
+
Assay to fit GP on. Or its assay_id.
|
|
207
|
+
properties: list of str
|
|
208
|
+
Properties in the assay to fit the gp on.
|
|
209
|
+
|
|
210
|
+
Returns
|
|
211
|
+
-------
|
|
212
|
+
PredictorModel
|
|
213
|
+
Property predictor model trained using the reduced embeddings with provided assay and properties.
|
|
214
|
+
"""
|
|
215
|
+
# local import to resolve cyclic
|
|
216
|
+
from openprotein.predictor import PredictorAPI
|
|
217
|
+
|
|
218
|
+
data_api = getattr(self.session, "data", None)
|
|
219
|
+
assert isinstance(data_api, DataAPI)
|
|
220
|
+
|
|
221
|
+
predictor_api = getattr(self.session, "predictor", None)
|
|
222
|
+
assert isinstance(predictor_api, PredictorAPI)
|
|
223
|
+
|
|
224
|
+
# get assay if str
|
|
225
|
+
assay = data_api.get(assay_id=assay) if isinstance(assay, str) else assay
|
|
226
|
+
if (
|
|
227
|
+
self.sequence_length is not None
|
|
228
|
+
and assay.sequence_length != self.sequence_length
|
|
229
|
+
):
|
|
230
|
+
raise InvalidParameterError(
|
|
231
|
+
f"Expected dataset to be of sequence length {self.sequence_length} due to svd fitted constraints"
|
|
232
|
+
)
|
|
233
|
+
if len(properties) == 0:
|
|
234
|
+
raise InvalidParameterError("Expected (at-least) 1 property to train")
|
|
235
|
+
if not set(properties) <= set(assay.measurement_names):
|
|
236
|
+
raise InvalidParameterError(
|
|
237
|
+
f"Expected all provided properties to be a subset of assay's measurements: {assay.measurement_names}"
|
|
238
|
+
)
|
|
239
|
+
# TODO - support multitask
|
|
240
|
+
if len(properties) > 1:
|
|
241
|
+
raise InvalidParameterError(
|
|
242
|
+
"Training a multitask GP is not yet supported (i.e. number of properties should only be 1 for now)"
|
|
243
|
+
)
|
|
244
|
+
return predictor_api.fit_gp(
|
|
245
|
+
assay=assay,
|
|
246
|
+
properties=properties,
|
|
247
|
+
feature_type=FeatureType.SVD,
|
|
248
|
+
model=self,
|
|
249
|
+
name=name,
|
|
250
|
+
description=description,
|
|
251
|
+
**kwargs,
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
class SVDEmbeddingsResultFuture(EmbeddingsResultFuture, Future):
|
|
256
|
+
"""SVD embeddings results represented as a future."""
|
|
257
|
+
|
|
258
|
+
job: SVDEmbeddingsJob
|
|
259
|
+
|
|
260
|
+
def wait(
|
|
261
|
+
self,
|
|
262
|
+
interval: int = config.POLLING_INTERVAL,
|
|
263
|
+
timeout: int | None = None,
|
|
264
|
+
verbose: bool = False,
|
|
265
|
+
) -> list[np.ndarray]:
|
|
266
|
+
"""Wait for the SVD embeddings job and retrieve the embeddings."""
|
|
267
|
+
return super().wait(interval, timeout, verbose)
|
|
268
|
+
|
|
269
|
+
def get(self, verbose=False) -> list[np.ndarray]:
|
|
270
|
+
"""Get all the SVD reduced embeddings from the job."""
|
|
271
|
+
return super().get(verbose)
|
|
272
|
+
|
|
273
|
+
def get_item(self, sequence: bytes) -> np.ndarray:
|
|
274
|
+
"""
|
|
275
|
+
Get SVD embeddings for specified sequence.
|
|
276
|
+
|
|
277
|
+
Parameters
|
|
278
|
+
----------
|
|
279
|
+
sequence: bytes
|
|
280
|
+
Sequence to fetch SVD embeddings for.
|
|
281
|
+
|
|
282
|
+
Returns
|
|
283
|
+
-------
|
|
284
|
+
np.ndarray
|
|
285
|
+
SVD embeddings represented a numpy array.
|
|
286
|
+
"""
|
|
287
|
+
data = api.embed_get_sequence_result(self.session, self.job.job_id, sequence)
|
|
288
|
+
return api.embed_decode(data)
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
"""Schemas for OpenProtein SVD system."""
|
|
2
|
+
|
|
3
|
+
from datetime import datetime
|
|
4
|
+
from typing import Literal
|
|
5
|
+
|
|
6
|
+
from pydantic import BaseModel, ConfigDict
|
|
7
|
+
|
|
8
|
+
from openprotein.jobs import BatchJob, Job, JobStatus, JobType
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class SVDMetadata(BaseModel):
|
|
12
|
+
id: str
|
|
13
|
+
status: JobStatus
|
|
14
|
+
created_date: datetime | None = None
|
|
15
|
+
model_id: str
|
|
16
|
+
n_components: int
|
|
17
|
+
reduction: str | None = None
|
|
18
|
+
sequence_length: int | None = None
|
|
19
|
+
|
|
20
|
+
def is_done(self):
|
|
21
|
+
return self.status.done()
|
|
22
|
+
|
|
23
|
+
model_config = ConfigDict(protected_namespaces=())
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class SVDFitJob(Job):
|
|
27
|
+
job_type: Literal[JobType.svd_fit]
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class SVDEmbeddingsJob(Job, BatchJob):
|
|
31
|
+
job_type: Literal[JobType.svd_embed]
|
openprotein/svd/svd.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
1
|
+
"""SVD API providing the interface for creating and using SVD models."""
|
|
2
|
+
|
|
3
|
+
from openprotein.base import APISession
|
|
4
|
+
from openprotein.common import ReductionType
|
|
5
|
+
from openprotein.data import AssayDataset, AssayMetadata
|
|
6
|
+
from openprotein.embeddings import EmbeddingModel, EmbeddingsAPI
|
|
7
|
+
|
|
8
|
+
from . import api
|
|
9
|
+
from .models import SVDModel
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class SVDAPI:
|
|
13
|
+
"""SVD API providing the interface for creating and using SVD models."""
|
|
14
|
+
|
|
15
|
+
def __init__(
|
|
16
|
+
self,
|
|
17
|
+
session: APISession,
|
|
18
|
+
):
|
|
19
|
+
self.session = session
|
|
20
|
+
|
|
21
|
+
def fit_svd(
|
|
22
|
+
self,
|
|
23
|
+
model_id: str,
|
|
24
|
+
sequences: list[bytes] | list[str] | None = None,
|
|
25
|
+
assay: AssayMetadata | AssayDataset | str | None = None,
|
|
26
|
+
n_components: int = 1024,
|
|
27
|
+
reduction: ReductionType | None = None,
|
|
28
|
+
**kwargs,
|
|
29
|
+
) -> SVDModel:
|
|
30
|
+
"""
|
|
31
|
+
Fit an SVD on the sequences with the specified model_id and hyperparameters (n_components).
|
|
32
|
+
|
|
33
|
+
Parameters
|
|
34
|
+
----------
|
|
35
|
+
model_id : str
|
|
36
|
+
ID of embeddings model to use.
|
|
37
|
+
sequences : list of bytes or None, optional
|
|
38
|
+
Optional sequences to fit SVD with. Either use sequences or
|
|
39
|
+
assay_id. sequences is preferred.
|
|
40
|
+
assay : AssayMetadata or AssayDataset or str or None, optional
|
|
41
|
+
Optional assay containing sequences to fit SVD with.
|
|
42
|
+
Or its assay_id. Either use sequences or assay.
|
|
43
|
+
Ignored if sequences are provided.
|
|
44
|
+
n_components : int, optional
|
|
45
|
+
The number of components for the SVD. Defaults to 1024.
|
|
46
|
+
reduction : str or None, optional
|
|
47
|
+
Type of embedding reduction to use for computing features.
|
|
48
|
+
E.g. "MEAN" or "SUM". Useful when dealing with variable length
|
|
49
|
+
sequence. Defaults to None.
|
|
50
|
+
kwargs :
|
|
51
|
+
Additional keyword arguments to be passed to foundational models, e.g. prompt_id for PoET models.
|
|
52
|
+
|
|
53
|
+
Returns
|
|
54
|
+
-------
|
|
55
|
+
SVDModel
|
|
56
|
+
The SVD model being fit.
|
|
57
|
+
"""
|
|
58
|
+
embeddings_api = getattr(self.session, "embedding", None)
|
|
59
|
+
assert isinstance(embeddings_api, EmbeddingsAPI)
|
|
60
|
+
model = embeddings_api.get_model(model_id)
|
|
61
|
+
assert isinstance(model, EmbeddingModel), "Expected EmbeddingModel"
|
|
62
|
+
# get assay_id
|
|
63
|
+
assay_id = (
|
|
64
|
+
assay.assay_id
|
|
65
|
+
if isinstance(assay, AssayMetadata)
|
|
66
|
+
else assay.id if isinstance(assay, AssayDataset) else assay
|
|
67
|
+
)
|
|
68
|
+
return SVDModel(
|
|
69
|
+
session=self.session,
|
|
70
|
+
job=api.svd_fit_post(
|
|
71
|
+
session=self.session,
|
|
72
|
+
model_id=model.id,
|
|
73
|
+
sequences=sequences,
|
|
74
|
+
assay_id=assay_id,
|
|
75
|
+
n_components=n_components,
|
|
76
|
+
reduction=reduction,
|
|
77
|
+
**kwargs,
|
|
78
|
+
),
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
def get_svd(self, svd_id: str) -> SVDModel:
|
|
82
|
+
"""
|
|
83
|
+
Get SVD job results. Including SVD dimension and sequence lengths.
|
|
84
|
+
|
|
85
|
+
Requires a successful SVD job from fit_svd
|
|
86
|
+
|
|
87
|
+
Parameters
|
|
88
|
+
----------
|
|
89
|
+
svd_id : str
|
|
90
|
+
The ID of the SVD job.
|
|
91
|
+
Returns
|
|
92
|
+
-------
|
|
93
|
+
SVDModel
|
|
94
|
+
The model with the SVD fit.
|
|
95
|
+
"""
|
|
96
|
+
metadata = api.svd_get(self.session, svd_id)
|
|
97
|
+
return SVDModel(
|
|
98
|
+
session=self.session,
|
|
99
|
+
metadata=metadata,
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
def __delete_svd(self, svd_id: str) -> bool:
|
|
103
|
+
"""
|
|
104
|
+
Delete SVD model.
|
|
105
|
+
|
|
106
|
+
Parameters
|
|
107
|
+
----------
|
|
108
|
+
svd_id : str
|
|
109
|
+
The ID of the SVD job.
|
|
110
|
+
Returns
|
|
111
|
+
-------
|
|
112
|
+
bool
|
|
113
|
+
Whether or not the SVD was successfully deleted.
|
|
114
|
+
|
|
115
|
+
"""
|
|
116
|
+
return api.svd_delete(self.session, svd_id)
|
|
117
|
+
|
|
118
|
+
def list_svd(self) -> list[SVDModel]:
|
|
119
|
+
"""
|
|
120
|
+
List SVD models made by user.
|
|
121
|
+
|
|
122
|
+
Returns
|
|
123
|
+
-------
|
|
124
|
+
list of SVDModel
|
|
125
|
+
List of SVDs that the user has access to.
|
|
126
|
+
|
|
127
|
+
"""
|
|
128
|
+
return [
|
|
129
|
+
SVDModel(
|
|
130
|
+
session=self.session,
|
|
131
|
+
metadata=metadata,
|
|
132
|
+
)
|
|
133
|
+
for metadata in api.svd_list_get(self.session)
|
|
134
|
+
]
|