openprotein-python 0.8.2__1-py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openprotein/__init__.py +164 -0
- openprotein/_version.py +48 -0
- openprotein/align/__init__.py +8 -0
- openprotein/align/align.py +395 -0
- openprotein/align/api.py +428 -0
- openprotein/align/future.py +55 -0
- openprotein/align/msa.py +129 -0
- openprotein/align/schemas.py +165 -0
- openprotein/base.py +181 -0
- openprotein/chains.py +88 -0
- openprotein/common/__init__.py +5 -0
- openprotein/common/features.py +7 -0
- openprotein/common/model_metadata.py +33 -0
- openprotein/common/reduction.py +8 -0
- openprotein/config.py +9 -0
- openprotein/csv.py +31 -0
- openprotein/data/__init__.py +9 -0
- openprotein/data/api.py +218 -0
- openprotein/data/assaydataset.py +178 -0
- openprotein/data/data.py +93 -0
- openprotein/data/schemas.py +27 -0
- openprotein/design/__init__.py +16 -0
- openprotein/design/api.py +259 -0
- openprotein/design/design.py +125 -0
- openprotein/design/future.py +146 -0
- openprotein/design/schemas.py +607 -0
- openprotein/embeddings/__init__.py +27 -0
- openprotein/embeddings/api.py +619 -0
- openprotein/embeddings/embeddings.py +151 -0
- openprotein/embeddings/esm.py +33 -0
- openprotein/embeddings/future.py +146 -0
- openprotein/embeddings/models.py +421 -0
- openprotein/embeddings/openprotein.py +21 -0
- openprotein/embeddings/poet.py +446 -0
- openprotein/embeddings/poet2.py +505 -0
- openprotein/embeddings/schemas.py +78 -0
- openprotein/errors.py +76 -0
- openprotein/fasta.py +92 -0
- openprotein/fold/__init__.py +21 -0
- openprotein/fold/alphafold2.py +131 -0
- openprotein/fold/api.py +287 -0
- openprotein/fold/boltz.py +691 -0
- openprotein/fold/esmfold.py +54 -0
- openprotein/fold/fold.py +107 -0
- openprotein/fold/future.py +509 -0
- openprotein/fold/models.py +139 -0
- openprotein/fold/schemas.py +39 -0
- openprotein/jobs/__init__.py +9 -0
- openprotein/jobs/api.py +71 -0
- openprotein/jobs/futures.py +746 -0
- openprotein/jobs/jobs.py +69 -0
- openprotein/jobs/schemas.py +135 -0
- openprotein/models/__init__.py +4 -0
- openprotein/models/base.py +63 -0
- openprotein/models/foundation/rfdiffusion.py +283 -0
- openprotein/models/models.py +33 -0
- openprotein/predictor/__init__.py +25 -0
- openprotein/predictor/api.py +384 -0
- openprotein/predictor/models.py +374 -0
- openprotein/predictor/prediction.py +79 -0
- openprotein/predictor/predictor.py +242 -0
- openprotein/predictor/schemas.py +113 -0
- openprotein/predictor/validate.py +40 -0
- openprotein/prompt/__init__.py +9 -0
- openprotein/prompt/api.py +505 -0
- openprotein/prompt/models.py +142 -0
- openprotein/prompt/prompt.py +130 -0
- openprotein/prompt/schemas.py +49 -0
- openprotein/protein.py +587 -0
- openprotein/svd/__init__.py +9 -0
- openprotein/svd/api.py +206 -0
- openprotein/svd/models.py +288 -0
- openprotein/svd/schemas.py +31 -0
- openprotein/svd/svd.py +134 -0
- openprotein/umap/__init__.py +9 -0
- openprotein/umap/api.py +259 -0
- openprotein/umap/models.py +211 -0
- openprotein/umap/schemas.py +35 -0
- openprotein/umap/umap.py +175 -0
- openprotein/utils/uuid.py +29 -0
- openprotein_python-0.8.2.dist-info/METADATA +176 -0
- openprotein_python-0.8.2.dist-info/RECORD +84 -0
- openprotein_python-0.8.2.dist-info/WHEEL +4 -0
- openprotein_python-0.8.2.dist-info/licenses/LICENSE.txt +30 -0
|
@@ -0,0 +1,619 @@
|
|
|
1
|
+
"""Embeddings REST API for making HTTP calls to our embeddings backend."""
|
|
2
|
+
|
|
3
|
+
import io
|
|
4
|
+
import random
|
|
5
|
+
import struct
|
|
6
|
+
from io import BytesIO
|
|
7
|
+
from typing import BinaryIO, Iterator
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
from pydantic import TypeAdapter
|
|
11
|
+
|
|
12
|
+
from openprotein import csv
|
|
13
|
+
from openprotein.base import APISession
|
|
14
|
+
from openprotein.common import ModelMetadata
|
|
15
|
+
from openprotein.errors import InvalidParameterError
|
|
16
|
+
|
|
17
|
+
from .schemas import (
|
|
18
|
+
AttnJob,
|
|
19
|
+
EmbeddingsJob,
|
|
20
|
+
GenerateJob,
|
|
21
|
+
JobType,
|
|
22
|
+
LogitsJob,
|
|
23
|
+
ScoreIndelJob,
|
|
24
|
+
ScoreJob,
|
|
25
|
+
ScoreSingleSiteJob,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
PATH_PREFIX = "v1/embeddings"
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def list_models(session: APISession) -> list[str]:
|
|
32
|
+
"""
|
|
33
|
+
List available embeddings models.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
session (APISession): API session
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
list[str]: list of model names.
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
endpoint = PATH_PREFIX + "/models"
|
|
43
|
+
response = session.get(endpoint)
|
|
44
|
+
result = response.json()
|
|
45
|
+
return result
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def get_model(session: APISession, model_id: str) -> ModelMetadata:
|
|
49
|
+
endpoint = PATH_PREFIX + f"/models/{model_id}"
|
|
50
|
+
response = session.get(endpoint)
|
|
51
|
+
result = response.json()
|
|
52
|
+
return ModelMetadata(**result)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def get_request_sequences(
|
|
56
|
+
session: APISession, job_id: str, job_type: JobType = JobType.embeddings_embed
|
|
57
|
+
) -> list[bytes]:
|
|
58
|
+
"""
|
|
59
|
+
Get results associated with the given request ID.
|
|
60
|
+
|
|
61
|
+
Parameters
|
|
62
|
+
----------
|
|
63
|
+
session : APISession
|
|
64
|
+
Session object for API communication.
|
|
65
|
+
job_id : str
|
|
66
|
+
job ID to fetch
|
|
67
|
+
|
|
68
|
+
Returns
|
|
69
|
+
-------
|
|
70
|
+
sequences : List[bytes]
|
|
71
|
+
"""
|
|
72
|
+
# NOTE - allow to handle svd/embed and umap/embed directly too instead of redirect
|
|
73
|
+
path = "v1" + job_type.value
|
|
74
|
+
endpoint = path + f"/{job_id}/sequences"
|
|
75
|
+
response = session.get(endpoint)
|
|
76
|
+
return TypeAdapter(list[bytes]).validate_python(response.json())
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def request_get_sequence_result(
|
|
80
|
+
session: APISession,
|
|
81
|
+
job_id: str,
|
|
82
|
+
sequence: str | bytes,
|
|
83
|
+
job_type: JobType = JobType.embeddings_embed,
|
|
84
|
+
) -> bytes:
|
|
85
|
+
"""
|
|
86
|
+
Get encoded result for a sequence from the request ID.
|
|
87
|
+
|
|
88
|
+
Parameters
|
|
89
|
+
----------
|
|
90
|
+
session : APISession
|
|
91
|
+
Session object for API communication.
|
|
92
|
+
job_id : str
|
|
93
|
+
job ID to retrieve results from
|
|
94
|
+
sequence : bytes
|
|
95
|
+
sequence to retrieve results for
|
|
96
|
+
|
|
97
|
+
Returns
|
|
98
|
+
-------
|
|
99
|
+
result : bytes
|
|
100
|
+
"""
|
|
101
|
+
# NOTE - allow to handle svd/embed and umap/embed directly too instead of redirect
|
|
102
|
+
path = "v1" + job_type.value
|
|
103
|
+
if isinstance(sequence, bytes):
|
|
104
|
+
sequence = sequence.decode()
|
|
105
|
+
endpoint = path + f"/{job_id}/{sequence}"
|
|
106
|
+
response = session.get(endpoint)
|
|
107
|
+
return response.content
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def result_decode(data: bytes) -> np.ndarray:
|
|
111
|
+
"""
|
|
112
|
+
Decode embedding.
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
data (bytes): raw bytes encoding the array received over the API
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
np.ndarray: decoded array
|
|
119
|
+
"""
|
|
120
|
+
s = io.BytesIO(data)
|
|
121
|
+
return np.load(s, allow_pickle=False)
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def request_get_score_result(session: APISession, job_id: str) -> Iterator[list[str]]:
|
|
125
|
+
"""
|
|
126
|
+
Get encoded result for a sequence from the request ID.
|
|
127
|
+
|
|
128
|
+
Parameters
|
|
129
|
+
----------
|
|
130
|
+
session : APISession
|
|
131
|
+
Session object for API communication.
|
|
132
|
+
job_id : str
|
|
133
|
+
job ID to retrieve results from
|
|
134
|
+
|
|
135
|
+
Returns
|
|
136
|
+
-------
|
|
137
|
+
csv.reader
|
|
138
|
+
"""
|
|
139
|
+
endpoint = PATH_PREFIX + f"/{job_id}/scores"
|
|
140
|
+
response = session.get(endpoint, stream=True)
|
|
141
|
+
return csv.parse_stream(response.iter_lines())
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def request_get_embeddings_stream(
|
|
145
|
+
session: APISession, job_id: str
|
|
146
|
+
) -> Iterator[np.ndarray]:
|
|
147
|
+
"""
|
|
148
|
+
Stream back the raw embeddings for a given embeddings job.
|
|
149
|
+
|
|
150
|
+
This will open an HTTP GET to `v1/embeddings/{job_id}/embeddings`
|
|
151
|
+
with `stream=True`, then read a sequence of framed `.npy` payloads
|
|
152
|
+
where each chunk is prefixed by an 8-byte big-endian length header.
|
|
153
|
+
Each chunk is decoded into a NumPy array and yielded as soon as it’s
|
|
154
|
+
received.
|
|
155
|
+
|
|
156
|
+
Parameters
|
|
157
|
+
----------
|
|
158
|
+
session : APISession
|
|
159
|
+
The API session to use for making requests.
|
|
160
|
+
job_id : str
|
|
161
|
+
The embeddings job identifier returned by `request_post`.
|
|
162
|
+
|
|
163
|
+
Yields
|
|
164
|
+
------
|
|
165
|
+
numpy.ndarray
|
|
166
|
+
An embedding array for each input sequence.
|
|
167
|
+
|
|
168
|
+
Raises
|
|
169
|
+
------
|
|
170
|
+
requests.HTTPError
|
|
171
|
+
If the HTTP request returns a non‐2xx status code.
|
|
172
|
+
ValueError
|
|
173
|
+
If the framed stream is malformed (e.g. incomplete header or payload).
|
|
174
|
+
"""
|
|
175
|
+
endpoint = PATH_PREFIX + f"/{job_id}/stream"
|
|
176
|
+
response = session.get(endpoint, stream=True)
|
|
177
|
+
response.raise_for_status()
|
|
178
|
+
response.raw.decode_content = True
|
|
179
|
+
buffered = io.BufferedReader(response.raw) # type: ignore
|
|
180
|
+
for array in parse_framed_npy_stream(buffered):
|
|
181
|
+
yield array
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
def parse_framed_npy_stream(stream: BinaryIO) -> Iterator[np.ndarray]:
|
|
185
|
+
"""
|
|
186
|
+
Read a binary stream of length‐prefixed NumPy .npy arrays.
|
|
187
|
+
|
|
188
|
+
This function parses a stream composed of consecutive frames. Each frame
|
|
189
|
+
starts with an 8‐byte big‐endian unsigned integer indicating the size of
|
|
190
|
+
the subsequent .npy payload. It then reads exactly that many bytes and
|
|
191
|
+
deserializes them into a NumPy array via np.load(…, allow_pickle=False).
|
|
192
|
+
Frames are yielded one by one until the stream is exhausted.
|
|
193
|
+
|
|
194
|
+
Parameters
|
|
195
|
+
----------
|
|
196
|
+
stream : BinaryIO
|
|
197
|
+
A binary stream supporting read(n) that contains zero or more
|
|
198
|
+
concatenated frames in the format:
|
|
199
|
+
[8‐byte big‐endian length][.npy payload].
|
|
200
|
+
|
|
201
|
+
Yields
|
|
202
|
+
------
|
|
203
|
+
np.ndarray
|
|
204
|
+
Each deserialized NumPy array from the stream.
|
|
205
|
+
|
|
206
|
+
Raises
|
|
207
|
+
------
|
|
208
|
+
ValueError
|
|
209
|
+
If an 8‐byte header cannot be read in full (unless at end of stream),
|
|
210
|
+
or if a payload shorter than the declared length is encountered.
|
|
211
|
+
"""
|
|
212
|
+
while True:
|
|
213
|
+
# Read the 8-byte length header
|
|
214
|
+
try:
|
|
215
|
+
length_bytes = stream.read(8)
|
|
216
|
+
except ValueError:
|
|
217
|
+
# underlying file got closed → treat as EOF
|
|
218
|
+
break
|
|
219
|
+
if len(length_bytes) < 8:
|
|
220
|
+
if length_bytes:
|
|
221
|
+
raise ValueError("Incomplete length header")
|
|
222
|
+
break # End of stream
|
|
223
|
+
|
|
224
|
+
(npy_len,) = struct.unpack(">Q", length_bytes)
|
|
225
|
+
npy_bytes = stream.read(npy_len)
|
|
226
|
+
if len(npy_bytes) < npy_len:
|
|
227
|
+
raise ValueError("Incomplete npy payload")
|
|
228
|
+
|
|
229
|
+
arr = np.load(BytesIO(npy_bytes), allow_pickle=False)
|
|
230
|
+
yield arr
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
def request_get_generate_result(
|
|
234
|
+
session: APISession, job_id: str
|
|
235
|
+
) -> Iterator[list[str]]:
|
|
236
|
+
"""
|
|
237
|
+
Get encoded result for a sequence from the request ID.
|
|
238
|
+
|
|
239
|
+
Parameters
|
|
240
|
+
----------
|
|
241
|
+
session : APISession
|
|
242
|
+
Session object for API communication.
|
|
243
|
+
job_id : str
|
|
244
|
+
job ID to retrieve results from
|
|
245
|
+
|
|
246
|
+
Returns
|
|
247
|
+
-------
|
|
248
|
+
csv.reader
|
|
249
|
+
"""
|
|
250
|
+
endpoint = PATH_PREFIX + f"/{job_id}/generate"
|
|
251
|
+
response = session.get(endpoint, stream=True)
|
|
252
|
+
return csv.parse_stream(response.iter_lines())
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
def request_post(
|
|
256
|
+
session: APISession,
|
|
257
|
+
model_id: str,
|
|
258
|
+
sequences: list[bytes] | list[str],
|
|
259
|
+
reduction: str | None = "MEAN",
|
|
260
|
+
**kwargs,
|
|
261
|
+
) -> EmbeddingsJob:
|
|
262
|
+
"""
|
|
263
|
+
POST a request for embeddings from the given model ID. Returns a Job object referring to this request
|
|
264
|
+
that can be used to retrieve results later.
|
|
265
|
+
|
|
266
|
+
Parameters
|
|
267
|
+
----------
|
|
268
|
+
session : APISession
|
|
269
|
+
Session object for API communication.
|
|
270
|
+
model_id : str
|
|
271
|
+
model ID to request results from
|
|
272
|
+
sequences : List[bytes]
|
|
273
|
+
sequences to request results for
|
|
274
|
+
reduction : str | None
|
|
275
|
+
reduction to apply to the embeddings. options are None, "MEAN", or "SUM". defaul: "MEAN"
|
|
276
|
+
**kwargs:
|
|
277
|
+
Optional parameters for models, e.g. prompt_id for PoET
|
|
278
|
+
|
|
279
|
+
Returns
|
|
280
|
+
-------
|
|
281
|
+
job : Job
|
|
282
|
+
"""
|
|
283
|
+
endpoint = PATH_PREFIX + f"/models/{model_id}/embed"
|
|
284
|
+
|
|
285
|
+
sequences_unicode = [(s if isinstance(s, str) else s.decode()) for s in sequences]
|
|
286
|
+
body: dict = {
|
|
287
|
+
"sequences": sequences_unicode,
|
|
288
|
+
}
|
|
289
|
+
if reduction is not None:
|
|
290
|
+
body["reduction"] = reduction
|
|
291
|
+
if kwargs.get("prompt_id"):
|
|
292
|
+
body["prompt_id"] = kwargs["prompt_id"]
|
|
293
|
+
if kwargs.get("query_id"):
|
|
294
|
+
body["query_id"] = kwargs["query_id"]
|
|
295
|
+
if "use_query_structure_in_decoder" in kwargs:
|
|
296
|
+
body["use_query_structure_in_decoder"] = kwargs[
|
|
297
|
+
"use_query_structure_in_decoder"
|
|
298
|
+
]
|
|
299
|
+
if kwargs.get("decoder_type"):
|
|
300
|
+
body["decoder_type"] = kwargs["decoder_type"]
|
|
301
|
+
response = session.post(endpoint, json=body)
|
|
302
|
+
return EmbeddingsJob.model_validate(response.json())
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
def request_logits_post(
|
|
306
|
+
session: APISession,
|
|
307
|
+
model_id: str,
|
|
308
|
+
sequences: list[bytes] | list[str],
|
|
309
|
+
**kwargs,
|
|
310
|
+
) -> LogitsJob:
|
|
311
|
+
"""
|
|
312
|
+
POST a request for logits from the given model ID. Returns a Job object referring to this request
|
|
313
|
+
that can be used to retrieve results later.
|
|
314
|
+
|
|
315
|
+
Parameters
|
|
316
|
+
----------
|
|
317
|
+
session : APISession
|
|
318
|
+
Session object for API communication.
|
|
319
|
+
model_id : str
|
|
320
|
+
model ID to request results from
|
|
321
|
+
sequences : List[bytes]
|
|
322
|
+
sequences to request results for
|
|
323
|
+
**kwargs:
|
|
324
|
+
Optional parameters for models, e.g. prompt_id for PoET
|
|
325
|
+
|
|
326
|
+
Returns
|
|
327
|
+
-------
|
|
328
|
+
job : Job
|
|
329
|
+
"""
|
|
330
|
+
endpoint = PATH_PREFIX + f"/models/{model_id}/logits"
|
|
331
|
+
|
|
332
|
+
sequences_unicode = [(s if isinstance(s, str) else s.decode()) for s in sequences]
|
|
333
|
+
body: dict = {
|
|
334
|
+
"sequences": sequences_unicode,
|
|
335
|
+
}
|
|
336
|
+
if kwargs.get("prompt_id"):
|
|
337
|
+
body["prompt_id"] = kwargs["prompt_id"]
|
|
338
|
+
if kwargs.get("query_id"):
|
|
339
|
+
body["query_id"] = kwargs["query_id"]
|
|
340
|
+
if "use_query_structure_in_decoder" in kwargs:
|
|
341
|
+
body["use_query_structure_in_decoder"] = kwargs[
|
|
342
|
+
"use_query_structure_in_decoder"
|
|
343
|
+
]
|
|
344
|
+
if kwargs.get("decoder_type"):
|
|
345
|
+
body["decoder_type"] = kwargs["decoder_type"]
|
|
346
|
+
response = session.post(endpoint, json=body)
|
|
347
|
+
return LogitsJob.model_validate(response.json())
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
def request_attn_post(
|
|
351
|
+
session: APISession,
|
|
352
|
+
model_id: str,
|
|
353
|
+
sequences: list[bytes] | list[str],
|
|
354
|
+
**kwargs,
|
|
355
|
+
) -> AttnJob:
|
|
356
|
+
"""
|
|
357
|
+
POST a request for attention embeddings from the given model ID. \
|
|
358
|
+
Returns a Job object referring to this request \
|
|
359
|
+
that can be used to retrieve results later.
|
|
360
|
+
|
|
361
|
+
Parameters
|
|
362
|
+
----------
|
|
363
|
+
session : APISession
|
|
364
|
+
Session object for API communication.
|
|
365
|
+
model_id : str
|
|
366
|
+
model ID to request results from
|
|
367
|
+
sequences : List[bytes]
|
|
368
|
+
sequences to request results for
|
|
369
|
+
**kwargs:
|
|
370
|
+
Optional parameters for models, e.g. prompt_id for PoET
|
|
371
|
+
|
|
372
|
+
Returns
|
|
373
|
+
-------
|
|
374
|
+
job : Job
|
|
375
|
+
"""
|
|
376
|
+
endpoint = PATH_PREFIX + f"/models/{model_id}/attn"
|
|
377
|
+
|
|
378
|
+
sequences_unicode = [(s if isinstance(s, str) else s.decode()) for s in sequences]
|
|
379
|
+
body: dict = {
|
|
380
|
+
"sequences": sequences_unicode,
|
|
381
|
+
}
|
|
382
|
+
if kwargs.get("prompt_id"):
|
|
383
|
+
body["prompt_id"] = kwargs["prompt_id"]
|
|
384
|
+
if kwargs.get("query_id"):
|
|
385
|
+
body["query_id"] = kwargs["query_id"]
|
|
386
|
+
if "use_query_structure_in_decoder" in kwargs:
|
|
387
|
+
body["use_query_structure_in_decoder"] = kwargs[
|
|
388
|
+
"use_query_structure_in_decoder"
|
|
389
|
+
]
|
|
390
|
+
response = session.post(endpoint, json=body)
|
|
391
|
+
return AttnJob.model_validate(response.json())
|
|
392
|
+
|
|
393
|
+
|
|
394
|
+
def request_score_post(
|
|
395
|
+
session: APISession,
|
|
396
|
+
model_id: str,
|
|
397
|
+
sequences: list[bytes] | list[str],
|
|
398
|
+
**kwargs,
|
|
399
|
+
) -> ScoreJob:
|
|
400
|
+
"""
|
|
401
|
+
POST a request for sequence scoring for the given model ID. \
|
|
402
|
+
Returns a Job object referring to this request \
|
|
403
|
+
that can be used to retrieve results later.
|
|
404
|
+
|
|
405
|
+
Parameters
|
|
406
|
+
----------
|
|
407
|
+
session : APISession
|
|
408
|
+
Session object for API communication.
|
|
409
|
+
model_id : str
|
|
410
|
+
model ID to request results from
|
|
411
|
+
sequences : List[bytes]
|
|
412
|
+
sequences to request results for
|
|
413
|
+
|
|
414
|
+
Returns
|
|
415
|
+
-------
|
|
416
|
+
job : Job
|
|
417
|
+
"""
|
|
418
|
+
endpoint = PATH_PREFIX + f"/models/{model_id}/score"
|
|
419
|
+
sequences_unicode = [(s if isinstance(s, str) else s.decode()) for s in sequences]
|
|
420
|
+
body: dict = {
|
|
421
|
+
"sequences": sequences_unicode,
|
|
422
|
+
}
|
|
423
|
+
if kwargs.get("prompt_id"):
|
|
424
|
+
body["prompt_id"] = kwargs["prompt_id"]
|
|
425
|
+
if kwargs.get("query_id"):
|
|
426
|
+
body["query_id"] = kwargs["query_id"]
|
|
427
|
+
if "use_query_structure_in_decoder" in kwargs:
|
|
428
|
+
body["use_query_structure_in_decoder"] = kwargs[
|
|
429
|
+
"use_query_structure_in_decoder"
|
|
430
|
+
]
|
|
431
|
+
if kwargs.get("decoder_type"):
|
|
432
|
+
body["decoder_type"] = kwargs["decoder_type"]
|
|
433
|
+
response = session.post(endpoint, json=body)
|
|
434
|
+
return ScoreJob.model_validate(response.json())
|
|
435
|
+
|
|
436
|
+
|
|
437
|
+
def request_score_indel_post(
|
|
438
|
+
session: APISession,
|
|
439
|
+
model_id: str,
|
|
440
|
+
base_sequence: bytes | str,
|
|
441
|
+
insert: str | None = None,
|
|
442
|
+
delete: list[int] | None = None,
|
|
443
|
+
**kwargs,
|
|
444
|
+
) -> ScoreIndelJob:
|
|
445
|
+
"""
|
|
446
|
+
POST a request for single site mutation scoring for the given model ID. \
|
|
447
|
+
Returns a Job object referring to this request \
|
|
448
|
+
that can be used to retrieve results later.
|
|
449
|
+
|
|
450
|
+
Parameters
|
|
451
|
+
----------
|
|
452
|
+
session : APISession
|
|
453
|
+
Session object for API communication.
|
|
454
|
+
model_id : str
|
|
455
|
+
model ID to request results from
|
|
456
|
+
sequences : List[bytes]
|
|
457
|
+
sequences to request results for
|
|
458
|
+
insert: str | None
|
|
459
|
+
Insertion fragment at each site.
|
|
460
|
+
delete: int | None
|
|
461
|
+
Range of size of fragment to delete at each site.
|
|
462
|
+
**kwargs:
|
|
463
|
+
Optional parameters for models, e.g. prompt_id for PoET
|
|
464
|
+
|
|
465
|
+
Returns
|
|
466
|
+
-------
|
|
467
|
+
job : Job
|
|
468
|
+
"""
|
|
469
|
+
endpoint = PATH_PREFIX + f"/models/{model_id}/score/indel"
|
|
470
|
+
|
|
471
|
+
body: dict = {
|
|
472
|
+
"base_sequence": (
|
|
473
|
+
base_sequence.decode()
|
|
474
|
+
if isinstance(base_sequence, bytes)
|
|
475
|
+
else base_sequence
|
|
476
|
+
),
|
|
477
|
+
}
|
|
478
|
+
if insert is not None:
|
|
479
|
+
body["insert"] = insert
|
|
480
|
+
if delete is not None:
|
|
481
|
+
body["delete"] = delete
|
|
482
|
+
if kwargs.get("prompt_id"):
|
|
483
|
+
body["prompt_id"] = kwargs["prompt_id"]
|
|
484
|
+
if kwargs.get("query_id"):
|
|
485
|
+
body["query_id"] = kwargs["query_id"]
|
|
486
|
+
if "use_query_structure_in_decoder" in kwargs:
|
|
487
|
+
body["use_query_structure_in_decoder"] = kwargs[
|
|
488
|
+
"use_query_structure_in_decoder"
|
|
489
|
+
]
|
|
490
|
+
if kwargs.get("decoder_type"):
|
|
491
|
+
body["decoder_type"] = kwargs["decoder_type"]
|
|
492
|
+
response = session.post(endpoint, json=body)
|
|
493
|
+
return ScoreIndelJob.model_validate(response.json())
|
|
494
|
+
|
|
495
|
+
|
|
496
|
+
def request_score_single_site_post(
|
|
497
|
+
session: APISession,
|
|
498
|
+
model_id: str,
|
|
499
|
+
base_sequence: bytes | str,
|
|
500
|
+
**kwargs,
|
|
501
|
+
) -> ScoreSingleSiteJob:
|
|
502
|
+
"""
|
|
503
|
+
POST a request for single site mutation scoring for the given model ID. \
|
|
504
|
+
Returns a Job object referring to this request \
|
|
505
|
+
that can be used to retrieve results later.
|
|
506
|
+
|
|
507
|
+
Parameters
|
|
508
|
+
----------
|
|
509
|
+
session : APISession
|
|
510
|
+
Session object for API communication.
|
|
511
|
+
model_id : str
|
|
512
|
+
model ID to request results from
|
|
513
|
+
sequences : List[bytes]
|
|
514
|
+
sequences to request results for
|
|
515
|
+
**kwargs:
|
|
516
|
+
Optional parameters for models, e.g. prompt_id for PoET
|
|
517
|
+
|
|
518
|
+
Returns
|
|
519
|
+
-------
|
|
520
|
+
job : Job
|
|
521
|
+
"""
|
|
522
|
+
endpoint = PATH_PREFIX + f"/models/{model_id}/score_single_site"
|
|
523
|
+
|
|
524
|
+
body: dict = {
|
|
525
|
+
"base_sequence": (
|
|
526
|
+
base_sequence.decode()
|
|
527
|
+
if isinstance(base_sequence, bytes)
|
|
528
|
+
else base_sequence
|
|
529
|
+
),
|
|
530
|
+
}
|
|
531
|
+
if kwargs.get("prompt_id"):
|
|
532
|
+
body["prompt_id"] = kwargs["prompt_id"]
|
|
533
|
+
if kwargs.get("query_id"):
|
|
534
|
+
body["query_id"] = kwargs["query_id"]
|
|
535
|
+
if "use_query_structure_in_decoder" in kwargs:
|
|
536
|
+
body["use_query_structure_in_decoder"] = kwargs[
|
|
537
|
+
"use_query_structure_in_decoder"
|
|
538
|
+
]
|
|
539
|
+
if kwargs.get("decoder_type"):
|
|
540
|
+
body["decoder_type"] = kwargs["decoder_type"]
|
|
541
|
+
response = session.post(endpoint, json=body)
|
|
542
|
+
return ScoreSingleSiteJob.model_validate(response.json())
|
|
543
|
+
|
|
544
|
+
|
|
545
|
+
def request_generate_post(
|
|
546
|
+
session: APISession,
|
|
547
|
+
model_id: str,
|
|
548
|
+
num_samples: int = 100,
|
|
549
|
+
temperature: float = 1.0,
|
|
550
|
+
topk: float | None = None,
|
|
551
|
+
topp: float | None = None,
|
|
552
|
+
max_length: int = 1000,
|
|
553
|
+
random_seed: int | None = None,
|
|
554
|
+
**kwargs,
|
|
555
|
+
) -> GenerateJob:
|
|
556
|
+
"""
|
|
557
|
+
POST a request for sequence generation for the given model ID. \
|
|
558
|
+
Returns a Job object referring to this request \
|
|
559
|
+
that can be used to retrieve results later.
|
|
560
|
+
|
|
561
|
+
Parameters
|
|
562
|
+
----------
|
|
563
|
+
session : APISession
|
|
564
|
+
Session object for API communication.
|
|
565
|
+
model_id : str
|
|
566
|
+
model ID to request results from
|
|
567
|
+
**kwargs:
|
|
568
|
+
Optional parameters for models, e.g. prompt_id for PoET
|
|
569
|
+
|
|
570
|
+
Returns
|
|
571
|
+
-------
|
|
572
|
+
job : Job
|
|
573
|
+
"""
|
|
574
|
+
endpoint = PATH_PREFIX + f"/models/{model_id}/generate"
|
|
575
|
+
|
|
576
|
+
if not (0.1 <= temperature <= 2):
|
|
577
|
+
raise InvalidParameterError("The 'temperature' must be between 0.1 and 2.")
|
|
578
|
+
if topk is not None and not (2 <= topk <= 20):
|
|
579
|
+
raise InvalidParameterError("The 'topk' must be between 2 and 20.")
|
|
580
|
+
if topp is not None and not (0 <= topp <= 1):
|
|
581
|
+
raise InvalidParameterError("The 'topp' must be between 0 and 1.")
|
|
582
|
+
if random_seed is not None and not (0 <= random_seed <= 2**32):
|
|
583
|
+
raise InvalidParameterError("The 'random_seed' must be between 0 and 2^32.")
|
|
584
|
+
|
|
585
|
+
if random_seed is None:
|
|
586
|
+
random_seed = random.randrange(2**32)
|
|
587
|
+
|
|
588
|
+
body: dict = {
|
|
589
|
+
"n_sequences": num_samples,
|
|
590
|
+
"temperature": temperature,
|
|
591
|
+
"maxlen": max_length,
|
|
592
|
+
}
|
|
593
|
+
if topk is not None:
|
|
594
|
+
body["topk"] = topk
|
|
595
|
+
if topp is not None:
|
|
596
|
+
body["topp"] = topp
|
|
597
|
+
if random_seed is not None:
|
|
598
|
+
body["seed"] = random_seed
|
|
599
|
+
if kwargs.get("prompt_id"):
|
|
600
|
+
body["prompt_id"] = kwargs["prompt_id"]
|
|
601
|
+
if kwargs.get("query_id"):
|
|
602
|
+
assert model_id != "poet", f"Model with id {model_id} does not support query"
|
|
603
|
+
body["query_id"] = kwargs["query_id"]
|
|
604
|
+
if "use_query_structure_in_decoder" in kwargs:
|
|
605
|
+
body["use_query_structure_in_decoder"] = kwargs[
|
|
606
|
+
"use_query_structure_in_decoder"
|
|
607
|
+
]
|
|
608
|
+
if (ensemble_weights := kwargs.get("ensemble_weights")) is not None:
|
|
609
|
+
assert (
|
|
610
|
+
model_id != "poet"
|
|
611
|
+
), f"Model with id {model_id} does not support ensemble_weights parameter"
|
|
612
|
+
body["ensemble_weights"] = list(ensemble_weights)
|
|
613
|
+
if (ensemble_method := kwargs.get("ensemble_method")) is not None:
|
|
614
|
+
assert (
|
|
615
|
+
model_id != "poet"
|
|
616
|
+
), f"Model with id {model_id} does not support ensemble_method parameter"
|
|
617
|
+
body["ensemble_method"] = ensemble_method
|
|
618
|
+
response = session.post(endpoint, json=body)
|
|
619
|
+
return GenerateJob.model_validate(response.json())
|