biolmai 0.1.4__tar.gz → 0.1.5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of biolmai might be problematic. Click here for more details.
- {biolmai-0.1.4 → biolmai-0.1.5}/PKG-INFO +1 -1
- biolmai-0.1.5/biolmai/__init__.py +11 -0
- {biolmai-0.1.4 → biolmai-0.1.5}/biolmai/api.py +100 -188
- {biolmai-0.1.4 → biolmai-0.1.5}/biolmai/asynch.py +40 -1
- {biolmai-0.1.4 → biolmai-0.1.5}/biolmai/auth.py +43 -0
- biolmai-0.1.5/biolmai/biolmai.py +7 -0
- biolmai-0.1.5/biolmai/cls.py +100 -0
- biolmai-0.1.5/biolmai/payloads.py +34 -0
- {biolmai-0.1.4 → biolmai-0.1.5}/biolmai.egg-info/PKG-INFO +1 -1
- {biolmai-0.1.4 → biolmai-0.1.5}/biolmai.egg-info/SOURCES.txt +16 -3
- biolmai-0.1.5/docs/_static/api_reference_icon.png +0 -0
- biolmai-0.1.5/docs/_static/chat_agents_icon.png +0 -0
- biolmai-0.1.5/docs/_static/jupyter_notebooks_icon.png +0 -0
- biolmai-0.1.5/docs/_static/model_docs_icon.png +0 -0
- biolmai-0.1.5/docs/_static/python_sdk_icon.png +0 -0
- biolmai-0.1.5/docs/_static/tutorials_icon.png +0 -0
- {biolmai-0.1.4 → biolmai-0.1.5}/docs/conf.py +1 -0
- biolmai-0.1.5/docs/index.rst +107 -0
- biolmai-0.1.5/docs/model-docs/ESM-InverseFold.rst +294 -0
- biolmai-0.1.5/docs/model-docs/Progen-2-OAS.rst +254 -0
- biolmai-0.1.5/docs/model-docs/Progen-2_BFD90.rst +258 -0
- biolmai-0.1.5/docs/model-docs/Progen-2_Medium.rst +256 -0
- biolmai-0.1.5/docs/model-docs/ProteInfer_EC.rst +256 -0
- biolmai-0.1.5/docs/model-docs/ProteInfer_GO.rst +339 -0
- biolmai-0.1.5/docs/model-docs/esm_1v_masking.rst +372 -0
- biolmai-0.1.5/docs/model-docs/esm_suite/esm2_embeddings.rst +249 -0
- {biolmai-0.1.4/docs/model-docs → biolmai-0.1.5/docs/model-docs/esm_suite}/esm2_fold.rst +50 -42
- biolmai-0.1.5/docs/model-docs/index.rst +24 -0
- {biolmai-0.1.4 → biolmai-0.1.5}/docs/python-client/quickstart.rst +2 -0
- {biolmai-0.1.4 → biolmai-0.1.5}/setup.cfg +1 -1
- {biolmai-0.1.4 → biolmai-0.1.5}/setup.py +1 -1
- {biolmai-0.1.4 → biolmai-0.1.5}/tests/test_biolmai.py +33 -12
- biolmai-0.1.4/biolmai/__init__.py +0 -15
- biolmai-0.1.4/biolmai/biolmai.py +0 -153
- biolmai-0.1.4/biolmai/cls.py +0 -1
- biolmai-0.1.4/biolmai/payloads.py +0 -8
- biolmai-0.1.4/docs/index.rst +0 -74
- biolmai-0.1.4/docs/model-docs/admonitions.rst +0 -39
- biolmai-0.1.4/docs/model-docs/esm2_embeddings.rst +0 -10
- {biolmai-0.1.4 → biolmai-0.1.5}/AUTHORS.rst +0 -0
- {biolmai-0.1.4 → biolmai-0.1.5}/CONTRIBUTING.rst +0 -0
- {biolmai-0.1.4 → biolmai-0.1.5}/HISTORY.rst +0 -0
- {biolmai-0.1.4 → biolmai-0.1.5}/LICENSE +0 -0
- {biolmai-0.1.4 → biolmai-0.1.5}/MANIFEST.in +0 -0
- {biolmai-0.1.4 → biolmai-0.1.5}/README.rst +0 -0
- {biolmai-0.1.4 → biolmai-0.1.5}/biolmai/cli.py +0 -0
- {biolmai-0.1.4 → biolmai-0.1.5}/biolmai/const.py +0 -0
- {biolmai-0.1.4 → biolmai-0.1.5}/biolmai/ltc.py +0 -0
- {biolmai-0.1.4 → biolmai-0.1.5}/biolmai/validate.py +0 -0
- {biolmai-0.1.4 → biolmai-0.1.5}/biolmai.egg-info/dependency_links.txt +0 -0
- {biolmai-0.1.4 → biolmai-0.1.5}/biolmai.egg-info/entry_points.txt +0 -0
- {biolmai-0.1.4 → biolmai-0.1.5}/biolmai.egg-info/not-zip-safe +0 -0
- {biolmai-0.1.4 → biolmai-0.1.5}/biolmai.egg-info/requires.txt +0 -0
- {biolmai-0.1.4 → biolmai-0.1.5}/biolmai.egg-info/top_level.txt +0 -0
- {biolmai-0.1.4 → biolmai-0.1.5}/docs/Makefile +0 -0
- {biolmai-0.1.4 → biolmai-0.1.5}/docs/_static/biolm_docs_logo_dark.png +0 -0
- {biolmai-0.1.4 → biolmai-0.1.5}/docs/_static/biolm_docs_logo_light.png +0 -0
- {biolmai-0.1.4 → biolmai-0.1.5}/docs/biolmai.rst +0 -0
- {biolmai-0.1.4 → biolmai-0.1.5}/docs/make.bat +0 -0
- {biolmai-0.1.4 → biolmai-0.1.5}/docs/model-docs/img/book_icon.png +0 -0
- {biolmai-0.1.4 → biolmai-0.1.5}/docs/model-docs/img/esmfold_perf.png +0 -0
- {biolmai-0.1.4 → biolmai-0.1.5}/docs/modules.rst +0 -0
- {biolmai-0.1.4 → biolmai-0.1.5}/docs/python-client/authors.rst +0 -0
- {biolmai-0.1.4 → biolmai-0.1.5}/docs/python-client/contributing.rst +0 -0
- {biolmai-0.1.4 → biolmai-0.1.5}/docs/python-client/history.rst +0 -0
- {biolmai-0.1.4 → biolmai-0.1.5}/docs/python-client/installation.rst +0 -0
- {biolmai-0.1.4 → biolmai-0.1.5}/docs/python-client/readme.rst +0 -0
- {biolmai-0.1.4 → biolmai-0.1.5}/docs/python-client/usage.rst +0 -0
- {biolmai-0.1.4 → biolmai-0.1.5}/docs/tutorials_use_cases/bulk_protein_folding.rst +0 -0
- {biolmai-0.1.4 → biolmai-0.1.5}/docs/tutorials_use_cases/dna_tutorials.rst +0 -0
- {biolmai-0.1.4 → biolmai-0.1.5}/docs/tutorials_use_cases/protein_tutorials.rst +0 -0
- {biolmai-0.1.4 → biolmai-0.1.5}/tests/__init__.py +0 -0
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
"""Top-level package for BioLM AI."""
|
|
2
|
+
__author__ = """Nikhil Haas"""
|
|
3
|
+
__email__ = 'nikhil@biolm.ai'
|
|
4
|
+
__version__ = '0.1.5'
|
|
5
|
+
|
|
6
|
+
from biolmai.auth import get_api_token
|
|
7
|
+
from biolmai.cls import ESMFoldSingleChain, ESMFoldMultiChain, ESM2Embeddings, ESM1v1, ESM1v2, ESM1v3, ESM1v4, ESM1v5
|
|
8
|
+
|
|
9
|
+
__all__ = [
|
|
10
|
+
|
|
11
|
+
]
|
|
@@ -1,100 +1,22 @@
|
|
|
1
1
|
"""References to API endpoints."""
|
|
2
|
-
|
|
2
|
+
import datetime
|
|
3
|
+
import time
|
|
4
|
+
|
|
5
|
+
import requests
|
|
6
|
+
from requests.adapters import HTTPAdapter
|
|
7
|
+
|
|
8
|
+
import biolmai.auth
|
|
9
|
+
import biolmai
|
|
3
10
|
import inspect
|
|
4
11
|
import pandas as pd
|
|
5
12
|
import numpy as np
|
|
6
|
-
from
|
|
7
|
-
from biolmai.asynch import async_main, async_api_calls
|
|
13
|
+
from biolmai.asynch import async_api_call_wrapper
|
|
8
14
|
|
|
9
|
-
from biolmai.biolmai import
|
|
15
|
+
from biolmai.biolmai import log
|
|
10
16
|
from biolmai.const import MULTIPROCESS_THREADS
|
|
11
17
|
from functools import lru_cache
|
|
12
18
|
|
|
13
|
-
from biolmai.payloads import INST_DAT_TXT
|
|
14
|
-
from biolmai.validate import ExtendedAAPlusExtra, SingleOccurrenceOf, \
|
|
15
|
-
UnambiguousAA, \
|
|
16
|
-
UnambiguousAAPlusExtra
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
def predict_resp_many_in_one_to_many_singles(resp_json, status_code,
|
|
20
|
-
batch_id, local_err, batch_size):
|
|
21
|
-
expected_root_key = 'predictions'
|
|
22
|
-
to_ret = []
|
|
23
|
-
if not local_err and status_code and status_code == 200:
|
|
24
|
-
list_of_individual_seq_results = resp_json[expected_root_key]
|
|
25
|
-
elif local_err:
|
|
26
|
-
list_of_individual_seq_results = [{'error': resp_json}]
|
|
27
|
-
elif status_code and status_code != 200 and isinstance(resp_json, dict):
|
|
28
|
-
list_of_individual_seq_results = [resp_json] * batch_size
|
|
29
|
-
else:
|
|
30
|
-
raise ValueError("Unexpected response in parser")
|
|
31
|
-
for idx, item in enumerate(list_of_individual_seq_results):
|
|
32
|
-
d = {'status_code': status_code,
|
|
33
|
-
'batch_id': batch_id,
|
|
34
|
-
'batch_item': idx}
|
|
35
|
-
if not status_code or status_code != 200:
|
|
36
|
-
d.update(item) # Put all resp keys at root there
|
|
37
|
-
else:
|
|
38
|
-
# We just append one item, mimicking a single seq in POST req/resp
|
|
39
|
-
d[expected_root_key] = []
|
|
40
|
-
d[expected_root_key].append(item)
|
|
41
|
-
to_ret.append(d)
|
|
42
|
-
return to_ret
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
def async_api_call_wrapper(grouped_df, slug, action, payload_maker,
|
|
46
|
-
response_key):
|
|
47
|
-
"""Wrap API calls to assist with sequence validation as a pre-cursor to
|
|
48
|
-
each API call.
|
|
49
|
-
"""
|
|
50
|
-
model_name = slug
|
|
51
|
-
# payload = payload_maker(grouped_df)
|
|
52
|
-
init_ploads = grouped_df.groupby('batch').apply(payload_maker, include_batch_size=True)
|
|
53
|
-
ploads = init_ploads.to_list()
|
|
54
|
-
init_ploads = init_ploads.to_frame(name='pload')
|
|
55
|
-
init_ploads['batch'] = init_ploads.index
|
|
56
|
-
init_ploads = init_ploads.reset_index(drop=True)
|
|
57
|
-
assert len(ploads) == init_ploads.shape[0]
|
|
58
|
-
for inst, b in zip(ploads, init_ploads['batch'].to_list()):
|
|
59
|
-
inst['batch'] = b
|
|
60
|
-
|
|
61
|
-
headers = get_user_auth_header() # Need to pull each time
|
|
62
|
-
urls = [
|
|
63
|
-
"https://github.com",
|
|
64
|
-
"https://stackoverflow.com",
|
|
65
|
-
"https://python.org",
|
|
66
|
-
]
|
|
67
|
-
# concurrency = 3
|
|
68
|
-
api_resp = run(async_api_calls(model_name, action, headers,
|
|
69
|
-
ploads, response_key))
|
|
70
|
-
api_resp = [item for sublist in api_resp for item in sublist]
|
|
71
|
-
api_resp = sorted(api_resp, key=lambda x: x['batch_id'])
|
|
72
|
-
# print(api_resp)
|
|
73
|
-
# api_resp = biolmai.api_call(model_name, action, headers, payload,
|
|
74
|
-
# response_key)
|
|
75
|
-
# resp_json = api_resp.json()
|
|
76
|
-
# batch_id = int(grouped_df.batch.iloc[0])
|
|
77
|
-
# batch_size = grouped_df.shape[0]
|
|
78
|
-
# response = predict_resp_many_in_one_to_many_singles(
|
|
79
|
-
# resp_json, api_resp.status_code, batch_id, None, batch_size)
|
|
80
|
-
return api_resp
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
def api_call_wrapper(df, args):
|
|
84
|
-
"""Wrap API calls to assist with sequence validation as a pre-cursor to
|
|
85
|
-
each API call.
|
|
86
|
-
"""
|
|
87
|
-
model_name, action, payload_maker, response_key = args
|
|
88
|
-
payload = payload_maker(df)
|
|
89
|
-
headers = get_user_auth_header() # Need to pull each time
|
|
90
|
-
api_resp = biolmai.api_call(model_name, action, headers, payload,
|
|
91
|
-
response_key)
|
|
92
|
-
resp_json = api_resp.json()
|
|
93
|
-
batch_id = int(df.batch.iloc[0])
|
|
94
|
-
batch_size = df.shape[0]
|
|
95
|
-
response = predict_resp_many_in_one_to_many_singles(
|
|
96
|
-
resp_json, api_resp.status_code, batch_id, None, batch_size)
|
|
97
|
-
return response
|
|
19
|
+
from biolmai.payloads import INST_DAT_TXT, predict_resp_many_in_one_to_many_singles
|
|
98
20
|
|
|
99
21
|
|
|
100
22
|
@lru_cache(maxsize=64)
|
|
@@ -131,7 +53,8 @@ def validate(f):
|
|
|
131
53
|
|
|
132
54
|
# Is the function we decorated a class method?
|
|
133
55
|
if is_method:
|
|
134
|
-
name = '{}.{}.{}'.format(f.__module__,
|
|
56
|
+
name = '{}.{}.{}'.format(f.__module__,
|
|
57
|
+
class_obj_self.__class__.__name__,
|
|
135
58
|
f.__name__)
|
|
136
59
|
else:
|
|
137
60
|
name = '{}.{}'.format(f.__module__, f.__name__)
|
|
@@ -211,14 +134,12 @@ class APIEndpoint(object):
|
|
|
211
134
|
else:
|
|
212
135
|
self.multiprocess_threads = MULTIPROCESS_THREADS # Could be False
|
|
213
136
|
# Get correct auth-like headers
|
|
214
|
-
self.auth_headers = biolmai.get_user_auth_header()
|
|
137
|
+
self.auth_headers = biolmai.auth.get_user_auth_header()
|
|
215
138
|
self.action_class_strings = tuple([
|
|
216
139
|
c.__name__.replace('Action', '').lower() for c in self.action_classes
|
|
217
140
|
])
|
|
218
141
|
|
|
219
|
-
|
|
220
|
-
@validate
|
|
221
|
-
def predict(self, dat):
|
|
142
|
+
def post_batches(self, dat, slug, action, payload_maker, resp_key):
|
|
222
143
|
keep_batches = dat.loc[~dat.batch.isnull(), ['text', 'batch']]
|
|
223
144
|
if keep_batches.shape[0] == 0:
|
|
224
145
|
pass # Do nothing - we made nice JSON errors to return in the DF
|
|
@@ -227,10 +148,10 @@ class APIEndpoint(object):
|
|
|
227
148
|
if keep_batches.shape[0] > 0:
|
|
228
149
|
api_resps = async_api_call_wrapper(
|
|
229
150
|
keep_batches,
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
151
|
+
slug,
|
|
152
|
+
action,
|
|
153
|
+
payload_maker,
|
|
154
|
+
resp_key
|
|
234
155
|
)
|
|
235
156
|
if isinstance(api_resps, pd.DataFrame):
|
|
236
157
|
batch_res = api_resps.explode('api_resp') # Should be lists of results
|
|
@@ -253,28 +174,100 @@ class APIEndpoint(object):
|
|
|
253
174
|
dat = dat.join(keep_batches.reindex(['api_resp'], axis=1))
|
|
254
175
|
else:
|
|
255
176
|
dat['api_resp'] = None
|
|
177
|
+
return dat
|
|
256
178
|
|
|
179
|
+
def unpack_local_validations(self, dat):
|
|
257
180
|
dat.loc[
|
|
258
181
|
dat.api_resp.isnull(), 'api_resp'
|
|
259
182
|
] = dat.loc[~dat.validation.isnull(), 'validation'].apply(
|
|
260
183
|
predict_resp_many_in_one_to_many_singles,
|
|
261
184
|
args=(None, None, True, None)).explode()
|
|
262
185
|
|
|
186
|
+
return dat
|
|
187
|
+
|
|
188
|
+
@convert_input
|
|
189
|
+
@validate
|
|
190
|
+
def predict(self, dat):
|
|
191
|
+
dat = self.post_batches(dat, self.slug, 'predict', INST_DAT_TXT, 'predictions')
|
|
192
|
+
dat = self.unpack_local_validations(dat)
|
|
263
193
|
return dat.api_resp.replace(np.nan, None).tolist()
|
|
264
194
|
|
|
265
195
|
def infer(self, dat):
|
|
266
196
|
return self.predict(dat)
|
|
267
197
|
|
|
198
|
+
@convert_input
|
|
268
199
|
@validate
|
|
269
|
-
def
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
200
|
+
def transform(self, dat):
|
|
201
|
+
dat = self.post_batches(dat, self.slug, 'transform', INST_DAT_TXT, 'predictions')
|
|
202
|
+
dat = self.unpack_local_validations(dat)
|
|
203
|
+
return dat.api_resp.replace(np.nan, None).tolist()
|
|
204
|
+
|
|
205
|
+
@convert_input
|
|
206
|
+
@validate
|
|
207
|
+
def generate(self, dat):
|
|
208
|
+
dat = self.post_batches(dat, self.slug, 'generate', INST_DAT_TXT, 'generated')
|
|
209
|
+
dat = self.unpack_local_validations(dat)
|
|
210
|
+
return dat.api_resp.replace(np.nan, None).tolist()
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def retry_minutes(sess, URL, HEADERS, dat, timeout, mins):
|
|
214
|
+
"""Retry for N minutes."""
|
|
215
|
+
HEADERS.update({'Content-Type': 'application/json'})
|
|
216
|
+
attempts, max_attempts = 0, 5
|
|
217
|
+
try:
|
|
218
|
+
now = datetime.datetime.now()
|
|
219
|
+
try_until = now + datetime.timedelta(minutes=mins)
|
|
220
|
+
while datetime.datetime.now() < try_until and attempts < max_attempts:
|
|
221
|
+
response = None
|
|
222
|
+
try:
|
|
223
|
+
log.info('Trying {}'.format(datetime.datetime.now()))
|
|
224
|
+
response = sess.post(
|
|
225
|
+
URL,
|
|
226
|
+
headers=HEADERS,
|
|
227
|
+
data=dat,
|
|
228
|
+
timeout=timeout
|
|
229
|
+
)
|
|
230
|
+
if response.status_code not in (400, 404):
|
|
231
|
+
response.raise_for_status()
|
|
232
|
+
if 'error' in response.json():
|
|
233
|
+
raise ValueError(response.json().dumps())
|
|
234
|
+
else:
|
|
235
|
+
break
|
|
236
|
+
except Exception as e:
|
|
237
|
+
log.warning(e)
|
|
238
|
+
if response:
|
|
239
|
+
log.warning(response.text)
|
|
240
|
+
time.sleep(5) # Wait 5 seconds between tries
|
|
241
|
+
attempts += 1
|
|
242
|
+
if response is None:
|
|
243
|
+
err = "Got Nonetype response"
|
|
244
|
+
raise ValueError(err)
|
|
245
|
+
elif 'Server Error' in response.text:
|
|
246
|
+
err = "Got Server Error"
|
|
247
|
+
raise ValueError(err)
|
|
248
|
+
except Exception as e:
|
|
249
|
+
return response
|
|
250
|
+
return response
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
def requests_retry_session(
|
|
254
|
+
retries=3,
|
|
255
|
+
backoff_factor=0.3,
|
|
256
|
+
status_forcelist=list(range(400, 599)),
|
|
257
|
+
session=None,
|
|
258
|
+
):
|
|
259
|
+
session = session or requests.Session()
|
|
260
|
+
retry = Retry(
|
|
261
|
+
total=retries,
|
|
262
|
+
read=retries,
|
|
263
|
+
connect=retries,
|
|
264
|
+
backoff_factor=backoff_factor,
|
|
265
|
+
status_forcelist=status_forcelist
|
|
266
|
+
)
|
|
267
|
+
adapter = HTTPAdapter(max_retries=retry)
|
|
268
|
+
session.mount('http://', adapter)
|
|
269
|
+
session.mount('https://', adapter)
|
|
270
|
+
return session
|
|
278
271
|
|
|
279
272
|
|
|
280
273
|
class PredictAction(object):
|
|
@@ -311,84 +304,3 @@ class FinetuneAction(object):
|
|
|
311
304
|
|
|
312
305
|
def __str__(self):
|
|
313
306
|
return 'FinetuneAction'
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
class ESMFoldSingleChain(APIEndpoint):
|
|
317
|
-
slug = 'esmfold-singlechain'
|
|
318
|
-
action_classes = (PredictAction, )
|
|
319
|
-
seq_classes = (UnambiguousAA(), )
|
|
320
|
-
batch_size = 2
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
class ESMFoldMultiChain(APIEndpoint):
|
|
324
|
-
slug = 'esmfold-multichain'
|
|
325
|
-
action_classes = (PredictAction, )
|
|
326
|
-
seq_classes = (ExtendedAAPlusExtra(extra=[':']), )
|
|
327
|
-
batch_size = 2
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
class ESM2Embeddings(APIEndpoint):
|
|
331
|
-
"""Example.
|
|
332
|
-
|
|
333
|
-
```python
|
|
334
|
-
{
|
|
335
|
-
"instances": [{
|
|
336
|
-
"data": {"text": "MSILVTRPSPAGEELVSRLRTLGQVAWHFPLIEFSPGQQLPQ"}
|
|
337
|
-
}]
|
|
338
|
-
}
|
|
339
|
-
```
|
|
340
|
-
"""
|
|
341
|
-
slug = 'esm2_t33_650M_UR50D'
|
|
342
|
-
action_classes = (TransformAction,)
|
|
343
|
-
seq_classes = (UnambiguousAA(), )
|
|
344
|
-
batch_size = 3
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
class ESM1v1(APIEndpoint):
|
|
348
|
-
"""Example.
|
|
349
|
-
|
|
350
|
-
```python
|
|
351
|
-
{
|
|
352
|
-
"instances": [{
|
|
353
|
-
"data": {"text": "QERLEUTGR<mask>SLGYNIVAT"}
|
|
354
|
-
}]
|
|
355
|
-
}
|
|
356
|
-
```
|
|
357
|
-
"""
|
|
358
|
-
slug = 'esm1v_t33_650M_UR90S_1'
|
|
359
|
-
action_classes = (PredictAction, )
|
|
360
|
-
seq_classes = (SingleOccurrenceOf('<mask>'),
|
|
361
|
-
ExtendedAAPlusExtra(extra=['<mask>']))
|
|
362
|
-
batch_size = 5
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
class ESM1v2(APIEndpoint):
|
|
366
|
-
slug = 'esm1v_t33_650M_UR90S_2'
|
|
367
|
-
action_classes = (PredictAction, )
|
|
368
|
-
seq_classes = (SingleOccurrenceOf('<mask>'),
|
|
369
|
-
ExtendedAAPlusExtra(extra=['<mask>']))
|
|
370
|
-
batch_size = 5
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
class ESM1v3(APIEndpoint):
|
|
374
|
-
slug = 'esm1v_t33_650M_UR90S_3'
|
|
375
|
-
action_classes = (PredictAction, )
|
|
376
|
-
seq_classes = (SingleOccurrenceOf('<mask>'),
|
|
377
|
-
ExtendedAAPlusExtra(extra=['<mask>']))
|
|
378
|
-
batch_size = 5
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
class ESM1v4(APIEndpoint):
|
|
382
|
-
slug = 'esm1v_t33_650M_UR90S_4'
|
|
383
|
-
action_classes = (PredictAction, )
|
|
384
|
-
seq_classes = (SingleOccurrenceOf('<mask>'),
|
|
385
|
-
ExtendedAAPlusExtra(extra=['<mask>']))
|
|
386
|
-
batch_size = 5
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
class ESM1v5(APIEndpoint):
|
|
390
|
-
slug = 'esm1v_t33_650M_UR90S_5'
|
|
391
|
-
action_classes = (PredictAction, )
|
|
392
|
-
seq_classes = (SingleOccurrenceOf('<mask>'),
|
|
393
|
-
ExtendedAAPlusExtra(extra=['<mask>']))
|
|
394
|
-
batch_size = 5
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import aiohttp.resolver
|
|
2
2
|
|
|
3
|
+
from biolmai.auth import get_user_auth_header
|
|
3
4
|
from biolmai.const import BASE_API_URL, MULTIPROCESS_THREADS
|
|
4
5
|
|
|
5
6
|
aiohttp.resolver.DefaultResolver = aiohttp.resolver.AsyncResolver
|
|
@@ -39,7 +40,7 @@ async def get_one_biolm(session: ClientSession,
|
|
|
39
40
|
pload_batch = pload.pop('batch')
|
|
40
41
|
pload_batch_size = pload.pop('batch_size')
|
|
41
42
|
t = aiohttp.ClientTimeout(
|
|
42
|
-
total=
|
|
43
|
+
total=1600, # 27 mins
|
|
43
44
|
# total timeout (time consists connection establishment for a new connection or waiting for a free connection from a pool if pool connection limits are exceeded) default value is 5 minutes, set to `None` or `0` for unlimited timeout
|
|
44
45
|
sock_connect=None,
|
|
45
46
|
# Maximal number of seconds for connecting to a peer for a new connection, not given from a pool. See also connect.
|
|
@@ -183,3 +184,41 @@ async def async_api_calls(model_name,
|
|
|
183
184
|
# headers = get_user_auth_header() # Need to re-get these now
|
|
184
185
|
# response = retry_minutes(session, url, headers, payload, tout, mins=10)
|
|
185
186
|
# return response
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def async_api_call_wrapper(grouped_df, slug, action, payload_maker,
|
|
190
|
+
response_key):
|
|
191
|
+
"""Wrap API calls to assist with sequence validation as a pre-cursor to
|
|
192
|
+
each API call.
|
|
193
|
+
"""
|
|
194
|
+
model_name = slug
|
|
195
|
+
# payload = payload_maker(grouped_df)
|
|
196
|
+
init_ploads = grouped_df.groupby('batch').apply(payload_maker, include_batch_size=True)
|
|
197
|
+
ploads = init_ploads.to_list()
|
|
198
|
+
init_ploads = init_ploads.to_frame(name='pload')
|
|
199
|
+
init_ploads['batch'] = init_ploads.index
|
|
200
|
+
init_ploads = init_ploads.reset_index(drop=True)
|
|
201
|
+
assert len(ploads) == init_ploads.shape[0]
|
|
202
|
+
for inst, b in zip(ploads, init_ploads['batch'].to_list()):
|
|
203
|
+
inst['batch'] = b
|
|
204
|
+
|
|
205
|
+
headers = get_user_auth_header() # Need to pull each time
|
|
206
|
+
urls = [
|
|
207
|
+
"https://github.com",
|
|
208
|
+
"https://stackoverflow.com",
|
|
209
|
+
"https://python.org",
|
|
210
|
+
]
|
|
211
|
+
# concurrency = 3
|
|
212
|
+
api_resp = run(async_api_calls(model_name, action, headers,
|
|
213
|
+
ploads, response_key))
|
|
214
|
+
api_resp = [item for sublist in api_resp for item in sublist]
|
|
215
|
+
api_resp = sorted(api_resp, key=lambda x: x['batch_id'])
|
|
216
|
+
# print(api_resp)
|
|
217
|
+
# api_resp = biolmai.api_call(model_name, action, headers, payload,
|
|
218
|
+
# response_key)
|
|
219
|
+
# resp_json = api_resp.json()
|
|
220
|
+
# batch_id = int(grouped_df.batch.iloc[0])
|
|
221
|
+
# batch_size = grouped_df.shape[0]
|
|
222
|
+
# response = predict_resp_many_in_one_to_many_singles(
|
|
223
|
+
# resp_json, api_resp.status_code, batch_id, None, batch_size)
|
|
224
|
+
return api_resp
|
|
@@ -125,3 +125,46 @@ def save_access_refresh_token(access_refresh_dict):
|
|
|
125
125
|
access = access_refresh_dict.get('access')
|
|
126
126
|
refresh = access_refresh_dict.get('refresh')
|
|
127
127
|
validate_user_auth(access=access, refresh=refresh)
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def get_api_token():
|
|
131
|
+
"""Get a BioLM API token to use with future API requests.
|
|
132
|
+
|
|
133
|
+
Copied from https://api.biolm.ai/#d7f87dfd-321f-45ae-99b6-eb203519ddeb.
|
|
134
|
+
"""
|
|
135
|
+
url = "https://biolm.ai/api/auth/token/"
|
|
136
|
+
|
|
137
|
+
payload = json.dumps({
|
|
138
|
+
"username": os.environ.get("BIOLM_USER"),
|
|
139
|
+
"password": os.environ.get("BIOLM_PASSWORD")
|
|
140
|
+
})
|
|
141
|
+
headers = {
|
|
142
|
+
'Content-Type': 'application/json'
|
|
143
|
+
}
|
|
144
|
+
|
|
145
|
+
response = requests.request("POST", url, headers=headers, data=payload)
|
|
146
|
+
response_json = response.json()
|
|
147
|
+
|
|
148
|
+
return response_json
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def get_user_auth_header():
|
|
152
|
+
"""Returns a dict with the appropriate Authorization header, either using
|
|
153
|
+
an API token from BIOLMAI_TOKEN environment variable, or by reading the
|
|
154
|
+
credentials file at ~/.biolmai/credntials next."""
|
|
155
|
+
api_token = os.environ.get('BIOLMAI_TOKEN', None)
|
|
156
|
+
if api_token:
|
|
157
|
+
headers = {'Authorization': f'Token {api_token}'}
|
|
158
|
+
elif os.path.exists(ACCESS_TOK_PATH):
|
|
159
|
+
with open(ACCESS_TOK_PATH, 'r') as f:
|
|
160
|
+
access_refresh_dict = json.load(f)
|
|
161
|
+
access = access_refresh_dict.get('access')
|
|
162
|
+
refresh = access_refresh_dict.get('refresh')
|
|
163
|
+
headers = {
|
|
164
|
+
'Cookie': 'access={};refresh={}'.format(access, refresh),
|
|
165
|
+
'Content-Type': 'application/json'
|
|
166
|
+
}
|
|
167
|
+
else:
|
|
168
|
+
err = "No https://biolm.ai credentials found. Please run `biolmai status` to debug."
|
|
169
|
+
raise AssertionError(err)
|
|
170
|
+
return headers
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
"""API inference classes."""
|
|
2
|
+
from biolmai.api import APIEndpoint, PredictAction, TransformAction, GenerateAction
|
|
3
|
+
from biolmai.validate import UnambiguousAA, ExtendedAAPlusExtra, SingleOccurrenceOf
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class ESMFoldSingleChain(APIEndpoint):
|
|
7
|
+
slug = 'esmfold-singlechain'
|
|
8
|
+
action_classes = (PredictAction, )
|
|
9
|
+
seq_classes = (UnambiguousAA(), )
|
|
10
|
+
batch_size = 2
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class ESMFoldMultiChain(APIEndpoint):
|
|
14
|
+
slug = 'esmfold-multichain'
|
|
15
|
+
action_classes = (PredictAction, )
|
|
16
|
+
seq_classes = (ExtendedAAPlusExtra(extra=[':']), )
|
|
17
|
+
batch_size = 2
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class ESM2Embeddings(APIEndpoint):
|
|
21
|
+
"""Example.
|
|
22
|
+
|
|
23
|
+
.. highlight:: python
|
|
24
|
+
.. code-block:: python
|
|
25
|
+
|
|
26
|
+
{
|
|
27
|
+
"instances": [{
|
|
28
|
+
"data": {"text": "MSILVTRPSPAGEELVSRLRTLGQVAWHFPLIEFSPGQQLPQ"}
|
|
29
|
+
}]
|
|
30
|
+
}
|
|
31
|
+
"""
|
|
32
|
+
slug = 'esm2_t33_650M_UR50D'
|
|
33
|
+
action_classes = (TransformAction,)
|
|
34
|
+
seq_classes = (UnambiguousAA(), )
|
|
35
|
+
batch_size = 1
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class ESM1v1(APIEndpoint):
|
|
39
|
+
"""Example.
|
|
40
|
+
|
|
41
|
+
.. highlight:: python
|
|
42
|
+
.. code-block:: python
|
|
43
|
+
|
|
44
|
+
{
|
|
45
|
+
"instances": [{
|
|
46
|
+
"data": {"text": "QERLEUTGR<mask>SLGYNIVAT"}
|
|
47
|
+
}]
|
|
48
|
+
}
|
|
49
|
+
"""
|
|
50
|
+
slug = 'esm1v_t33_650M_UR90S_1'
|
|
51
|
+
action_classes = (PredictAction, )
|
|
52
|
+
seq_classes = (SingleOccurrenceOf('<mask>'),
|
|
53
|
+
ExtendedAAPlusExtra(extra=['<mask>']))
|
|
54
|
+
batch_size = 5
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class ESM1v2(APIEndpoint):
|
|
58
|
+
slug = 'esm1v_t33_650M_UR90S_2'
|
|
59
|
+
action_classes = (PredictAction, )
|
|
60
|
+
seq_classes = (SingleOccurrenceOf('<mask>'),
|
|
61
|
+
ExtendedAAPlusExtra(extra=['<mask>']))
|
|
62
|
+
batch_size = 5
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class ESM1v3(APIEndpoint):
|
|
66
|
+
slug = 'esm1v_t33_650M_UR90S_3'
|
|
67
|
+
action_classes = (PredictAction, )
|
|
68
|
+
seq_classes = (SingleOccurrenceOf('<mask>'),
|
|
69
|
+
ExtendedAAPlusExtra(extra=['<mask>']))
|
|
70
|
+
batch_size = 5
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class ESM1v4(APIEndpoint):
|
|
74
|
+
slug = 'esm1v_t33_650M_UR90S_4'
|
|
75
|
+
action_classes = (PredictAction, )
|
|
76
|
+
seq_classes = (SingleOccurrenceOf('<mask>'),
|
|
77
|
+
ExtendedAAPlusExtra(extra=['<mask>']))
|
|
78
|
+
batch_size = 5
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class ESM1v5(APIEndpoint):
|
|
82
|
+
slug = 'esm1v_t33_650M_UR90S_5'
|
|
83
|
+
action_classes = (PredictAction, )
|
|
84
|
+
seq_classes = (SingleOccurrenceOf('<mask>'),
|
|
85
|
+
ExtendedAAPlusExtra(extra=['<mask>']))
|
|
86
|
+
batch_size = 5
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class ESMIF1(APIEndpoint):
|
|
90
|
+
slug = 'esmif1'
|
|
91
|
+
action_classes = (GenerateAction, )
|
|
92
|
+
seq_classes = tuple([])
|
|
93
|
+
batch_size = 2
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class Progen2(APIEndpoint):
|
|
97
|
+
slug = 'progen2'
|
|
98
|
+
action_classes = (GenerateAction, )
|
|
99
|
+
seq_classes = tuple([])
|
|
100
|
+
batch_size = 1
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
def INST_DAT_TXT(batch, include_batch_size=False):
|
|
2
|
+
d = {"instances": []}
|
|
3
|
+
for idx, row in batch.iterrows():
|
|
4
|
+
inst = {"data": {"text": row.text}}
|
|
5
|
+
d['instances'].append(inst)
|
|
6
|
+
if include_batch_size is True:
|
|
7
|
+
d['batch_size'] = len(d['instances'])
|
|
8
|
+
return d
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def predict_resp_many_in_one_to_many_singles(resp_json, status_code,
|
|
12
|
+
batch_id, local_err, batch_size):
|
|
13
|
+
expected_root_key = 'predictions'
|
|
14
|
+
to_ret = []
|
|
15
|
+
if not local_err and status_code and status_code == 200:
|
|
16
|
+
list_of_individual_seq_results = resp_json[expected_root_key]
|
|
17
|
+
elif local_err:
|
|
18
|
+
list_of_individual_seq_results = [{'error': resp_json}]
|
|
19
|
+
elif status_code and status_code != 200 and isinstance(resp_json, dict):
|
|
20
|
+
list_of_individual_seq_results = [resp_json] * batch_size
|
|
21
|
+
else:
|
|
22
|
+
raise ValueError("Unexpected response in parser")
|
|
23
|
+
for idx, item in enumerate(list_of_individual_seq_results):
|
|
24
|
+
d = {'status_code': status_code,
|
|
25
|
+
'batch_id': batch_id,
|
|
26
|
+
'batch_item': idx}
|
|
27
|
+
if not status_code or status_code != 200:
|
|
28
|
+
d.update(item) # Put all resp keys at root there
|
|
29
|
+
else:
|
|
30
|
+
# We just append one item, mimicking a single seq in POST req/resp
|
|
31
|
+
d[expected_root_key] = []
|
|
32
|
+
d[expected_root_key].append(item)
|
|
33
|
+
to_ret.append(d)
|
|
34
|
+
return to_ret
|
|
@@ -30,11 +30,24 @@ docs/conf.py
|
|
|
30
30
|
docs/index.rst
|
|
31
31
|
docs/make.bat
|
|
32
32
|
docs/modules.rst
|
|
33
|
+
docs/_static/api_reference_icon.png
|
|
33
34
|
docs/_static/biolm_docs_logo_dark.png
|
|
34
35
|
docs/_static/biolm_docs_logo_light.png
|
|
35
|
-
docs/
|
|
36
|
-
docs/
|
|
37
|
-
docs/
|
|
36
|
+
docs/_static/chat_agents_icon.png
|
|
37
|
+
docs/_static/jupyter_notebooks_icon.png
|
|
38
|
+
docs/_static/model_docs_icon.png
|
|
39
|
+
docs/_static/python_sdk_icon.png
|
|
40
|
+
docs/_static/tutorials_icon.png
|
|
41
|
+
docs/model-docs/ESM-InverseFold.rst
|
|
42
|
+
docs/model-docs/Progen-2-OAS.rst
|
|
43
|
+
docs/model-docs/Progen-2_BFD90.rst
|
|
44
|
+
docs/model-docs/Progen-2_Medium.rst
|
|
45
|
+
docs/model-docs/ProteInfer_EC.rst
|
|
46
|
+
docs/model-docs/ProteInfer_GO.rst
|
|
47
|
+
docs/model-docs/esm_1v_masking.rst
|
|
48
|
+
docs/model-docs/index.rst
|
|
49
|
+
docs/model-docs/esm_suite/esm2_embeddings.rst
|
|
50
|
+
docs/model-docs/esm_suite/esm2_fold.rst
|
|
38
51
|
docs/model-docs/img/book_icon.png
|
|
39
52
|
docs/model-docs/img/esmfold_perf.png
|
|
40
53
|
docs/python-client/authors.rst
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|