biolmai 0.1.4__tar.gz → 0.1.7__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of biolmai might be problematic. Click here for more details.
- {biolmai-0.1.4 → biolmai-0.1.7}/PKG-INFO +1 -1
- biolmai-0.1.7/biolmai/__init__.py +7 -0
- biolmai-0.1.7/biolmai/api.py +310 -0
- {biolmai-0.1.4 → biolmai-0.1.7}/biolmai/asynch.py +90 -53
- {biolmai-0.1.4 → biolmai-0.1.7}/biolmai/auth.py +75 -29
- biolmai-0.1.7/biolmai/biolmai.py +5 -0
- biolmai-0.1.7/biolmai/cli.py +75 -0
- biolmai-0.1.7/biolmai/cls.py +97 -0
- {biolmai-0.1.4 → biolmai-0.1.7}/biolmai/const.py +13 -11
- biolmai-0.1.7/biolmai/payloads.py +33 -0
- {biolmai-0.1.4 → biolmai-0.1.7}/biolmai/validate.py +55 -28
- {biolmai-0.1.4 → biolmai-0.1.7}/biolmai.egg-info/PKG-INFO +1 -1
- biolmai-0.1.7/biolmai.egg-info/SOURCES.txt +64 -0
- biolmai-0.1.7/docs/_static/api_reference_icon.png +0 -0
- biolmai-0.1.7/docs/_static/chat_agents_icon.png +0 -0
- biolmai-0.1.7/docs/_static/jupyter_notebooks_icon.png +0 -0
- biolmai-0.1.7/docs/_static/model_docs_icon.png +0 -0
- biolmai-0.1.7/docs/_static/python_sdk_icon.png +0 -0
- biolmai-0.1.7/docs/_static/tutorials_icon.png +0 -0
- {biolmai-0.1.4 → biolmai-0.1.7}/docs/conf.py +32 -44
- biolmai-0.1.7/docs/index.rst +107 -0
- biolmai-0.1.7/docs/model-docs/DNABERT.rst +640 -0
- biolmai-0.1.7/docs/model-docs/ESM-1v.rst +362 -0
- biolmai-0.1.7/docs/model-docs/ESM2_Embeddings.rst +242 -0
- biolmai-0.1.4/docs/model-docs/esm2_fold.rst → biolmai-0.1.7/docs/model-docs/ESMFold.rst +62 -63
- biolmai-0.1.7/docs/model-docs/ESM_InverseFold.rst +278 -0
- biolmai-0.1.7/docs/model-docs/ProtGPT2.rst +609 -0
- biolmai-0.1.7/docs/model-docs/ProteInfer_EC.rst +249 -0
- biolmai-0.1.7/docs/model-docs/ProteInfer_GO.rst +329 -0
- biolmai-0.1.7/docs/model-docs/index.rst +13 -0
- biolmai-0.1.7/docs/model-docs/progen2/ProGen2_BFD90.rst +251 -0
- biolmai-0.1.7/docs/model-docs/progen2/ProGen2_Medium.rst +248 -0
- biolmai-0.1.7/docs/model-docs/progen2/ProGen2_OAS.rst +246 -0
- biolmai-0.1.7/docs/model-docs/progen2/index.rst +10 -0
- biolmai-0.1.7/docs/python-client/get_started/authorization.rst +9 -0
- {biolmai-0.1.4/docs/python-client → biolmai-0.1.7/docs/python-client/get_started}/quickstart.rst +7 -0
- biolmai-0.1.7/docs/python-client/index.rst +18 -0
- biolmai-0.1.7/docs/python-client/usage.rst +7 -0
- biolmai-0.1.7/docs/tutorials_use_cases/notebooks.rst +9 -0
- biolmai-0.1.7/pyproject.toml +44 -0
- {biolmai-0.1.4 → biolmai-0.1.7}/setup.cfg +7 -2
- biolmai-0.1.7/setup.py +53 -0
- biolmai-0.1.7/tests/test_biolmai.py +263 -0
- biolmai-0.1.4/biolmai/__init__.py +0 -15
- biolmai-0.1.4/biolmai/api.py +0 -394
- biolmai-0.1.4/biolmai/biolmai.py +0 -153
- biolmai-0.1.4/biolmai/cli.py +0 -67
- biolmai-0.1.4/biolmai/cls.py +0 -1
- biolmai-0.1.4/biolmai/payloads.py +0 -8
- biolmai-0.1.4/biolmai.egg-info/SOURCES.txt +0 -51
- biolmai-0.1.4/docs/index.rst +0 -74
- biolmai-0.1.4/docs/model-docs/admonitions.rst +0 -39
- biolmai-0.1.4/docs/model-docs/esm2_embeddings.rst +0 -10
- biolmai-0.1.4/docs/python-client/authors.rst +0 -1
- biolmai-0.1.4/docs/python-client/contributing.rst +0 -1
- biolmai-0.1.4/docs/python-client/history.rst +0 -1
- biolmai-0.1.4/docs/python-client/readme.rst +0 -1
- biolmai-0.1.4/docs/python-client/usage.rst +0 -8
- biolmai-0.1.4/docs/tutorials_use_cases/bulk_protein_folding.rst +0 -3
- biolmai-0.1.4/docs/tutorials_use_cases/dna_tutorials.rst +0 -8
- biolmai-0.1.4/docs/tutorials_use_cases/protein_tutorials.rst +0 -15
- biolmai-0.1.4/setup.py +0 -51
- biolmai-0.1.4/tests/test_biolmai.py +0 -226
- {biolmai-0.1.4 → biolmai-0.1.7}/AUTHORS.rst +0 -0
- {biolmai-0.1.4 → biolmai-0.1.7}/CONTRIBUTING.rst +0 -0
- {biolmai-0.1.4 → biolmai-0.1.7}/HISTORY.rst +0 -0
- {biolmai-0.1.4 → biolmai-0.1.7}/LICENSE +0 -0
- {biolmai-0.1.4 → biolmai-0.1.7}/MANIFEST.in +0 -0
- {biolmai-0.1.4 → biolmai-0.1.7}/README.rst +0 -0
- {biolmai-0.1.4 → biolmai-0.1.7}/biolmai/ltc.py +0 -0
- {biolmai-0.1.4 → biolmai-0.1.7}/biolmai.egg-info/dependency_links.txt +0 -0
- {biolmai-0.1.4 → biolmai-0.1.7}/biolmai.egg-info/entry_points.txt +0 -0
- {biolmai-0.1.4 → biolmai-0.1.7}/biolmai.egg-info/not-zip-safe +0 -0
- {biolmai-0.1.4 → biolmai-0.1.7}/biolmai.egg-info/requires.txt +0 -0
- {biolmai-0.1.4 → biolmai-0.1.7}/biolmai.egg-info/top_level.txt +0 -0
- {biolmai-0.1.4 → biolmai-0.1.7}/docs/Makefile +0 -0
- {biolmai-0.1.4 → biolmai-0.1.7}/docs/_static/biolm_docs_logo_dark.png +0 -0
- {biolmai-0.1.4 → biolmai-0.1.7}/docs/_static/biolm_docs_logo_light.png +0 -0
- {biolmai-0.1.4 → biolmai-0.1.7}/docs/biolmai.rst +0 -0
- {biolmai-0.1.4 → biolmai-0.1.7}/docs/make.bat +0 -0
- {biolmai-0.1.4 → biolmai-0.1.7}/docs/model-docs/img/book_icon.png +0 -0
- {biolmai-0.1.4 → biolmai-0.1.7}/docs/model-docs/img/esmfold_perf.png +0 -0
- {biolmai-0.1.4 → biolmai-0.1.7}/docs/modules.rst +0 -0
- {biolmai-0.1.4/docs/python-client → biolmai-0.1.7/docs/python-client/get_started}/installation.rst +0 -0
- {biolmai-0.1.4 → biolmai-0.1.7}/tests/__init__.py +0 -0
|
@@ -0,0 +1,310 @@
|
|
|
1
|
+
"""References to API endpoints."""
|
|
2
|
+
import datetime
|
|
3
|
+
import inspect
|
|
4
|
+
import time
|
|
5
|
+
from functools import lru_cache
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import pandas as pd
|
|
9
|
+
import requests
|
|
10
|
+
from requests.adapters import HTTPAdapter
|
|
11
|
+
from requests.packages.urllib3.util.retry import Retry
|
|
12
|
+
|
|
13
|
+
import biolmai
|
|
14
|
+
import biolmai.auth
|
|
15
|
+
from biolmai.asynch import async_api_call_wrapper
|
|
16
|
+
from biolmai.biolmai import log
|
|
17
|
+
from biolmai.const import MULTIPROCESS_THREADS
|
|
18
|
+
from biolmai.payloads import INST_DAT_TXT, predict_resp_many_in_one_to_many_singles
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@lru_cache(maxsize=64)
|
|
22
|
+
def validate_endpoint_action(allowed_classes, method_name, api_class_name):
|
|
23
|
+
action_method_name = method_name.split(".")[-1]
|
|
24
|
+
if action_method_name not in allowed_classes:
|
|
25
|
+
err = "Only {} supported on {}"
|
|
26
|
+
err = err.format(list(allowed_classes), api_class_name)
|
|
27
|
+
raise AssertionError(err)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def text_validator(text, c):
|
|
31
|
+
"""Validate some text against a class-based validator, returning a string
|
|
32
|
+
if invalid, or None otherwise."""
|
|
33
|
+
try:
|
|
34
|
+
c(text)
|
|
35
|
+
except Exception as e:
|
|
36
|
+
return str(e)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def validate(f):
|
|
40
|
+
def wrapper(*args, **kwargs):
|
|
41
|
+
# Get class instance at runtime, so you can access not just
|
|
42
|
+
# APIEndpoints, but any *parent* classes of that,
|
|
43
|
+
# like ESMFoldSinglechain.
|
|
44
|
+
class_obj_self = args[0]
|
|
45
|
+
try:
|
|
46
|
+
is_method = inspect.getfullargspec(f)[0][0] == "self"
|
|
47
|
+
except Exception:
|
|
48
|
+
is_method = False
|
|
49
|
+
|
|
50
|
+
# Is the function we decorated a class method?
|
|
51
|
+
if is_method:
|
|
52
|
+
name = f"{f.__module__}.{class_obj_self.__class__.__name__}.{f.__name__}"
|
|
53
|
+
else:
|
|
54
|
+
name = f"{f.__module__}.{f.__name__}"
|
|
55
|
+
|
|
56
|
+
if is_method:
|
|
57
|
+
# Splits name, e.g. 'biolmai.api.ESMFoldSingleChain.predict'
|
|
58
|
+
action_method_name = name.split(".")[-1]
|
|
59
|
+
validate_endpoint_action(
|
|
60
|
+
class_obj_self.action_class_strings,
|
|
61
|
+
action_method_name,
|
|
62
|
+
class_obj_self.__class__.__name__,
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
input_data = args[1]
|
|
66
|
+
# Validate each row's text/input based on class attribute `seq_classes`
|
|
67
|
+
for c in class_obj_self.seq_classes:
|
|
68
|
+
# Validate input data against regex
|
|
69
|
+
if class_obj_self.multiprocess_threads:
|
|
70
|
+
validation = input_data.text.apply(text_validator, args=(c,))
|
|
71
|
+
else:
|
|
72
|
+
validation = input_data.text.apply(text_validator, args=(c,))
|
|
73
|
+
if "validation" not in input_data.columns:
|
|
74
|
+
input_data["validation"] = validation
|
|
75
|
+
else:
|
|
76
|
+
input_data["validation"] = input_data["validation"].str.cat(
|
|
77
|
+
validation, sep="\n", na_rep=""
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
# Mark your batches, excluding invalid rows
|
|
81
|
+
valid_dat = input_data.loc[input_data.validation.isnull(), :].copy()
|
|
82
|
+
N = class_obj_self.batch_size # N rows will go per API request
|
|
83
|
+
# JOIN back, which is by index
|
|
84
|
+
if valid_dat.shape[0] != input_data.shape[0]:
|
|
85
|
+
valid_dat["batch"] = np.arange(valid_dat.shape[0]) // N
|
|
86
|
+
input_data = input_data.merge(
|
|
87
|
+
valid_dat.batch, left_index=True, right_index=True, how="left"
|
|
88
|
+
)
|
|
89
|
+
else:
|
|
90
|
+
input_data["batch"] = np.arange(input_data.shape[0]) // N
|
|
91
|
+
|
|
92
|
+
res = f(class_obj_self, input_data, **kwargs)
|
|
93
|
+
return res
|
|
94
|
+
|
|
95
|
+
return wrapper
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def convert_input(f):
|
|
99
|
+
def wrapper(*args, **kwargs):
|
|
100
|
+
# Get the user-input data argument to the decorated function
|
|
101
|
+
# class_obj_self = args[0]
|
|
102
|
+
input_data = args[1]
|
|
103
|
+
# Make sure we have expected input types
|
|
104
|
+
acceptable_inputs = (str, list, tuple, np.ndarray, pd.DataFrame)
|
|
105
|
+
if not isinstance(input_data, acceptable_inputs):
|
|
106
|
+
err = "Input must be one or many DNA or protein strings"
|
|
107
|
+
raise ValueError(err)
|
|
108
|
+
# Convert single-sequence input to list
|
|
109
|
+
if isinstance(input_data, str):
|
|
110
|
+
input_data = [input_data]
|
|
111
|
+
# Make sure we don't have a matrix
|
|
112
|
+
if isinstance(input_data, np.ndarray) and len(input_data.shape) > 1:
|
|
113
|
+
err = "Detected Numpy matrix - input a single vector or array"
|
|
114
|
+
raise AssertionError(err)
|
|
115
|
+
# Make sure we don't have a >=2D DF
|
|
116
|
+
if isinstance(input_data, pd.DataFrame) and len(input_data.shape) > 1:
|
|
117
|
+
err = "Detected Pandas DataFrame - input a single vector or Series"
|
|
118
|
+
raise AssertionError(err)
|
|
119
|
+
input_data = pd.DataFrame(input_data, columns=["text"])
|
|
120
|
+
return f(args[0], input_data, **kwargs)
|
|
121
|
+
|
|
122
|
+
return wrapper
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
class APIEndpoint:
|
|
126
|
+
batch_size = 3 # Overwrite in parent classes as needed
|
|
127
|
+
|
|
128
|
+
def __init__(self, multiprocess_threads=None):
|
|
129
|
+
# Check for instance-specific threads, otherwise read from env var
|
|
130
|
+
if multiprocess_threads is not None:
|
|
131
|
+
self.multiprocess_threads = multiprocess_threads
|
|
132
|
+
else:
|
|
133
|
+
self.multiprocess_threads = MULTIPROCESS_THREADS # Could be False
|
|
134
|
+
# Get correct auth-like headers
|
|
135
|
+
self.auth_headers = biolmai.auth.get_user_auth_header()
|
|
136
|
+
self.action_class_strings = tuple(
|
|
137
|
+
[c.__name__.replace("Action", "").lower() for c in self.action_classes]
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
def post_batches(self, dat, slug, action, payload_maker, resp_key):
|
|
141
|
+
keep_batches = dat.loc[~dat.batch.isnull(), ["text", "batch"]]
|
|
142
|
+
if keep_batches.shape[0] == 0:
|
|
143
|
+
pass # Do nothing - we made nice JSON errors to return in the DF
|
|
144
|
+
# err = "No inputs found following local validation"
|
|
145
|
+
# raise AssertionError(err)
|
|
146
|
+
if keep_batches.shape[0] > 0:
|
|
147
|
+
api_resps = async_api_call_wrapper(
|
|
148
|
+
keep_batches, slug, action, payload_maker, resp_key
|
|
149
|
+
)
|
|
150
|
+
if isinstance(api_resps, pd.DataFrame):
|
|
151
|
+
batch_res = api_resps.explode("api_resp") # Should be lists of results
|
|
152
|
+
len_res = batch_res.shape[0]
|
|
153
|
+
else:
|
|
154
|
+
batch_res = pd.DataFrame({"api_resp": api_resps})
|
|
155
|
+
len_res = batch_res.shape[0]
|
|
156
|
+
orig_request_rows = keep_batches.shape[0]
|
|
157
|
+
if len_res != orig_request_rows:
|
|
158
|
+
err = "Response rows ({}) mismatch with input rows ({})"
|
|
159
|
+
err = err.format(len_res, orig_request_rows)
|
|
160
|
+
raise AssertionError(err)
|
|
161
|
+
|
|
162
|
+
# Stack the results horizontally w/ original rows of batches
|
|
163
|
+
keep_batches["prev_idx"] = keep_batches.index
|
|
164
|
+
keep_batches.reset_index(drop=False, inplace=True)
|
|
165
|
+
batch_res.reset_index(drop=True, inplace=True)
|
|
166
|
+
keep_batches["api_resp"] = batch_res
|
|
167
|
+
keep_batches.set_index("prev_idx", inplace=True)
|
|
168
|
+
dat = dat.join(keep_batches.reindex(["api_resp"], axis=1))
|
|
169
|
+
else:
|
|
170
|
+
dat["api_resp"] = None
|
|
171
|
+
return dat
|
|
172
|
+
|
|
173
|
+
def unpack_local_validations(self, dat):
|
|
174
|
+
dat.loc[dat.api_resp.isnull(), "api_resp"] = (
|
|
175
|
+
dat.loc[~dat.validation.isnull(), "validation"]
|
|
176
|
+
.apply(
|
|
177
|
+
predict_resp_many_in_one_to_many_singles, args=(None, None, True, None)
|
|
178
|
+
)
|
|
179
|
+
.explode()
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
return dat
|
|
183
|
+
|
|
184
|
+
@convert_input
|
|
185
|
+
@validate
|
|
186
|
+
def predict(self, dat):
|
|
187
|
+
dat = self.post_batches(dat, self.slug, "predict", INST_DAT_TXT, "predictions")
|
|
188
|
+
dat = self.unpack_local_validations(dat)
|
|
189
|
+
return dat.api_resp.replace(np.nan, None).tolist()
|
|
190
|
+
|
|
191
|
+
def infer(self, dat):
|
|
192
|
+
return self.predict(dat)
|
|
193
|
+
|
|
194
|
+
@convert_input
|
|
195
|
+
@validate
|
|
196
|
+
def transform(self, dat):
|
|
197
|
+
dat = self.post_batches(
|
|
198
|
+
dat, self.slug, "transform", INST_DAT_TXT, "predictions"
|
|
199
|
+
)
|
|
200
|
+
dat = self.unpack_local_validations(dat)
|
|
201
|
+
return dat.api_resp.replace(np.nan, None).tolist()
|
|
202
|
+
|
|
203
|
+
# @convert_input
|
|
204
|
+
# @validate
|
|
205
|
+
# def encode(self, dat):
|
|
206
|
+
# # NOTE: we defined this for the specific case of ESM2
|
|
207
|
+
# # TODO: this will be need again in v2 of API contract
|
|
208
|
+
# dat = self.post_batches(dat, self.slug, "transform",
|
|
209
|
+
# INST_DAT_TXT, "embeddings")
|
|
210
|
+
# dat = self.unpack_local_validations(dat)
|
|
211
|
+
# return dat.api_resp.replace(np.nan, None).tolist()
|
|
212
|
+
|
|
213
|
+
@convert_input
|
|
214
|
+
@validate
|
|
215
|
+
def generate(self, dat):
|
|
216
|
+
dat = self.post_batches(dat, self.slug, "generate", INST_DAT_TXT, "generated")
|
|
217
|
+
dat = self.unpack_local_validations(dat)
|
|
218
|
+
return dat.api_resp.replace(np.nan, None).tolist()
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
def retry_minutes(sess, URL, HEADERS, dat, timeout, mins):
|
|
222
|
+
"""Retry for N minutes."""
|
|
223
|
+
HEADERS.update({"Content-Type": "application/json"})
|
|
224
|
+
attempts, max_attempts = 0, 5
|
|
225
|
+
try:
|
|
226
|
+
now = datetime.datetime.now()
|
|
227
|
+
try_until = now + datetime.timedelta(minutes=mins)
|
|
228
|
+
while datetime.datetime.now() < try_until and attempts < max_attempts:
|
|
229
|
+
response = None
|
|
230
|
+
try:
|
|
231
|
+
log.info(f"Trying {datetime.datetime.now()}")
|
|
232
|
+
response = sess.post(URL, headers=HEADERS, data=dat, timeout=timeout)
|
|
233
|
+
if response.status_code not in (400, 404):
|
|
234
|
+
response.raise_for_status()
|
|
235
|
+
if "error" in response.json():
|
|
236
|
+
raise ValueError(response.json().dumps())
|
|
237
|
+
else:
|
|
238
|
+
break
|
|
239
|
+
except Exception as e:
|
|
240
|
+
log.warning(e)
|
|
241
|
+
if response:
|
|
242
|
+
log.warning(response.text)
|
|
243
|
+
time.sleep(5) # Wait 5 seconds between tries
|
|
244
|
+
attempts += 1
|
|
245
|
+
if response is None:
|
|
246
|
+
err = "Got Nonetype response"
|
|
247
|
+
raise ValueError(err)
|
|
248
|
+
elif "Server Error" in response.text:
|
|
249
|
+
err = "Got Server Error"
|
|
250
|
+
raise ValueError(err)
|
|
251
|
+
except Exception:
|
|
252
|
+
return response
|
|
253
|
+
return response
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
def requests_retry_session(
|
|
257
|
+
retries=3,
|
|
258
|
+
backoff_factor=0.3,
|
|
259
|
+
status_forcelist=None,
|
|
260
|
+
session=None,
|
|
261
|
+
):
|
|
262
|
+
if status_forcelist is None:
|
|
263
|
+
status_forcelist = list(range(400, 599))
|
|
264
|
+
session = session or requests.Session()
|
|
265
|
+
retry = Retry(
|
|
266
|
+
total=retries,
|
|
267
|
+
read=retries,
|
|
268
|
+
connect=retries,
|
|
269
|
+
backoff_factor=backoff_factor,
|
|
270
|
+
status_forcelist=status_forcelist,
|
|
271
|
+
)
|
|
272
|
+
adapter = HTTPAdapter(max_retries=retry)
|
|
273
|
+
session.mount("http://", adapter)
|
|
274
|
+
session.mount("https://", adapter)
|
|
275
|
+
return session
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
class PredictAction:
|
|
279
|
+
def __str__(self):
|
|
280
|
+
return "PredictAction"
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
class GenerateAction:
|
|
284
|
+
def __str__(self):
|
|
285
|
+
return "GenerateAction"
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
class TransformAction:
|
|
289
|
+
def __str__(self):
|
|
290
|
+
return "TransformAction"
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
# class EncodeAction:
|
|
294
|
+
# def __str__(self):
|
|
295
|
+
# return "EncodeAction"
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
class ExplainAction:
|
|
299
|
+
def __str__(self):
|
|
300
|
+
return "ExplainAction"
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
class SimilarityAction:
|
|
304
|
+
def __str__(self):
|
|
305
|
+
return "SimilarityAction"
|
|
306
|
+
|
|
307
|
+
|
|
308
|
+
class FinetuneAction:
|
|
309
|
+
def __str__(self):
|
|
310
|
+
return "FinetuneAction"
|
|
@@ -1,23 +1,15 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
from asyncio import create_task, gather, run
|
|
3
|
+
from itertools import zip_longest
|
|
4
|
+
from typing import Dict, List
|
|
5
|
+
|
|
1
6
|
import aiohttp.resolver
|
|
7
|
+
from aiohttp import ClientSession
|
|
2
8
|
|
|
9
|
+
from biolmai.auth import get_user_auth_header
|
|
3
10
|
from biolmai.const import BASE_API_URL, MULTIPROCESS_THREADS
|
|
4
11
|
|
|
5
12
|
aiohttp.resolver.DefaultResolver = aiohttp.resolver.AsyncResolver
|
|
6
|
-
from aiohttp import ClientSession, TCPConnector
|
|
7
|
-
from typing import List
|
|
8
|
-
import json
|
|
9
|
-
import asyncio
|
|
10
|
-
|
|
11
|
-
from asyncio import create_task, gather, run, sleep
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
async def get_one(session: ClientSession, slug: str, action: str,
|
|
16
|
-
payload: dict, response_key: str):
|
|
17
|
-
pass
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
from aiohttp import ClientSession
|
|
21
13
|
|
|
22
14
|
|
|
23
15
|
async def get_one(session: ClientSession, url: str) -> None:
|
|
@@ -30,25 +22,31 @@ async def get_one(session: ClientSession, url: str) -> None:
|
|
|
30
22
|
return text_resp
|
|
31
23
|
|
|
32
24
|
|
|
33
|
-
async def get_one_biolm(
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
25
|
+
async def get_one_biolm(
|
|
26
|
+
session: ClientSession,
|
|
27
|
+
url: str,
|
|
28
|
+
pload: dict,
|
|
29
|
+
headers: dict,
|
|
30
|
+
response_key: str = None,
|
|
31
|
+
) -> None:
|
|
38
32
|
print("Requesting", url)
|
|
39
|
-
pload_batch = pload.pop(
|
|
40
|
-
pload_batch_size = pload.pop(
|
|
33
|
+
pload_batch = pload.pop("batch")
|
|
34
|
+
pload_batch_size = pload.pop("batch_size")
|
|
41
35
|
t = aiohttp.ClientTimeout(
|
|
42
|
-
total=
|
|
43
|
-
# total timeout (time consists connection establishment for
|
|
36
|
+
total=1600, # 27 mins
|
|
37
|
+
# total timeout (time consists connection establishment for
|
|
38
|
+
# a new connection or waiting for a free connection from a
|
|
39
|
+
# pool if pool connection limits are exceeded) default value
|
|
40
|
+
# is 5 minutes, set to `None` or `0` for unlimited timeout
|
|
44
41
|
sock_connect=None,
|
|
45
|
-
# Maximal number of seconds for connecting to a peer for a
|
|
42
|
+
# Maximal number of seconds for connecting to a peer for a
|
|
43
|
+
# new connection, not given from a pool. See also connect.
|
|
46
44
|
sock_read=None
|
|
47
45
|
# Maximal number of seconds for reading a portion of data from a peer
|
|
48
46
|
)
|
|
49
47
|
async with session.post(url, headers=headers, json=pload, timeout=t) as resp:
|
|
50
48
|
resp_json = await resp.json()
|
|
51
|
-
resp_json[
|
|
49
|
+
resp_json["batch"] = pload_batch
|
|
52
50
|
status_code = resp.status
|
|
53
51
|
expected_root_key = response_key
|
|
54
52
|
to_ret = []
|
|
@@ -61,9 +59,7 @@ async def get_one_biolm(session: ClientSession,
|
|
|
61
59
|
else:
|
|
62
60
|
raise ValueError("Unexpected response in parser")
|
|
63
61
|
for idx, item in enumerate(list_of_individual_seq_results):
|
|
64
|
-
d = {
|
|
65
|
-
'batch_id': pload_batch,
|
|
66
|
-
'batch_item': idx}
|
|
62
|
+
d = {"status_code": status_code, "batch_id": pload_batch, "batch_item": idx}
|
|
67
63
|
if not status_code or status_code != 200:
|
|
68
64
|
d.update(item) # Put all resp keys at root there
|
|
69
65
|
else:
|
|
@@ -77,16 +73,15 @@ async def get_one_biolm(session: ClientSession,
|
|
|
77
73
|
# await sleep(2) # for demo purposes
|
|
78
74
|
# text_resp = text.strip().split("\n", 1)[0]
|
|
79
75
|
# print("Got response from", url, text_resp)
|
|
80
|
-
return j
|
|
81
76
|
|
|
82
77
|
|
|
83
78
|
async def async_range(count):
|
|
84
79
|
for i in range(count):
|
|
85
|
-
yield(i)
|
|
80
|
+
yield (i)
|
|
86
81
|
await asyncio.sleep(0.0)
|
|
87
82
|
|
|
88
83
|
|
|
89
|
-
async def get_all(urls: List[str], num_concurrent: int) ->
|
|
84
|
+
async def get_all(urls: List[str], num_concurrent: int) -> list:
|
|
90
85
|
url_iterator = iter(urls)
|
|
91
86
|
keep_going = True
|
|
92
87
|
results = []
|
|
@@ -106,22 +101,26 @@ async def get_all(urls: List[str], num_concurrent: int) -> List:
|
|
|
106
101
|
return results
|
|
107
102
|
|
|
108
103
|
|
|
109
|
-
async def get_all_biolm(
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
104
|
+
async def get_all_biolm(
|
|
105
|
+
url: str,
|
|
106
|
+
ploads: List[Dict],
|
|
107
|
+
headers: dict,
|
|
108
|
+
num_concurrent: int,
|
|
109
|
+
response_key: str = None,
|
|
110
|
+
) -> list:
|
|
114
111
|
ploads_iterator = iter(ploads)
|
|
115
112
|
keep_going = True
|
|
116
113
|
results = []
|
|
117
|
-
connector = aiohttp.TCPConnector(limit=100,
|
|
118
|
-
limit_per_host=50,
|
|
119
|
-
ttl_dns_cache=60)
|
|
114
|
+
connector = aiohttp.TCPConnector(limit=100, limit_per_host=50, ttl_dns_cache=60)
|
|
120
115
|
ov_tout = aiohttp.ClientTimeout(
|
|
121
116
|
total=None,
|
|
122
|
-
# total timeout (time consists connection establishment for
|
|
117
|
+
# total timeout (time consists connection establishment for
|
|
118
|
+
# a new connection or waiting for a free connection from a
|
|
119
|
+
# pool if pool connection limits are exceeded) default value
|
|
120
|
+
# is 5 minutes, set to `None` or `0` for unlimited timeout
|
|
123
121
|
sock_connect=None,
|
|
124
|
-
# Maximal number of seconds for connecting to a peer for a
|
|
122
|
+
# Maximal number of seconds for connecting to a peer for a
|
|
123
|
+
# new connection, not given from a pool. See also connect.
|
|
125
124
|
sock_read=None
|
|
126
125
|
# Maximal number of seconds for reading a portion of data from a peer
|
|
127
126
|
)
|
|
@@ -134,35 +133,31 @@ async def get_all_biolm(url: str,
|
|
|
134
133
|
except StopIteration:
|
|
135
134
|
keep_going = False
|
|
136
135
|
break
|
|
137
|
-
new_task = create_task(
|
|
138
|
-
|
|
136
|
+
new_task = create_task(
|
|
137
|
+
get_one_biolm(session, url, pload, headers, response_key)
|
|
138
|
+
)
|
|
139
139
|
tasks.append(new_task)
|
|
140
140
|
res = await gather(*tasks)
|
|
141
141
|
results.extend(res)
|
|
142
142
|
return results
|
|
143
143
|
|
|
144
144
|
|
|
145
|
-
async def async_main(urls, concurrency) ->
|
|
145
|
+
async def async_main(urls, concurrency) -> list:
|
|
146
146
|
return await get_all(urls, concurrency)
|
|
147
147
|
|
|
148
148
|
|
|
149
|
-
async def async_api_calls(model_name,
|
|
150
|
-
action,
|
|
151
|
-
headers,
|
|
152
|
-
payloads,
|
|
153
|
-
response_key=None):
|
|
149
|
+
async def async_api_calls(model_name, action, headers, payloads, response_key=None):
|
|
154
150
|
"""Hit an arbitrary BioLM model inference API."""
|
|
155
151
|
# Normally would POST multiple sequences at once for greater efficiency,
|
|
156
152
|
# but for simplicity sake will do one at at time right now
|
|
157
|
-
url = f
|
|
153
|
+
url = f"{BASE_API_URL}/models/{model_name}/{action}/"
|
|
158
154
|
|
|
159
155
|
if not isinstance(payloads, (list, dict)):
|
|
160
156
|
err = "API request payload must be a list or dict, got {}"
|
|
161
157
|
raise AssertionError(err.format(type(payloads)))
|
|
162
158
|
|
|
163
159
|
concurrency = int(MULTIPROCESS_THREADS)
|
|
164
|
-
return await get_all_biolm(url, payloads, headers, concurrency,
|
|
165
|
-
response_key)
|
|
160
|
+
return await get_all_biolm(url, payloads, headers, concurrency, response_key)
|
|
166
161
|
|
|
167
162
|
# payload = json.dumps(payload)
|
|
168
163
|
# session = requests_retry_session()
|
|
@@ -183,3 +178,45 @@ async def async_api_calls(model_name,
|
|
|
183
178
|
# headers = get_user_auth_header() # Need to re-get these now
|
|
184
179
|
# response = retry_minutes(session, url, headers, payload, tout, mins=10)
|
|
185
180
|
# return response
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def async_api_call_wrapper(grouped_df, slug, action, payload_maker, response_key):
|
|
184
|
+
"""Wrap API calls to assist with sequence validation as a pre-cursor to
|
|
185
|
+
each API call.
|
|
186
|
+
"""
|
|
187
|
+
model_name = slug
|
|
188
|
+
# payload = payload_maker(grouped_df)
|
|
189
|
+
init_ploads = grouped_df.groupby("batch").apply(
|
|
190
|
+
payload_maker, include_batch_size=True
|
|
191
|
+
)
|
|
192
|
+
ploads = init_ploads.to_list()
|
|
193
|
+
init_ploads = init_ploads.to_frame(name="pload")
|
|
194
|
+
init_ploads["batch"] = init_ploads.index
|
|
195
|
+
init_ploads = init_ploads.reset_index(drop=True)
|
|
196
|
+
assert len(ploads) == init_ploads.shape[0]
|
|
197
|
+
for inst, b in zip_longest(ploads, init_ploads["batch"].to_list()):
|
|
198
|
+
if inst is None or b is None:
|
|
199
|
+
raise ValueError(
|
|
200
|
+
"ploads and init_ploads['batch'] are not of the same length"
|
|
201
|
+
)
|
|
202
|
+
inst["batch"] = b
|
|
203
|
+
|
|
204
|
+
headers = get_user_auth_header() # Need to pull each time
|
|
205
|
+
# urls = [
|
|
206
|
+
# "https://github.com",
|
|
207
|
+
# "https://stackoverflow.com",
|
|
208
|
+
# "https://python.org",
|
|
209
|
+
# ]
|
|
210
|
+
# concurrency = 3
|
|
211
|
+
api_resp = run(async_api_calls(model_name, action, headers, ploads, response_key))
|
|
212
|
+
api_resp = [item for sublist in api_resp for item in sublist]
|
|
213
|
+
api_resp = sorted(api_resp, key=lambda x: x["batch_id"])
|
|
214
|
+
# print(api_resp)
|
|
215
|
+
# api_resp = biolmai.api_call(model_name, action, headers, payload,
|
|
216
|
+
# response_key)
|
|
217
|
+
# resp_json = api_resp.json()
|
|
218
|
+
# batch_id = int(grouped_df.batch.iloc[0])
|
|
219
|
+
# batch_size = grouped_df.shape[0]
|
|
220
|
+
# response = predict_resp_many_in_one_to_many_singles(
|
|
221
|
+
# resp_json, api_resp.status_code, batch_id, None, batch_size)
|
|
222
|
+
return api_resp
|