biolmai 0.1.5__py2.py3-none-any.whl → 0.1.7__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of biolmai might be problematic. Click here for more details.
- biolmai/__init__.py +3 -7
- biolmai/api.py +93 -89
- biolmai/asynch.py +65 -67
- biolmai/auth.py +47 -44
- biolmai/biolmai.py +1 -3
- biolmai/cli.py +30 -22
- biolmai/cls.py +33 -36
- biolmai/const.py +13 -11
- biolmai/payloads.py +9 -10
- biolmai/validate.py +55 -28
- {biolmai-0.1.5.dist-info → biolmai-0.1.7.dist-info}/METADATA +1 -1
- biolmai-0.1.7.dist-info/RECORD +18 -0
- {biolmai-0.1.5.dist-info → biolmai-0.1.7.dist-info}/WHEEL +1 -1
- biolmai-0.1.5.dist-info/RECORD +0 -18
- {biolmai-0.1.5.dist-info → biolmai-0.1.7.dist-info}/AUTHORS.rst +0 -0
- {biolmai-0.1.5.dist-info → biolmai-0.1.7.dist-info}/LICENSE +0 -0
- {biolmai-0.1.5.dist-info → biolmai-0.1.7.dist-info}/entry_points.txt +0 -0
- {biolmai-0.1.5.dist-info → biolmai-0.1.7.dist-info}/top_level.txt +0 -0
biolmai/__init__.py
CHANGED
|
@@ -1,11 +1,7 @@
|
|
|
1
1
|
"""Top-level package for BioLM AI."""
|
|
2
2
|
__author__ = """Nikhil Haas"""
|
|
3
|
-
__email__ =
|
|
4
|
-
__version__ = '0.1.
|
|
3
|
+
__email__ = "nikhil@biolm.ai"
|
|
4
|
+
__version__ = '0.1.7'
|
|
5
5
|
|
|
6
|
-
from biolmai.auth import get_api_token
|
|
7
|
-
from biolmai.cls import ESMFoldSingleChain, ESMFoldMultiChain, ESM2Embeddings, ESM1v1, ESM1v2, ESM1v3, ESM1v4, ESM1v5
|
|
8
6
|
|
|
9
|
-
__all__ = [
|
|
10
|
-
|
|
11
|
-
]
|
|
7
|
+
__all__ = []
|
biolmai/api.py
CHANGED
|
@@ -1,33 +1,29 @@
|
|
|
1
1
|
"""References to API endpoints."""
|
|
2
2
|
import datetime
|
|
3
|
+
import inspect
|
|
3
4
|
import time
|
|
5
|
+
from functools import lru_cache
|
|
4
6
|
|
|
7
|
+
import numpy as np
|
|
8
|
+
import pandas as pd
|
|
5
9
|
import requests
|
|
6
10
|
from requests.adapters import HTTPAdapter
|
|
11
|
+
from requests.packages.urllib3.util.retry import Retry
|
|
7
12
|
|
|
8
|
-
import biolmai.auth
|
|
9
13
|
import biolmai
|
|
10
|
-
import
|
|
11
|
-
import pandas as pd
|
|
12
|
-
import numpy as np
|
|
14
|
+
import biolmai.auth
|
|
13
15
|
from biolmai.asynch import async_api_call_wrapper
|
|
14
|
-
|
|
15
16
|
from biolmai.biolmai import log
|
|
16
17
|
from biolmai.const import MULTIPROCESS_THREADS
|
|
17
|
-
from functools import lru_cache
|
|
18
|
-
|
|
19
18
|
from biolmai.payloads import INST_DAT_TXT, predict_resp_many_in_one_to_many_singles
|
|
20
19
|
|
|
21
20
|
|
|
22
21
|
@lru_cache(maxsize=64)
|
|
23
22
|
def validate_endpoint_action(allowed_classes, method_name, api_class_name):
|
|
24
|
-
action_method_name = method_name.split(
|
|
23
|
+
action_method_name = method_name.split(".")[-1]
|
|
25
24
|
if action_method_name not in allowed_classes:
|
|
26
|
-
err =
|
|
27
|
-
err = err.format(
|
|
28
|
-
list(allowed_classes),
|
|
29
|
-
api_class_name
|
|
30
|
-
)
|
|
25
|
+
err = "Only {} supported on {}"
|
|
26
|
+
err = err.format(list(allowed_classes), api_class_name)
|
|
31
27
|
raise AssertionError(err)
|
|
32
28
|
|
|
33
29
|
|
|
@@ -47,25 +43,23 @@ def validate(f):
|
|
|
47
43
|
# like ESMFoldSinglechain.
|
|
48
44
|
class_obj_self = args[0]
|
|
49
45
|
try:
|
|
50
|
-
is_method = inspect.getfullargspec(f)[0][0] ==
|
|
51
|
-
except:
|
|
46
|
+
is_method = inspect.getfullargspec(f)[0][0] == "self"
|
|
47
|
+
except Exception:
|
|
52
48
|
is_method = False
|
|
53
49
|
|
|
54
50
|
# Is the function we decorated a class method?
|
|
55
51
|
if is_method:
|
|
56
|
-
name =
|
|
57
|
-
class_obj_self.__class__.__name__,
|
|
58
|
-
f.__name__)
|
|
52
|
+
name = f"{f.__module__}.{class_obj_self.__class__.__name__}.{f.__name__}"
|
|
59
53
|
else:
|
|
60
|
-
name =
|
|
54
|
+
name = f"{f.__module__}.{f.__name__}"
|
|
61
55
|
|
|
62
56
|
if is_method:
|
|
63
57
|
# Splits name, e.g. 'biolmai.api.ESMFoldSingleChain.predict'
|
|
64
|
-
action_method_name = name.split(
|
|
58
|
+
action_method_name = name.split(".")[-1]
|
|
65
59
|
validate_endpoint_action(
|
|
66
60
|
class_obj_self.action_class_strings,
|
|
67
61
|
action_method_name,
|
|
68
|
-
class_obj_self.__class__.__name__
|
|
62
|
+
class_obj_self.__class__.__name__,
|
|
69
63
|
)
|
|
70
64
|
|
|
71
65
|
input_data = args[1]
|
|
@@ -73,35 +67,38 @@ def validate(f):
|
|
|
73
67
|
for c in class_obj_self.seq_classes:
|
|
74
68
|
# Validate input data against regex
|
|
75
69
|
if class_obj_self.multiprocess_threads:
|
|
76
|
-
validation = input_data.text.apply(text_validator, args=(c,
|
|
70
|
+
validation = input_data.text.apply(text_validator, args=(c,))
|
|
77
71
|
else:
|
|
78
|
-
validation = input_data.text.apply(text_validator, args=(c,
|
|
79
|
-
if
|
|
80
|
-
input_data[
|
|
72
|
+
validation = input_data.text.apply(text_validator, args=(c,))
|
|
73
|
+
if "validation" not in input_data.columns:
|
|
74
|
+
input_data["validation"] = validation
|
|
81
75
|
else:
|
|
82
|
-
input_data[
|
|
83
|
-
validation, sep=
|
|
76
|
+
input_data["validation"] = input_data["validation"].str.cat(
|
|
77
|
+
validation, sep="\n", na_rep=""
|
|
78
|
+
)
|
|
84
79
|
|
|
85
80
|
# Mark your batches, excluding invalid rows
|
|
86
81
|
valid_dat = input_data.loc[input_data.validation.isnull(), :].copy()
|
|
87
82
|
N = class_obj_self.batch_size # N rows will go per API request
|
|
88
83
|
# JOIN back, which is by index
|
|
89
84
|
if valid_dat.shape[0] != input_data.shape[0]:
|
|
90
|
-
valid_dat[
|
|
85
|
+
valid_dat["batch"] = np.arange(valid_dat.shape[0]) // N
|
|
91
86
|
input_data = input_data.merge(
|
|
92
|
-
valid_dat.batch, left_index=True, right_index=True, how=
|
|
87
|
+
valid_dat.batch, left_index=True, right_index=True, how="left"
|
|
88
|
+
)
|
|
93
89
|
else:
|
|
94
|
-
input_data[
|
|
90
|
+
input_data["batch"] = np.arange(input_data.shape[0]) // N
|
|
95
91
|
|
|
96
92
|
res = f(class_obj_self, input_data, **kwargs)
|
|
97
93
|
return res
|
|
94
|
+
|
|
98
95
|
return wrapper
|
|
99
96
|
|
|
100
97
|
|
|
101
98
|
def convert_input(f):
|
|
102
99
|
def wrapper(*args, **kwargs):
|
|
103
100
|
# Get the user-input data argument to the decorated function
|
|
104
|
-
class_obj_self = args[0]
|
|
101
|
+
# class_obj_self = args[0]
|
|
105
102
|
input_data = args[1]
|
|
106
103
|
# Make sure we have expected input types
|
|
107
104
|
acceptable_inputs = (str, list, tuple, np.ndarray, pd.DataFrame)
|
|
@@ -119,12 +116,13 @@ def convert_input(f):
|
|
|
119
116
|
if isinstance(input_data, pd.DataFrame) and len(input_data.shape) > 1:
|
|
120
117
|
err = "Detected Pandas DataFrame - input a single vector or Series"
|
|
121
118
|
raise AssertionError(err)
|
|
122
|
-
input_data = pd.DataFrame(input_data, columns=[
|
|
119
|
+
input_data = pd.DataFrame(input_data, columns=["text"])
|
|
123
120
|
return f(args[0], input_data, **kwargs)
|
|
121
|
+
|
|
124
122
|
return wrapper
|
|
125
123
|
|
|
126
124
|
|
|
127
|
-
class APIEndpoint
|
|
125
|
+
class APIEndpoint:
|
|
128
126
|
batch_size = 3 # Overwrite in parent classes as needed
|
|
129
127
|
|
|
130
128
|
def __init__(self, multiprocess_threads=None):
|
|
@@ -135,29 +133,25 @@ class APIEndpoint(object):
|
|
|
135
133
|
self.multiprocess_threads = MULTIPROCESS_THREADS # Could be False
|
|
136
134
|
# Get correct auth-like headers
|
|
137
135
|
self.auth_headers = biolmai.auth.get_user_auth_header()
|
|
138
|
-
self.action_class_strings = tuple(
|
|
139
|
-
c.__name__.replace(
|
|
140
|
-
|
|
136
|
+
self.action_class_strings = tuple(
|
|
137
|
+
[c.__name__.replace("Action", "").lower() for c in self.action_classes]
|
|
138
|
+
)
|
|
141
139
|
|
|
142
140
|
def post_batches(self, dat, slug, action, payload_maker, resp_key):
|
|
143
|
-
keep_batches = dat.loc[~dat.batch.isnull(), [
|
|
141
|
+
keep_batches = dat.loc[~dat.batch.isnull(), ["text", "batch"]]
|
|
144
142
|
if keep_batches.shape[0] == 0:
|
|
145
143
|
pass # Do nothing - we made nice JSON errors to return in the DF
|
|
146
144
|
# err = "No inputs found following local validation"
|
|
147
145
|
# raise AssertionError(err)
|
|
148
146
|
if keep_batches.shape[0] > 0:
|
|
149
147
|
api_resps = async_api_call_wrapper(
|
|
150
|
-
keep_batches,
|
|
151
|
-
slug,
|
|
152
|
-
action,
|
|
153
|
-
payload_maker,
|
|
154
|
-
resp_key
|
|
148
|
+
keep_batches, slug, action, payload_maker, resp_key
|
|
155
149
|
)
|
|
156
150
|
if isinstance(api_resps, pd.DataFrame):
|
|
157
|
-
batch_res = api_resps.explode(
|
|
151
|
+
batch_res = api_resps.explode("api_resp") # Should be lists of results
|
|
158
152
|
len_res = batch_res.shape[0]
|
|
159
153
|
else:
|
|
160
|
-
batch_res = pd.DataFrame({
|
|
154
|
+
batch_res = pd.DataFrame({"api_resp": api_resps})
|
|
161
155
|
len_res = batch_res.shape[0]
|
|
162
156
|
orig_request_rows = keep_batches.shape[0]
|
|
163
157
|
if len_res != orig_request_rows:
|
|
@@ -166,29 +160,31 @@ class APIEndpoint(object):
|
|
|
166
160
|
raise AssertionError(err)
|
|
167
161
|
|
|
168
162
|
# Stack the results horizontally w/ original rows of batches
|
|
169
|
-
keep_batches[
|
|
163
|
+
keep_batches["prev_idx"] = keep_batches.index
|
|
170
164
|
keep_batches.reset_index(drop=False, inplace=True)
|
|
171
165
|
batch_res.reset_index(drop=True, inplace=True)
|
|
172
|
-
keep_batches[
|
|
173
|
-
keep_batches.set_index(
|
|
174
|
-
dat = dat.join(keep_batches.reindex([
|
|
166
|
+
keep_batches["api_resp"] = batch_res
|
|
167
|
+
keep_batches.set_index("prev_idx", inplace=True)
|
|
168
|
+
dat = dat.join(keep_batches.reindex(["api_resp"], axis=1))
|
|
175
169
|
else:
|
|
176
|
-
dat[
|
|
170
|
+
dat["api_resp"] = None
|
|
177
171
|
return dat
|
|
178
172
|
|
|
179
173
|
def unpack_local_validations(self, dat):
|
|
180
|
-
dat.loc[
|
|
181
|
-
dat.
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
174
|
+
dat.loc[dat.api_resp.isnull(), "api_resp"] = (
|
|
175
|
+
dat.loc[~dat.validation.isnull(), "validation"]
|
|
176
|
+
.apply(
|
|
177
|
+
predict_resp_many_in_one_to_many_singles, args=(None, None, True, None)
|
|
178
|
+
)
|
|
179
|
+
.explode()
|
|
180
|
+
)
|
|
185
181
|
|
|
186
182
|
return dat
|
|
187
183
|
|
|
188
184
|
@convert_input
|
|
189
185
|
@validate
|
|
190
186
|
def predict(self, dat):
|
|
191
|
-
dat = self.post_batches(dat, self.slug,
|
|
187
|
+
dat = self.post_batches(dat, self.slug, "predict", INST_DAT_TXT, "predictions")
|
|
192
188
|
dat = self.unpack_local_validations(dat)
|
|
193
189
|
return dat.api_resp.replace(np.nan, None).tolist()
|
|
194
190
|
|
|
@@ -198,21 +194,33 @@ class APIEndpoint(object):
|
|
|
198
194
|
@convert_input
|
|
199
195
|
@validate
|
|
200
196
|
def transform(self, dat):
|
|
201
|
-
dat = self.post_batches(
|
|
197
|
+
dat = self.post_batches(
|
|
198
|
+
dat, self.slug, "transform", INST_DAT_TXT, "predictions"
|
|
199
|
+
)
|
|
202
200
|
dat = self.unpack_local_validations(dat)
|
|
203
201
|
return dat.api_resp.replace(np.nan, None).tolist()
|
|
204
202
|
|
|
203
|
+
# @convert_input
|
|
204
|
+
# @validate
|
|
205
|
+
# def encode(self, dat):
|
|
206
|
+
# # NOTE: we defined this for the specific case of ESM2
|
|
207
|
+
# # TODO: this will be need again in v2 of API contract
|
|
208
|
+
# dat = self.post_batches(dat, self.slug, "transform",
|
|
209
|
+
# INST_DAT_TXT, "embeddings")
|
|
210
|
+
# dat = self.unpack_local_validations(dat)
|
|
211
|
+
# return dat.api_resp.replace(np.nan, None).tolist()
|
|
212
|
+
|
|
205
213
|
@convert_input
|
|
206
214
|
@validate
|
|
207
215
|
def generate(self, dat):
|
|
208
|
-
dat = self.post_batches(dat, self.slug,
|
|
216
|
+
dat = self.post_batches(dat, self.slug, "generate", INST_DAT_TXT, "generated")
|
|
209
217
|
dat = self.unpack_local_validations(dat)
|
|
210
218
|
return dat.api_resp.replace(np.nan, None).tolist()
|
|
211
219
|
|
|
212
220
|
|
|
213
221
|
def retry_minutes(sess, URL, HEADERS, dat, timeout, mins):
|
|
214
222
|
"""Retry for N minutes."""
|
|
215
|
-
HEADERS.update({
|
|
223
|
+
HEADERS.update({"Content-Type": "application/json"})
|
|
216
224
|
attempts, max_attempts = 0, 5
|
|
217
225
|
try:
|
|
218
226
|
now = datetime.datetime.now()
|
|
@@ -220,16 +228,11 @@ def retry_minutes(sess, URL, HEADERS, dat, timeout, mins):
|
|
|
220
228
|
while datetime.datetime.now() < try_until and attempts < max_attempts:
|
|
221
229
|
response = None
|
|
222
230
|
try:
|
|
223
|
-
log.info(
|
|
224
|
-
response = sess.post(
|
|
225
|
-
URL,
|
|
226
|
-
headers=HEADERS,
|
|
227
|
-
data=dat,
|
|
228
|
-
timeout=timeout
|
|
229
|
-
)
|
|
231
|
+
log.info(f"Trying {datetime.datetime.now()}")
|
|
232
|
+
response = sess.post(URL, headers=HEADERS, data=dat, timeout=timeout)
|
|
230
233
|
if response.status_code not in (400, 404):
|
|
231
234
|
response.raise_for_status()
|
|
232
|
-
if
|
|
235
|
+
if "error" in response.json():
|
|
233
236
|
raise ValueError(response.json().dumps())
|
|
234
237
|
else:
|
|
235
238
|
break
|
|
@@ -242,10 +245,10 @@ def retry_minutes(sess, URL, HEADERS, dat, timeout, mins):
|
|
|
242
245
|
if response is None:
|
|
243
246
|
err = "Got Nonetype response"
|
|
244
247
|
raise ValueError(err)
|
|
245
|
-
elif
|
|
248
|
+
elif "Server Error" in response.text:
|
|
246
249
|
err = "Got Server Error"
|
|
247
250
|
raise ValueError(err)
|
|
248
|
-
except Exception
|
|
251
|
+
except Exception:
|
|
249
252
|
return response
|
|
250
253
|
return response
|
|
251
254
|
|
|
@@ -253,54 +256,55 @@ def retry_minutes(sess, URL, HEADERS, dat, timeout, mins):
|
|
|
253
256
|
def requests_retry_session(
|
|
254
257
|
retries=3,
|
|
255
258
|
backoff_factor=0.3,
|
|
256
|
-
status_forcelist=
|
|
259
|
+
status_forcelist=None,
|
|
257
260
|
session=None,
|
|
258
261
|
):
|
|
262
|
+
if status_forcelist is None:
|
|
263
|
+
status_forcelist = list(range(400, 599))
|
|
259
264
|
session = session or requests.Session()
|
|
260
265
|
retry = Retry(
|
|
261
266
|
total=retries,
|
|
262
267
|
read=retries,
|
|
263
268
|
connect=retries,
|
|
264
269
|
backoff_factor=backoff_factor,
|
|
265
|
-
status_forcelist=status_forcelist
|
|
270
|
+
status_forcelist=status_forcelist,
|
|
266
271
|
)
|
|
267
272
|
adapter = HTTPAdapter(max_retries=retry)
|
|
268
|
-
session.mount(
|
|
269
|
-
session.mount(
|
|
273
|
+
session.mount("http://", adapter)
|
|
274
|
+
session.mount("https://", adapter)
|
|
270
275
|
return session
|
|
271
276
|
|
|
272
277
|
|
|
273
|
-
class PredictAction
|
|
274
|
-
|
|
278
|
+
class PredictAction:
|
|
275
279
|
def __str__(self):
|
|
276
|
-
return
|
|
280
|
+
return "PredictAction"
|
|
277
281
|
|
|
278
282
|
|
|
279
|
-
class GenerateAction
|
|
280
|
-
|
|
283
|
+
class GenerateAction:
|
|
281
284
|
def __str__(self):
|
|
282
|
-
return
|
|
283
|
-
|
|
285
|
+
return "GenerateAction"
|
|
284
286
|
|
|
285
|
-
class TransformAction(object):
|
|
286
287
|
|
|
288
|
+
class TransformAction:
|
|
287
289
|
def __str__(self):
|
|
288
|
-
return
|
|
290
|
+
return "TransformAction"
|
|
289
291
|
|
|
290
292
|
|
|
291
|
-
class
|
|
293
|
+
# class EncodeAction:
|
|
294
|
+
# def __str__(self):
|
|
295
|
+
# return "EncodeAction"
|
|
292
296
|
|
|
293
|
-
def __str__(self):
|
|
294
|
-
return 'ExplainAction'
|
|
295
297
|
|
|
298
|
+
class ExplainAction:
|
|
299
|
+
def __str__(self):
|
|
300
|
+
return "ExplainAction"
|
|
296
301
|
|
|
297
|
-
class SimilarityAction(object):
|
|
298
302
|
|
|
303
|
+
class SimilarityAction:
|
|
299
304
|
def __str__(self):
|
|
300
|
-
return
|
|
301
|
-
|
|
305
|
+
return "SimilarityAction"
|
|
302
306
|
|
|
303
|
-
class FinetuneAction(object):
|
|
304
307
|
|
|
308
|
+
class FinetuneAction:
|
|
305
309
|
def __str__(self):
|
|
306
|
-
return
|
|
310
|
+
return "FinetuneAction"
|
biolmai/asynch.py
CHANGED
|
@@ -1,24 +1,15 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
from asyncio import create_task, gather, run
|
|
3
|
+
from itertools import zip_longest
|
|
4
|
+
from typing import Dict, List
|
|
5
|
+
|
|
1
6
|
import aiohttp.resolver
|
|
7
|
+
from aiohttp import ClientSession
|
|
2
8
|
|
|
3
9
|
from biolmai.auth import get_user_auth_header
|
|
4
10
|
from biolmai.const import BASE_API_URL, MULTIPROCESS_THREADS
|
|
5
11
|
|
|
6
12
|
aiohttp.resolver.DefaultResolver = aiohttp.resolver.AsyncResolver
|
|
7
|
-
from aiohttp import ClientSession, TCPConnector
|
|
8
|
-
from typing import List
|
|
9
|
-
import json
|
|
10
|
-
import asyncio
|
|
11
|
-
|
|
12
|
-
from asyncio import create_task, gather, run, sleep
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
async def get_one(session: ClientSession, slug: str, action: str,
|
|
17
|
-
payload: dict, response_key: str):
|
|
18
|
-
pass
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
from aiohttp import ClientSession
|
|
22
13
|
|
|
23
14
|
|
|
24
15
|
async def get_one(session: ClientSession, url: str) -> None:
|
|
@@ -31,25 +22,31 @@ async def get_one(session: ClientSession, url: str) -> None:
|
|
|
31
22
|
return text_resp
|
|
32
23
|
|
|
33
24
|
|
|
34
|
-
async def get_one_biolm(
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
25
|
+
async def get_one_biolm(
|
|
26
|
+
session: ClientSession,
|
|
27
|
+
url: str,
|
|
28
|
+
pload: dict,
|
|
29
|
+
headers: dict,
|
|
30
|
+
response_key: str = None,
|
|
31
|
+
) -> None:
|
|
39
32
|
print("Requesting", url)
|
|
40
|
-
pload_batch = pload.pop(
|
|
41
|
-
pload_batch_size = pload.pop(
|
|
33
|
+
pload_batch = pload.pop("batch")
|
|
34
|
+
pload_batch_size = pload.pop("batch_size")
|
|
42
35
|
t = aiohttp.ClientTimeout(
|
|
43
36
|
total=1600, # 27 mins
|
|
44
|
-
# total timeout (time consists connection establishment for
|
|
37
|
+
# total timeout (time consists connection establishment for
|
|
38
|
+
# a new connection or waiting for a free connection from a
|
|
39
|
+
# pool if pool connection limits are exceeded) default value
|
|
40
|
+
# is 5 minutes, set to `None` or `0` for unlimited timeout
|
|
45
41
|
sock_connect=None,
|
|
46
|
-
# Maximal number of seconds for connecting to a peer for a
|
|
42
|
+
# Maximal number of seconds for connecting to a peer for a
|
|
43
|
+
# new connection, not given from a pool. See also connect.
|
|
47
44
|
sock_read=None
|
|
48
45
|
# Maximal number of seconds for reading a portion of data from a peer
|
|
49
46
|
)
|
|
50
47
|
async with session.post(url, headers=headers, json=pload, timeout=t) as resp:
|
|
51
48
|
resp_json = await resp.json()
|
|
52
|
-
resp_json[
|
|
49
|
+
resp_json["batch"] = pload_batch
|
|
53
50
|
status_code = resp.status
|
|
54
51
|
expected_root_key = response_key
|
|
55
52
|
to_ret = []
|
|
@@ -62,9 +59,7 @@ async def get_one_biolm(session: ClientSession,
|
|
|
62
59
|
else:
|
|
63
60
|
raise ValueError("Unexpected response in parser")
|
|
64
61
|
for idx, item in enumerate(list_of_individual_seq_results):
|
|
65
|
-
d = {
|
|
66
|
-
'batch_id': pload_batch,
|
|
67
|
-
'batch_item': idx}
|
|
62
|
+
d = {"status_code": status_code, "batch_id": pload_batch, "batch_item": idx}
|
|
68
63
|
if not status_code or status_code != 200:
|
|
69
64
|
d.update(item) # Put all resp keys at root there
|
|
70
65
|
else:
|
|
@@ -78,16 +73,15 @@ async def get_one_biolm(session: ClientSession,
|
|
|
78
73
|
# await sleep(2) # for demo purposes
|
|
79
74
|
# text_resp = text.strip().split("\n", 1)[0]
|
|
80
75
|
# print("Got response from", url, text_resp)
|
|
81
|
-
return j
|
|
82
76
|
|
|
83
77
|
|
|
84
78
|
async def async_range(count):
|
|
85
79
|
for i in range(count):
|
|
86
|
-
yield(i)
|
|
80
|
+
yield (i)
|
|
87
81
|
await asyncio.sleep(0.0)
|
|
88
82
|
|
|
89
83
|
|
|
90
|
-
async def get_all(urls: List[str], num_concurrent: int) ->
|
|
84
|
+
async def get_all(urls: List[str], num_concurrent: int) -> list:
|
|
91
85
|
url_iterator = iter(urls)
|
|
92
86
|
keep_going = True
|
|
93
87
|
results = []
|
|
@@ -107,22 +101,26 @@ async def get_all(urls: List[str], num_concurrent: int) -> List:
|
|
|
107
101
|
return results
|
|
108
102
|
|
|
109
103
|
|
|
110
|
-
async def get_all_biolm(
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
104
|
+
async def get_all_biolm(
|
|
105
|
+
url: str,
|
|
106
|
+
ploads: List[Dict],
|
|
107
|
+
headers: dict,
|
|
108
|
+
num_concurrent: int,
|
|
109
|
+
response_key: str = None,
|
|
110
|
+
) -> list:
|
|
115
111
|
ploads_iterator = iter(ploads)
|
|
116
112
|
keep_going = True
|
|
117
113
|
results = []
|
|
118
|
-
connector = aiohttp.TCPConnector(limit=100,
|
|
119
|
-
limit_per_host=50,
|
|
120
|
-
ttl_dns_cache=60)
|
|
114
|
+
connector = aiohttp.TCPConnector(limit=100, limit_per_host=50, ttl_dns_cache=60)
|
|
121
115
|
ov_tout = aiohttp.ClientTimeout(
|
|
122
116
|
total=None,
|
|
123
|
-
# total timeout (time consists connection establishment for
|
|
117
|
+
# total timeout (time consists connection establishment for
|
|
118
|
+
# a new connection or waiting for a free connection from a
|
|
119
|
+
# pool if pool connection limits are exceeded) default value
|
|
120
|
+
# is 5 minutes, set to `None` or `0` for unlimited timeout
|
|
124
121
|
sock_connect=None,
|
|
125
|
-
# Maximal number of seconds for connecting to a peer for a
|
|
122
|
+
# Maximal number of seconds for connecting to a peer for a
|
|
123
|
+
# new connection, not given from a pool. See also connect.
|
|
126
124
|
sock_read=None
|
|
127
125
|
# Maximal number of seconds for reading a portion of data from a peer
|
|
128
126
|
)
|
|
@@ -135,35 +133,31 @@ async def get_all_biolm(url: str,
|
|
|
135
133
|
except StopIteration:
|
|
136
134
|
keep_going = False
|
|
137
135
|
break
|
|
138
|
-
new_task = create_task(
|
|
139
|
-
|
|
136
|
+
new_task = create_task(
|
|
137
|
+
get_one_biolm(session, url, pload, headers, response_key)
|
|
138
|
+
)
|
|
140
139
|
tasks.append(new_task)
|
|
141
140
|
res = await gather(*tasks)
|
|
142
141
|
results.extend(res)
|
|
143
142
|
return results
|
|
144
143
|
|
|
145
144
|
|
|
146
|
-
async def async_main(urls, concurrency) ->
|
|
145
|
+
async def async_main(urls, concurrency) -> list:
|
|
147
146
|
return await get_all(urls, concurrency)
|
|
148
147
|
|
|
149
148
|
|
|
150
|
-
async def async_api_calls(model_name,
|
|
151
|
-
action,
|
|
152
|
-
headers,
|
|
153
|
-
payloads,
|
|
154
|
-
response_key=None):
|
|
149
|
+
async def async_api_calls(model_name, action, headers, payloads, response_key=None):
|
|
155
150
|
"""Hit an arbitrary BioLM model inference API."""
|
|
156
151
|
# Normally would POST multiple sequences at once for greater efficiency,
|
|
157
152
|
# but for simplicity sake will do one at at time right now
|
|
158
|
-
url = f
|
|
153
|
+
url = f"{BASE_API_URL}/models/{model_name}/{action}/"
|
|
159
154
|
|
|
160
155
|
if not isinstance(payloads, (list, dict)):
|
|
161
156
|
err = "API request payload must be a list or dict, got {}"
|
|
162
157
|
raise AssertionError(err.format(type(payloads)))
|
|
163
158
|
|
|
164
159
|
concurrency = int(MULTIPROCESS_THREADS)
|
|
165
|
-
return await get_all_biolm(url, payloads, headers, concurrency,
|
|
166
|
-
response_key)
|
|
160
|
+
return await get_all_biolm(url, payloads, headers, concurrency, response_key)
|
|
167
161
|
|
|
168
162
|
# payload = json.dumps(payload)
|
|
169
163
|
# session = requests_retry_session()
|
|
@@ -186,33 +180,37 @@ async def async_api_calls(model_name,
|
|
|
186
180
|
# return response
|
|
187
181
|
|
|
188
182
|
|
|
189
|
-
def async_api_call_wrapper(grouped_df, slug, action, payload_maker,
|
|
190
|
-
response_key):
|
|
183
|
+
def async_api_call_wrapper(grouped_df, slug, action, payload_maker, response_key):
|
|
191
184
|
"""Wrap API calls to assist with sequence validation as a pre-cursor to
|
|
192
185
|
each API call.
|
|
193
186
|
"""
|
|
194
187
|
model_name = slug
|
|
195
188
|
# payload = payload_maker(grouped_df)
|
|
196
|
-
init_ploads = grouped_df.groupby(
|
|
189
|
+
init_ploads = grouped_df.groupby("batch").apply(
|
|
190
|
+
payload_maker, include_batch_size=True
|
|
191
|
+
)
|
|
197
192
|
ploads = init_ploads.to_list()
|
|
198
|
-
init_ploads = init_ploads.to_frame(name=
|
|
199
|
-
init_ploads[
|
|
193
|
+
init_ploads = init_ploads.to_frame(name="pload")
|
|
194
|
+
init_ploads["batch"] = init_ploads.index
|
|
200
195
|
init_ploads = init_ploads.reset_index(drop=True)
|
|
201
196
|
assert len(ploads) == init_ploads.shape[0]
|
|
202
|
-
for inst, b in
|
|
203
|
-
inst
|
|
197
|
+
for inst, b in zip_longest(ploads, init_ploads["batch"].to_list()):
|
|
198
|
+
if inst is None or b is None:
|
|
199
|
+
raise ValueError(
|
|
200
|
+
"ploads and init_ploads['batch'] are not of the same length"
|
|
201
|
+
)
|
|
202
|
+
inst["batch"] = b
|
|
204
203
|
|
|
205
204
|
headers = get_user_auth_header() # Need to pull each time
|
|
206
|
-
urls = [
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
]
|
|
205
|
+
# urls = [
|
|
206
|
+
# "https://github.com",
|
|
207
|
+
# "https://stackoverflow.com",
|
|
208
|
+
# "https://python.org",
|
|
209
|
+
# ]
|
|
211
210
|
# concurrency = 3
|
|
212
|
-
api_resp = run(async_api_calls(model_name, action, headers,
|
|
213
|
-
ploads, response_key))
|
|
211
|
+
api_resp = run(async_api_calls(model_name, action, headers, ploads, response_key))
|
|
214
212
|
api_resp = [item for sublist in api_resp for item in sublist]
|
|
215
|
-
api_resp = sorted(api_resp, key=lambda x: x[
|
|
213
|
+
api_resp = sorted(api_resp, key=lambda x: x["batch_id"])
|
|
216
214
|
# print(api_resp)
|
|
217
215
|
# api_resp = biolmai.api_call(model_name, action, headers, payload,
|
|
218
216
|
# response_key)
|