biolmai 0.1.3__tar.gz → 0.1.5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of biolmai might be problematic. Click here for more details.
- {biolmai-0.1.3 → biolmai-0.1.5}/PKG-INFO +1 -1
- biolmai-0.1.5/biolmai/__init__.py +11 -0
- {biolmai-0.1.3 → biolmai-0.1.5}/biolmai/api.py +122 -107
- biolmai-0.1.5/biolmai/asynch.py +224 -0
- {biolmai-0.1.3 → biolmai-0.1.5}/biolmai/auth.py +43 -0
- biolmai-0.1.5/biolmai/biolmai.py +7 -0
- biolmai-0.1.5/biolmai/cls.py +100 -0
- biolmai-0.1.5/biolmai/const.py +27 -0
- biolmai-0.1.5/biolmai/payloads.py +34 -0
- {biolmai-0.1.3 → biolmai-0.1.5}/biolmai.egg-info/PKG-INFO +1 -1
- {biolmai-0.1.3 → biolmai-0.1.5}/biolmai.egg-info/SOURCES.txt +17 -4
- {biolmai-0.1.3 → biolmai-0.1.5}/biolmai.egg-info/requires.txt +0 -1
- biolmai-0.1.5/docs/_static/api_reference_icon.png +0 -0
- biolmai-0.1.5/docs/_static/chat_agents_icon.png +0 -0
- biolmai-0.1.5/docs/_static/jupyter_notebooks_icon.png +0 -0
- biolmai-0.1.5/docs/_static/model_docs_icon.png +0 -0
- biolmai-0.1.5/docs/_static/python_sdk_icon.png +0 -0
- biolmai-0.1.5/docs/_static/tutorials_icon.png +0 -0
- {biolmai-0.1.3 → biolmai-0.1.5}/docs/biolmai.rst +3 -3
- {biolmai-0.1.3 → biolmai-0.1.5}/docs/conf.py +1 -0
- biolmai-0.1.5/docs/index.rst +107 -0
- biolmai-0.1.5/docs/model-docs/ESM-InverseFold.rst +294 -0
- biolmai-0.1.5/docs/model-docs/Progen-2-OAS.rst +254 -0
- biolmai-0.1.5/docs/model-docs/Progen-2_BFD90.rst +258 -0
- biolmai-0.1.5/docs/model-docs/Progen-2_Medium.rst +256 -0
- biolmai-0.1.5/docs/model-docs/ProteInfer_EC.rst +256 -0
- biolmai-0.1.5/docs/model-docs/ProteInfer_GO.rst +339 -0
- biolmai-0.1.5/docs/model-docs/esm_1v_masking.rst +372 -0
- biolmai-0.1.5/docs/model-docs/esm_suite/esm2_embeddings.rst +249 -0
- {biolmai-0.1.3/docs/model-docs → biolmai-0.1.5/docs/model-docs/esm_suite}/esm2_fold.rst +50 -42
- biolmai-0.1.5/docs/model-docs/index.rst +24 -0
- {biolmai-0.1.3 → biolmai-0.1.5}/docs/python-client/quickstart.rst +2 -0
- {biolmai-0.1.3 → biolmai-0.1.5}/setup.cfg +1 -1
- {biolmai-0.1.3 → biolmai-0.1.5}/setup.py +2 -3
- {biolmai-0.1.3 → biolmai-0.1.5}/tests/test_biolmai.py +43 -13
- biolmai-0.1.3/biolmai/__init__.py +0 -15
- biolmai-0.1.3/biolmai/async.py +0 -6
- biolmai-0.1.3/biolmai/biolmai.py +0 -153
- biolmai-0.1.3/biolmai/cls.py +0 -1
- biolmai-0.1.3/biolmai/const.py +0 -13
- biolmai-0.1.3/biolmai/payloads.py +0 -6
- biolmai-0.1.3/docs/index.rst +0 -74
- biolmai-0.1.3/docs/model-docs/admonitions.rst +0 -39
- biolmai-0.1.3/docs/model-docs/esm2_embeddings.rst +0 -10
- {biolmai-0.1.3 → biolmai-0.1.5}/AUTHORS.rst +0 -0
- {biolmai-0.1.3 → biolmai-0.1.5}/CONTRIBUTING.rst +0 -0
- {biolmai-0.1.3 → biolmai-0.1.5}/HISTORY.rst +0 -0
- {biolmai-0.1.3 → biolmai-0.1.5}/LICENSE +0 -0
- {biolmai-0.1.3 → biolmai-0.1.5}/MANIFEST.in +0 -0
- {biolmai-0.1.3 → biolmai-0.1.5}/README.rst +0 -0
- {biolmai-0.1.3 → biolmai-0.1.5}/biolmai/cli.py +0 -0
- {biolmai-0.1.3 → biolmai-0.1.5}/biolmai/ltc.py +0 -0
- {biolmai-0.1.3 → biolmai-0.1.5}/biolmai/validate.py +0 -0
- {biolmai-0.1.3 → biolmai-0.1.5}/biolmai.egg-info/dependency_links.txt +0 -0
- {biolmai-0.1.3 → biolmai-0.1.5}/biolmai.egg-info/entry_points.txt +0 -0
- {biolmai-0.1.3 → biolmai-0.1.5}/biolmai.egg-info/not-zip-safe +0 -0
- {biolmai-0.1.3 → biolmai-0.1.5}/biolmai.egg-info/top_level.txt +0 -0
- {biolmai-0.1.3 → biolmai-0.1.5}/docs/Makefile +0 -0
- {biolmai-0.1.3 → biolmai-0.1.5}/docs/_static/biolm_docs_logo_dark.png +0 -0
- {biolmai-0.1.3 → biolmai-0.1.5}/docs/_static/biolm_docs_logo_light.png +0 -0
- {biolmai-0.1.3 → biolmai-0.1.5}/docs/make.bat +0 -0
- {biolmai-0.1.3 → biolmai-0.1.5}/docs/model-docs/img/book_icon.png +0 -0
- {biolmai-0.1.3 → biolmai-0.1.5}/docs/model-docs/img/esmfold_perf.png +0 -0
- {biolmai-0.1.3 → biolmai-0.1.5}/docs/modules.rst +0 -0
- {biolmai-0.1.3 → biolmai-0.1.5}/docs/python-client/authors.rst +0 -0
- {biolmai-0.1.3 → biolmai-0.1.5}/docs/python-client/contributing.rst +0 -0
- {biolmai-0.1.3 → biolmai-0.1.5}/docs/python-client/history.rst +0 -0
- {biolmai-0.1.3 → biolmai-0.1.5}/docs/python-client/installation.rst +0 -0
- {biolmai-0.1.3 → biolmai-0.1.5}/docs/python-client/readme.rst +0 -0
- {biolmai-0.1.3 → biolmai-0.1.5}/docs/python-client/usage.rst +0 -0
- {biolmai-0.1.3 → biolmai-0.1.5}/docs/tutorials_use_cases/bulk_protein_folding.rst +0 -0
- {biolmai-0.1.3 → biolmai-0.1.5}/docs/tutorials_use_cases/dna_tutorials.rst +0 -0
- {biolmai-0.1.3 → biolmai-0.1.5}/docs/tutorials_use_cases/protein_tutorials.rst +0 -0
- {biolmai-0.1.3 → biolmai-0.1.5}/tests/__init__.py +0 -0
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
"""Top-level package for BioLM AI."""
|
|
2
|
+
__author__ = """Nikhil Haas"""
|
|
3
|
+
__email__ = 'nikhil@biolm.ai'
|
|
4
|
+
__version__ = '0.1.5'
|
|
5
|
+
|
|
6
|
+
from biolmai.auth import get_api_token
|
|
7
|
+
from biolmai.cls import ESMFoldSingleChain, ESMFoldMultiChain, ESM2Embeddings, ESM1v1, ESM1v2, ESM1v3, ESM1v4, ESM1v5
|
|
8
|
+
|
|
9
|
+
__all__ = [
|
|
10
|
+
|
|
11
|
+
]
|
|
@@ -1,61 +1,22 @@
|
|
|
1
1
|
"""References to API endpoints."""
|
|
2
|
-
|
|
2
|
+
import datetime
|
|
3
|
+
import time
|
|
4
|
+
|
|
5
|
+
import requests
|
|
6
|
+
from requests.adapters import HTTPAdapter
|
|
7
|
+
|
|
8
|
+
import biolmai.auth
|
|
9
|
+
import biolmai
|
|
3
10
|
import inspect
|
|
4
11
|
import pandas as pd
|
|
5
12
|
import numpy as np
|
|
13
|
+
from biolmai.asynch import async_api_call_wrapper
|
|
6
14
|
|
|
7
|
-
from biolmai.biolmai import
|
|
15
|
+
from biolmai.biolmai import log
|
|
8
16
|
from biolmai.const import MULTIPROCESS_THREADS
|
|
9
|
-
if MULTIPROCESS_THREADS:
|
|
10
|
-
from pandarallel import pandarallel
|
|
11
|
-
pandarallel.initialize(progress_bar=False,
|
|
12
|
-
nb_workers=int(MULTIPROCESS_THREADS), verbose=2)
|
|
13
17
|
from functools import lru_cache
|
|
14
18
|
|
|
15
|
-
from biolmai.payloads import INST_DAT_TXT
|
|
16
|
-
from biolmai.validate import UnambiguousAA
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
def predict_resp_many_in_one_to_many_singles(resp_json, status_code,
|
|
20
|
-
batch_id, local_err, batch_size):
|
|
21
|
-
expected_root_key = 'predictions'
|
|
22
|
-
to_ret = []
|
|
23
|
-
if not local_err and status_code and status_code == 200:
|
|
24
|
-
list_of_individual_seq_results = resp_json[expected_root_key]
|
|
25
|
-
elif local_err:
|
|
26
|
-
list_of_individual_seq_results = [{'error': resp_json}]
|
|
27
|
-
elif status_code and status_code != 200 and isinstance(resp_json, dict):
|
|
28
|
-
list_of_individual_seq_results = [resp_json] * batch_size
|
|
29
|
-
else:
|
|
30
|
-
raise ValueError("Unexpected response in parser")
|
|
31
|
-
for idx, item in enumerate(list_of_individual_seq_results):
|
|
32
|
-
d = {'status_code': status_code,
|
|
33
|
-
'batch_id': batch_id,
|
|
34
|
-
'batch_item': idx}
|
|
35
|
-
if not status_code or status_code != 200:
|
|
36
|
-
d.update(item) # Put all resp keys at root there
|
|
37
|
-
else:
|
|
38
|
-
# We just append one item, mimicking a single seq in POST req/resp
|
|
39
|
-
d[expected_root_key] = []
|
|
40
|
-
d[expected_root_key].append(item)
|
|
41
|
-
to_ret.append(d)
|
|
42
|
-
return to_ret
|
|
43
|
-
|
|
44
|
-
def api_call_wrapper(df, args):
|
|
45
|
-
"""Wrap API calls to assist with sequence validation as a pre-cursor to
|
|
46
|
-
each API call.
|
|
47
|
-
"""
|
|
48
|
-
model_name, action, payload_maker, response_key = args
|
|
49
|
-
payload = payload_maker(df)
|
|
50
|
-
headers = get_user_auth_header() # Need to pull each time
|
|
51
|
-
api_resp = biolmai.api_call(model_name, action, headers, payload,
|
|
52
|
-
response_key)
|
|
53
|
-
resp_json = api_resp.json()
|
|
54
|
-
batch_id = int(df.batch.iloc[0])
|
|
55
|
-
batch_size = df.shape[0]
|
|
56
|
-
response = predict_resp_many_in_one_to_many_singles(
|
|
57
|
-
resp_json, api_resp.status_code, batch_id, None, batch_size)
|
|
58
|
-
return response
|
|
19
|
+
from biolmai.payloads import INST_DAT_TXT, predict_resp_many_in_one_to_many_singles
|
|
59
20
|
|
|
60
21
|
|
|
61
22
|
@lru_cache(maxsize=64)
|
|
@@ -92,7 +53,8 @@ def validate(f):
|
|
|
92
53
|
|
|
93
54
|
# Is the function we decorated a class method?
|
|
94
55
|
if is_method:
|
|
95
|
-
name = '{}.{}.{}'.format(f.__module__,
|
|
56
|
+
name = '{}.{}.{}'.format(f.__module__,
|
|
57
|
+
class_obj_self.__class__.__name__,
|
|
96
58
|
f.__name__)
|
|
97
59
|
else:
|
|
98
60
|
name = '{}.{}'.format(f.__module__, f.__name__)
|
|
@@ -111,9 +73,9 @@ def validate(f):
|
|
|
111
73
|
for c in class_obj_self.seq_classes:
|
|
112
74
|
# Validate input data against regex
|
|
113
75
|
if class_obj_self.multiprocess_threads:
|
|
114
|
-
validation = input_data.text.
|
|
76
|
+
validation = input_data.text.apply(text_validator, args=(c, ))
|
|
115
77
|
else:
|
|
116
|
-
validation = input_data.text.apply(text_validator, args=(c
|
|
78
|
+
validation = input_data.text.apply(text_validator, args=(c, ))
|
|
117
79
|
if 'validation' not in input_data.columns:
|
|
118
80
|
input_data['validation'] = validation
|
|
119
81
|
else:
|
|
@@ -138,7 +100,7 @@ def validate(f):
|
|
|
138
100
|
|
|
139
101
|
def convert_input(f):
|
|
140
102
|
def wrapper(*args, **kwargs):
|
|
141
|
-
|
|
103
|
+
# Get the user-input data argument to the decorated function
|
|
142
104
|
class_obj_self = args[0]
|
|
143
105
|
input_data = args[1]
|
|
144
106
|
# Make sure we have expected input types
|
|
@@ -172,44 +134,35 @@ class APIEndpoint(object):
|
|
|
172
134
|
else:
|
|
173
135
|
self.multiprocess_threads = MULTIPROCESS_THREADS # Could be False
|
|
174
136
|
# Get correct auth-like headers
|
|
175
|
-
self.auth_headers = biolmai.get_user_auth_header()
|
|
137
|
+
self.auth_headers = biolmai.auth.get_user_auth_header()
|
|
176
138
|
self.action_class_strings = tuple([
|
|
177
139
|
c.__name__.replace('Action', '').lower() for c in self.action_classes
|
|
178
140
|
])
|
|
179
141
|
|
|
180
|
-
|
|
181
|
-
@validate
|
|
182
|
-
def predict(self, dat):
|
|
142
|
+
def post_batches(self, dat, slug, action, payload_maker, resp_key):
|
|
183
143
|
keep_batches = dat.loc[~dat.batch.isnull(), ['text', 'batch']]
|
|
184
144
|
if keep_batches.shape[0] == 0:
|
|
185
|
-
|
|
145
|
+
pass # Do nothing - we made nice JSON errors to return in the DF
|
|
146
|
+
# err = "No inputs found following local validation"
|
|
186
147
|
# raise AssertionError(err)
|
|
187
|
-
elif self.multiprocess_threads:
|
|
188
|
-
api_resps = keep_batches.groupby('batch').parallel_apply(
|
|
189
|
-
api_call_wrapper,
|
|
190
|
-
(
|
|
191
|
-
self.slug,
|
|
192
|
-
'predict',
|
|
193
|
-
INST_DAT_TXT,
|
|
194
|
-
'predictions'
|
|
195
|
-
),
|
|
196
|
-
)
|
|
197
|
-
else:
|
|
198
|
-
api_resps = keep_batches.groupby('batch').apply(
|
|
199
|
-
api_call_wrapper,
|
|
200
|
-
(
|
|
201
|
-
self.slug,
|
|
202
|
-
'predict',
|
|
203
|
-
INST_DAT_TXT,
|
|
204
|
-
'predictions'
|
|
205
|
-
),
|
|
206
|
-
)
|
|
207
148
|
if keep_batches.shape[0] > 0:
|
|
208
|
-
|
|
149
|
+
api_resps = async_api_call_wrapper(
|
|
150
|
+
keep_batches,
|
|
151
|
+
slug,
|
|
152
|
+
action,
|
|
153
|
+
payload_maker,
|
|
154
|
+
resp_key
|
|
155
|
+
)
|
|
156
|
+
if isinstance(api_resps, pd.DataFrame):
|
|
157
|
+
batch_res = api_resps.explode('api_resp') # Should be lists of results
|
|
158
|
+
len_res = batch_res.shape[0]
|
|
159
|
+
else:
|
|
160
|
+
batch_res = pd.DataFrame({'api_resp': api_resps})
|
|
161
|
+
len_res = batch_res.shape[0]
|
|
209
162
|
orig_request_rows = keep_batches.shape[0]
|
|
210
|
-
if
|
|
163
|
+
if len_res != orig_request_rows:
|
|
211
164
|
err = "Response rows ({}) mismatch with input rows ({})"
|
|
212
|
-
err = err.format(
|
|
165
|
+
err = err.format(len_res, orig_request_rows)
|
|
213
166
|
raise AssertionError(err)
|
|
214
167
|
|
|
215
168
|
# Stack the results horizontally w/ original rows of batches
|
|
@@ -221,28 +174,100 @@ class APIEndpoint(object):
|
|
|
221
174
|
dat = dat.join(keep_batches.reindex(['api_resp'], axis=1))
|
|
222
175
|
else:
|
|
223
176
|
dat['api_resp'] = None
|
|
177
|
+
return dat
|
|
224
178
|
|
|
179
|
+
def unpack_local_validations(self, dat):
|
|
225
180
|
dat.loc[
|
|
226
181
|
dat.api_resp.isnull(), 'api_resp'
|
|
227
182
|
] = dat.loc[~dat.validation.isnull(), 'validation'].apply(
|
|
228
183
|
predict_resp_many_in_one_to_many_singles,
|
|
229
184
|
args=(None, None, True, None)).explode()
|
|
230
185
|
|
|
186
|
+
return dat
|
|
187
|
+
|
|
188
|
+
@convert_input
|
|
189
|
+
@validate
|
|
190
|
+
def predict(self, dat):
|
|
191
|
+
dat = self.post_batches(dat, self.slug, 'predict', INST_DAT_TXT, 'predictions')
|
|
192
|
+
dat = self.unpack_local_validations(dat)
|
|
231
193
|
return dat.api_resp.replace(np.nan, None).tolist()
|
|
232
194
|
|
|
233
195
|
def infer(self, dat):
|
|
234
196
|
return self.predict(dat)
|
|
235
197
|
|
|
198
|
+
@convert_input
|
|
236
199
|
@validate
|
|
237
|
-
def
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
200
|
+
def transform(self, dat):
|
|
201
|
+
dat = self.post_batches(dat, self.slug, 'transform', INST_DAT_TXT, 'predictions')
|
|
202
|
+
dat = self.unpack_local_validations(dat)
|
|
203
|
+
return dat.api_resp.replace(np.nan, None).tolist()
|
|
204
|
+
|
|
205
|
+
@convert_input
|
|
206
|
+
@validate
|
|
207
|
+
def generate(self, dat):
|
|
208
|
+
dat = self.post_batches(dat, self.slug, 'generate', INST_DAT_TXT, 'generated')
|
|
209
|
+
dat = self.unpack_local_validations(dat)
|
|
210
|
+
return dat.api_resp.replace(np.nan, None).tolist()
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def retry_minutes(sess, URL, HEADERS, dat, timeout, mins):
|
|
214
|
+
"""Retry for N minutes."""
|
|
215
|
+
HEADERS.update({'Content-Type': 'application/json'})
|
|
216
|
+
attempts, max_attempts = 0, 5
|
|
217
|
+
try:
|
|
218
|
+
now = datetime.datetime.now()
|
|
219
|
+
try_until = now + datetime.timedelta(minutes=mins)
|
|
220
|
+
while datetime.datetime.now() < try_until and attempts < max_attempts:
|
|
221
|
+
response = None
|
|
222
|
+
try:
|
|
223
|
+
log.info('Trying {}'.format(datetime.datetime.now()))
|
|
224
|
+
response = sess.post(
|
|
225
|
+
URL,
|
|
226
|
+
headers=HEADERS,
|
|
227
|
+
data=dat,
|
|
228
|
+
timeout=timeout
|
|
229
|
+
)
|
|
230
|
+
if response.status_code not in (400, 404):
|
|
231
|
+
response.raise_for_status()
|
|
232
|
+
if 'error' in response.json():
|
|
233
|
+
raise ValueError(response.json().dumps())
|
|
234
|
+
else:
|
|
235
|
+
break
|
|
236
|
+
except Exception as e:
|
|
237
|
+
log.warning(e)
|
|
238
|
+
if response:
|
|
239
|
+
log.warning(response.text)
|
|
240
|
+
time.sleep(5) # Wait 5 seconds between tries
|
|
241
|
+
attempts += 1
|
|
242
|
+
if response is None:
|
|
243
|
+
err = "Got Nonetype response"
|
|
244
|
+
raise ValueError(err)
|
|
245
|
+
elif 'Server Error' in response.text:
|
|
246
|
+
err = "Got Server Error"
|
|
247
|
+
raise ValueError(err)
|
|
248
|
+
except Exception as e:
|
|
249
|
+
return response
|
|
250
|
+
return response
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
def requests_retry_session(
|
|
254
|
+
retries=3,
|
|
255
|
+
backoff_factor=0.3,
|
|
256
|
+
status_forcelist=list(range(400, 599)),
|
|
257
|
+
session=None,
|
|
258
|
+
):
|
|
259
|
+
session = session or requests.Session()
|
|
260
|
+
retry = Retry(
|
|
261
|
+
total=retries,
|
|
262
|
+
read=retries,
|
|
263
|
+
connect=retries,
|
|
264
|
+
backoff_factor=backoff_factor,
|
|
265
|
+
status_forcelist=status_forcelist
|
|
266
|
+
)
|
|
267
|
+
adapter = HTTPAdapter(max_retries=retry)
|
|
268
|
+
session.mount('http://', adapter)
|
|
269
|
+
session.mount('https://', adapter)
|
|
270
|
+
return session
|
|
246
271
|
|
|
247
272
|
|
|
248
273
|
class PredictAction(object):
|
|
@@ -250,21 +275,25 @@ class PredictAction(object):
|
|
|
250
275
|
def __str__(self):
|
|
251
276
|
return 'PredictAction'
|
|
252
277
|
|
|
278
|
+
|
|
253
279
|
class GenerateAction(object):
|
|
254
280
|
|
|
255
281
|
def __str__(self):
|
|
256
282
|
return 'GenerateAction'
|
|
257
283
|
|
|
258
|
-
|
|
284
|
+
|
|
285
|
+
class TransformAction(object):
|
|
259
286
|
|
|
260
287
|
def __str__(self):
|
|
261
|
-
return '
|
|
288
|
+
return 'TransformAction'
|
|
289
|
+
|
|
262
290
|
|
|
263
291
|
class ExplainAction(object):
|
|
264
292
|
|
|
265
293
|
def __str__(self):
|
|
266
294
|
return 'ExplainAction'
|
|
267
295
|
|
|
296
|
+
|
|
268
297
|
class SimilarityAction(object):
|
|
269
298
|
|
|
270
299
|
def __str__(self):
|
|
@@ -275,17 +304,3 @@ class FinetuneAction(object):
|
|
|
275
304
|
|
|
276
305
|
def __str__(self):
|
|
277
306
|
return 'FinetuneAction'
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
class ESMFoldSingleChain(APIEndpoint):
|
|
281
|
-
slug = 'esmfold-singlechain'
|
|
282
|
-
action_classes = (PredictAction, )
|
|
283
|
-
seq_classes = (UnambiguousAA, )
|
|
284
|
-
batch_size = 2
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
class ESMFoldMultiChain(APIEndpoint):
|
|
288
|
-
slug = 'esmfold-multichain'
|
|
289
|
-
action_classes = (PredictAction, )
|
|
290
|
-
seq_classes = (UnambiguousAA, )
|
|
291
|
-
batch_size = 2
|
|
@@ -0,0 +1,224 @@
|
|
|
1
|
+
import aiohttp.resolver
|
|
2
|
+
|
|
3
|
+
from biolmai.auth import get_user_auth_header
|
|
4
|
+
from biolmai.const import BASE_API_URL, MULTIPROCESS_THREADS
|
|
5
|
+
|
|
6
|
+
aiohttp.resolver.DefaultResolver = aiohttp.resolver.AsyncResolver
|
|
7
|
+
from aiohttp import ClientSession, TCPConnector
|
|
8
|
+
from typing import List
|
|
9
|
+
import json
|
|
10
|
+
import asyncio
|
|
11
|
+
|
|
12
|
+
from asyncio import create_task, gather, run, sleep
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
async def get_one(session: ClientSession, slug: str, action: str,
|
|
17
|
+
payload: dict, response_key: str):
|
|
18
|
+
pass
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
from aiohttp import ClientSession
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
async def get_one(session: ClientSession, url: str) -> None:
|
|
25
|
+
print("Requesting", url)
|
|
26
|
+
async with session.get(url) as resp:
|
|
27
|
+
text = await resp.text()
|
|
28
|
+
# await sleep(2) # for demo purposes
|
|
29
|
+
text_resp = text.strip().split("\n", 1)[0]
|
|
30
|
+
print("Got response from", url, text_resp)
|
|
31
|
+
return text_resp
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
async def get_one_biolm(session: ClientSession,
|
|
35
|
+
url: str,
|
|
36
|
+
pload: dict,
|
|
37
|
+
headers: dict,
|
|
38
|
+
response_key: str = None) -> None:
|
|
39
|
+
print("Requesting", url)
|
|
40
|
+
pload_batch = pload.pop('batch')
|
|
41
|
+
pload_batch_size = pload.pop('batch_size')
|
|
42
|
+
t = aiohttp.ClientTimeout(
|
|
43
|
+
total=1600, # 27 mins
|
|
44
|
+
# total timeout (time consists connection establishment for a new connection or waiting for a free connection from a pool if pool connection limits are exceeded) default value is 5 minutes, set to `None` or `0` for unlimited timeout
|
|
45
|
+
sock_connect=None,
|
|
46
|
+
# Maximal number of seconds for connecting to a peer for a new connection, not given from a pool. See also connect.
|
|
47
|
+
sock_read=None
|
|
48
|
+
# Maximal number of seconds for reading a portion of data from a peer
|
|
49
|
+
)
|
|
50
|
+
async with session.post(url, headers=headers, json=pload, timeout=t) as resp:
|
|
51
|
+
resp_json = await resp.json()
|
|
52
|
+
resp_json['batch'] = pload_batch
|
|
53
|
+
status_code = resp.status
|
|
54
|
+
expected_root_key = response_key
|
|
55
|
+
to_ret = []
|
|
56
|
+
if status_code and status_code == 200:
|
|
57
|
+
list_of_individual_seq_results = resp_json[expected_root_key]
|
|
58
|
+
# elif local_err:
|
|
59
|
+
# list_of_individual_seq_results = [{'error': resp_json}]
|
|
60
|
+
elif status_code and status_code != 200 and isinstance(resp_json, dict):
|
|
61
|
+
list_of_individual_seq_results = [resp_json] * pload_batch_size
|
|
62
|
+
else:
|
|
63
|
+
raise ValueError("Unexpected response in parser")
|
|
64
|
+
for idx, item in enumerate(list_of_individual_seq_results):
|
|
65
|
+
d = {'status_code': status_code,
|
|
66
|
+
'batch_id': pload_batch,
|
|
67
|
+
'batch_item': idx}
|
|
68
|
+
if not status_code or status_code != 200:
|
|
69
|
+
d.update(item) # Put all resp keys at root there
|
|
70
|
+
else:
|
|
71
|
+
# We just append one item, mimicking a single seq in POST req/resp
|
|
72
|
+
d[expected_root_key] = []
|
|
73
|
+
d[expected_root_key].append(item)
|
|
74
|
+
to_ret.append(d)
|
|
75
|
+
return to_ret
|
|
76
|
+
|
|
77
|
+
# text = await resp.text()
|
|
78
|
+
# await sleep(2) # for demo purposes
|
|
79
|
+
# text_resp = text.strip().split("\n", 1)[0]
|
|
80
|
+
# print("Got response from", url, text_resp)
|
|
81
|
+
return j
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
async def async_range(count):
|
|
85
|
+
for i in range(count):
|
|
86
|
+
yield(i)
|
|
87
|
+
await asyncio.sleep(0.0)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
async def get_all(urls: List[str], num_concurrent: int) -> List:
|
|
91
|
+
url_iterator = iter(urls)
|
|
92
|
+
keep_going = True
|
|
93
|
+
results = []
|
|
94
|
+
async with ClientSession() as session:
|
|
95
|
+
while keep_going:
|
|
96
|
+
tasks = []
|
|
97
|
+
for _ in range(num_concurrent):
|
|
98
|
+
try:
|
|
99
|
+
url = next(url_iterator)
|
|
100
|
+
except StopIteration:
|
|
101
|
+
keep_going = False
|
|
102
|
+
break
|
|
103
|
+
new_task = create_task(get_one(session, url))
|
|
104
|
+
tasks.append(new_task)
|
|
105
|
+
res = await gather(*tasks)
|
|
106
|
+
results.extend(res)
|
|
107
|
+
return results
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
async def get_all_biolm(url: str,
|
|
111
|
+
ploads: List[dict],
|
|
112
|
+
headers: dict,
|
|
113
|
+
num_concurrent: int,
|
|
114
|
+
response_key: str = None) -> List:
|
|
115
|
+
ploads_iterator = iter(ploads)
|
|
116
|
+
keep_going = True
|
|
117
|
+
results = []
|
|
118
|
+
connector = aiohttp.TCPConnector(limit=100,
|
|
119
|
+
limit_per_host=50,
|
|
120
|
+
ttl_dns_cache=60)
|
|
121
|
+
ov_tout = aiohttp.ClientTimeout(
|
|
122
|
+
total=None,
|
|
123
|
+
# total timeout (time consists connection establishment for a new connection or waiting for a free connection from a pool if pool connection limits are exceeded) default value is 5 minutes, set to `None` or `0` for unlimited timeout
|
|
124
|
+
sock_connect=None,
|
|
125
|
+
# Maximal number of seconds for connecting to a peer for a new connection, not given from a pool. See also connect.
|
|
126
|
+
sock_read=None
|
|
127
|
+
# Maximal number of seconds for reading a portion of data from a peer
|
|
128
|
+
)
|
|
129
|
+
async with ClientSession(connector=connector, timeout=ov_tout) as session:
|
|
130
|
+
while keep_going:
|
|
131
|
+
tasks = []
|
|
132
|
+
for _ in range(num_concurrent):
|
|
133
|
+
try:
|
|
134
|
+
pload = next(ploads_iterator)
|
|
135
|
+
except StopIteration:
|
|
136
|
+
keep_going = False
|
|
137
|
+
break
|
|
138
|
+
new_task = create_task(get_one_biolm(session, url, pload,
|
|
139
|
+
headers, response_key))
|
|
140
|
+
tasks.append(new_task)
|
|
141
|
+
res = await gather(*tasks)
|
|
142
|
+
results.extend(res)
|
|
143
|
+
return results
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
async def async_main(urls, concurrency) -> List:
|
|
147
|
+
return await get_all(urls, concurrency)
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
async def async_api_calls(model_name,
|
|
151
|
+
action,
|
|
152
|
+
headers,
|
|
153
|
+
payloads,
|
|
154
|
+
response_key=None):
|
|
155
|
+
"""Hit an arbitrary BioLM model inference API."""
|
|
156
|
+
# Normally would POST multiple sequences at once for greater efficiency,
|
|
157
|
+
# but for simplicity sake will do one at at time right now
|
|
158
|
+
url = f'{BASE_API_URL}/models/{model_name}/{action}/'
|
|
159
|
+
|
|
160
|
+
if not isinstance(payloads, (list, dict)):
|
|
161
|
+
err = "API request payload must be a list or dict, got {}"
|
|
162
|
+
raise AssertionError(err.format(type(payloads)))
|
|
163
|
+
|
|
164
|
+
concurrency = int(MULTIPROCESS_THREADS)
|
|
165
|
+
return await get_all_biolm(url, payloads, headers, concurrency,
|
|
166
|
+
response_key)
|
|
167
|
+
|
|
168
|
+
# payload = json.dumps(payload)
|
|
169
|
+
# session = requests_retry_session()
|
|
170
|
+
# tout = urllib3.util.Timeout(total=180, read=180)
|
|
171
|
+
# response = retry_minutes(session, url, headers, payload, tout, mins=10)
|
|
172
|
+
# # If token expired / invalid, attempt to refresh.
|
|
173
|
+
# if response.status_code == 401 and os.path.exists(ACCESS_TOK_PATH):
|
|
174
|
+
# # Add jitter to slow down in case we're multiprocessing so all threads
|
|
175
|
+
# # don't try to re-authenticate at once
|
|
176
|
+
# time.sleep(random.random() * 4)
|
|
177
|
+
# with open(ACCESS_TOK_PATH, 'r') as f:
|
|
178
|
+
# access_refresh_dict = json.load(f)
|
|
179
|
+
# refresh = access_refresh_dict.get('refresh')
|
|
180
|
+
# if not refresh_access_token(refresh):
|
|
181
|
+
# err = "Unauthenticated! Please run `biolmai status` to debug or " \
|
|
182
|
+
# "`biolmai login`."
|
|
183
|
+
# raise AssertionError(err)
|
|
184
|
+
# headers = get_user_auth_header() # Need to re-get these now
|
|
185
|
+
# response = retry_minutes(session, url, headers, payload, tout, mins=10)
|
|
186
|
+
# return response
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def async_api_call_wrapper(grouped_df, slug, action, payload_maker,
|
|
190
|
+
response_key):
|
|
191
|
+
"""Wrap API calls to assist with sequence validation as a pre-cursor to
|
|
192
|
+
each API call.
|
|
193
|
+
"""
|
|
194
|
+
model_name = slug
|
|
195
|
+
# payload = payload_maker(grouped_df)
|
|
196
|
+
init_ploads = grouped_df.groupby('batch').apply(payload_maker, include_batch_size=True)
|
|
197
|
+
ploads = init_ploads.to_list()
|
|
198
|
+
init_ploads = init_ploads.to_frame(name='pload')
|
|
199
|
+
init_ploads['batch'] = init_ploads.index
|
|
200
|
+
init_ploads = init_ploads.reset_index(drop=True)
|
|
201
|
+
assert len(ploads) == init_ploads.shape[0]
|
|
202
|
+
for inst, b in zip(ploads, init_ploads['batch'].to_list()):
|
|
203
|
+
inst['batch'] = b
|
|
204
|
+
|
|
205
|
+
headers = get_user_auth_header() # Need to pull each time
|
|
206
|
+
urls = [
|
|
207
|
+
"https://github.com",
|
|
208
|
+
"https://stackoverflow.com",
|
|
209
|
+
"https://python.org",
|
|
210
|
+
]
|
|
211
|
+
# concurrency = 3
|
|
212
|
+
api_resp = run(async_api_calls(model_name, action, headers,
|
|
213
|
+
ploads, response_key))
|
|
214
|
+
api_resp = [item for sublist in api_resp for item in sublist]
|
|
215
|
+
api_resp = sorted(api_resp, key=lambda x: x['batch_id'])
|
|
216
|
+
# print(api_resp)
|
|
217
|
+
# api_resp = biolmai.api_call(model_name, action, headers, payload,
|
|
218
|
+
# response_key)
|
|
219
|
+
# resp_json = api_resp.json()
|
|
220
|
+
# batch_id = int(grouped_df.batch.iloc[0])
|
|
221
|
+
# batch_size = grouped_df.shape[0]
|
|
222
|
+
# response = predict_resp_many_in_one_to_many_singles(
|
|
223
|
+
# resp_json, api_resp.status_code, batch_id, None, batch_size)
|
|
224
|
+
return api_resp
|
|
@@ -125,3 +125,46 @@ def save_access_refresh_token(access_refresh_dict):
|
|
|
125
125
|
access = access_refresh_dict.get('access')
|
|
126
126
|
refresh = access_refresh_dict.get('refresh')
|
|
127
127
|
validate_user_auth(access=access, refresh=refresh)
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def get_api_token():
|
|
131
|
+
"""Get a BioLM API token to use with future API requests.
|
|
132
|
+
|
|
133
|
+
Copied from https://api.biolm.ai/#d7f87dfd-321f-45ae-99b6-eb203519ddeb.
|
|
134
|
+
"""
|
|
135
|
+
url = "https://biolm.ai/api/auth/token/"
|
|
136
|
+
|
|
137
|
+
payload = json.dumps({
|
|
138
|
+
"username": os.environ.get("BIOLM_USER"),
|
|
139
|
+
"password": os.environ.get("BIOLM_PASSWORD")
|
|
140
|
+
})
|
|
141
|
+
headers = {
|
|
142
|
+
'Content-Type': 'application/json'
|
|
143
|
+
}
|
|
144
|
+
|
|
145
|
+
response = requests.request("POST", url, headers=headers, data=payload)
|
|
146
|
+
response_json = response.json()
|
|
147
|
+
|
|
148
|
+
return response_json
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def get_user_auth_header():
|
|
152
|
+
"""Returns a dict with the appropriate Authorization header, either using
|
|
153
|
+
an API token from BIOLMAI_TOKEN environment variable, or by reading the
|
|
154
|
+
credentials file at ~/.biolmai/credntials next."""
|
|
155
|
+
api_token = os.environ.get('BIOLMAI_TOKEN', None)
|
|
156
|
+
if api_token:
|
|
157
|
+
headers = {'Authorization': f'Token {api_token}'}
|
|
158
|
+
elif os.path.exists(ACCESS_TOK_PATH):
|
|
159
|
+
with open(ACCESS_TOK_PATH, 'r') as f:
|
|
160
|
+
access_refresh_dict = json.load(f)
|
|
161
|
+
access = access_refresh_dict.get('access')
|
|
162
|
+
refresh = access_refresh_dict.get('refresh')
|
|
163
|
+
headers = {
|
|
164
|
+
'Cookie': 'access={};refresh={}'.format(access, refresh),
|
|
165
|
+
'Content-Type': 'application/json'
|
|
166
|
+
}
|
|
167
|
+
else:
|
|
168
|
+
err = "No https://biolm.ai credentials found. Please run `biolmai status` to debug."
|
|
169
|
+
raise AssertionError(err)
|
|
170
|
+
return headers
|