biolmai 0.1.8__tar.gz → 0.1.10__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of biolmai might be problematic. Click here for more details.
- {biolmai-0.1.8 → biolmai-0.1.10}/PKG-INFO +2 -1
- {biolmai-0.1.8 → biolmai-0.1.10}/biolmai/__init__.py +2 -1
- {biolmai-0.1.8 → biolmai-0.1.10}/biolmai/api.py +121 -84
- {biolmai-0.1.8 → biolmai-0.1.10}/biolmai/asynch.py +16 -8
- biolmai-0.1.10/biolmai/cls.py +176 -0
- {biolmai-0.1.8 → biolmai-0.1.10}/biolmai/const.py +2 -1
- {biolmai-0.1.8 → biolmai-0.1.10}/biolmai/payloads.py +13 -2
- biolmai-0.1.10/biolmai/validate.py +159 -0
- {biolmai-0.1.8 → biolmai-0.1.10}/biolmai.egg-info/PKG-INFO +2 -1
- biolmai-0.1.10/biolmai.egg-info/SOURCES.txt +83 -0
- {biolmai-0.1.8 → biolmai-0.1.10}/biolmai.egg-info/requires.txt +3 -0
- {biolmai-0.1.8 → biolmai-0.1.10}/docs/conf.py +1 -1
- {biolmai-0.1.8 → biolmai-0.1.10}/docs/index.rst +11 -8
- biolmai-0.1.10/docs/model-docs/ablang/AbLang_API.rst +200 -0
- biolmai-0.1.10/docs/model-docs/ablang/AbLang_Additional.rst +94 -0
- biolmai-0.1.10/docs/model-docs/ablang/index.rst +10 -0
- biolmai-0.1.10/docs/model-docs/biolmtox/BioLMTox_API.rst +293 -0
- biolmai-0.1.10/docs/model-docs/biolmtox/BioLMTox_Additional.rst +62 -0
- biolmai-0.1.10/docs/model-docs/biolmtox/index.rst +10 -0
- biolmai-0.1.10/docs/model-docs/dnabert/DNABERT_Additional.rst +59 -0
- biolmai-0.1.8/docs/model-docs/DNABERT.rst → biolmai-0.1.10/docs/model-docs/dnabert/classifier_ft.rst +34 -67
- biolmai-0.1.10/docs/model-docs/dnabert/index.rst +10 -0
- biolmai-0.1.10/docs/model-docs/esm1v/ESM-1v_API.rst +196 -0
- biolmai-0.1.10/docs/model-docs/esm1v/ESM-1v_Additional.rst +89 -0
- biolmai-0.1.10/docs/model-docs/esm1v/index.rst +10 -0
- biolmai-0.1.10/docs/model-docs/esm2/ESM2_API.rst +450 -0
- biolmai-0.1.10/docs/model-docs/esm2/ESM2_Additional.rst +99 -0
- biolmai-0.1.10/docs/model-docs/esm2/index.rst +10 -0
- biolmai-0.1.10/docs/model-docs/esmfold/ESMFold_API.rst +166 -0
- biolmai-0.1.10/docs/model-docs/esmfold/ESMFold_Additional.rst +108 -0
- biolmai-0.1.10/docs/model-docs/esmfold/index.rst +10 -0
- biolmai-0.1.8/docs/model-docs/ESM_InverseFold.rst → biolmai-0.1.10/docs/model-docs/esmif/ESM_InverseFold_API.rst +97 -163
- biolmai-0.1.10/docs/model-docs/esmif/ESM_InverseFold_Additional.rst +104 -0
- biolmai-0.1.10/docs/model-docs/esmif/index.rst +10 -0
- biolmai-0.1.10/docs/model-docs/finetuning/index.rst +11 -0
- biolmai-0.1.10/docs/model-docs/progen2/ProGen2_API.rst +222 -0
- biolmai-0.1.8/docs/model-docs/progen2/ProGen2_OAS.rst → biolmai-0.1.10/docs/model-docs/progen2/ProGen2_Additional.rst +20 -196
- biolmai-0.1.10/docs/model-docs/proteinfer/ProteInfer_API.rst +181 -0
- biolmai-0.1.10/docs/model-docs/proteinfer/ProteInfer_Additional.rst +92 -0
- biolmai-0.1.10/docs/model-docs/proteinfer/index.rst +10 -0
- biolmai-0.1.10/docs/model-docs/protgpt2/ProtGPT2.rst +75 -0
- biolmai-0.1.8/docs/model-docs/ProtGPT2.rst → biolmai-0.1.10/docs/model-docs/protgpt2/generator_ft.rst +32 -68
- biolmai-0.1.10/docs/model-docs/protgpt2/index.rst +10 -0
- biolmai-0.1.10/docs/python-client/get_started/authentication.rst +82 -0
- {biolmai-0.1.8 → biolmai-0.1.10}/docs/python-client/get_started/installation.rst +2 -2
- biolmai-0.1.10/docs/python-client/get_started/quickstart.rst +25 -0
- {biolmai-0.1.8 → biolmai-0.1.10}/docs/python-client/index.rst +1 -1
- {biolmai-0.1.8 → biolmai-0.1.10}/setup.cfg +1 -1
- {biolmai-0.1.8 → biolmai-0.1.10}/setup.py +4 -1
- biolmai-0.1.10/tests/test_biolmai.py +498 -0
- biolmai-0.1.8/biolmai/cls.py +0 -97
- biolmai-0.1.8/biolmai/validate.py +0 -134
- biolmai-0.1.8/biolmai.egg-info/SOURCES.txt +0 -64
- biolmai-0.1.8/docs/model-docs/ESM-1v.rst +0 -362
- biolmai-0.1.8/docs/model-docs/ESM2_Embeddings.rst +0 -242
- biolmai-0.1.8/docs/model-docs/ESMFold.rst +0 -252
- biolmai-0.1.8/docs/model-docs/ProteInfer_EC.rst +0 -249
- biolmai-0.1.8/docs/model-docs/ProteInfer_GO.rst +0 -329
- biolmai-0.1.8/docs/model-docs/progen2/ProGen2_BFD90.rst +0 -251
- biolmai-0.1.8/docs/model-docs/progen2/ProGen2_Medium.rst +0 -248
- biolmai-0.1.8/docs/python-client/get_started/authorization.rst +0 -9
- biolmai-0.1.8/docs/python-client/get_started/quickstart.rst +0 -15
- biolmai-0.1.8/tests/test_biolmai.py +0 -263
- {biolmai-0.1.8 → biolmai-0.1.10}/AUTHORS.rst +0 -0
- {biolmai-0.1.8 → biolmai-0.1.10}/CONTRIBUTING.rst +0 -0
- {biolmai-0.1.8 → biolmai-0.1.10}/HISTORY.rst +0 -0
- {biolmai-0.1.8 → biolmai-0.1.10}/LICENSE +0 -0
- {biolmai-0.1.8 → biolmai-0.1.10}/MANIFEST.in +0 -0
- {biolmai-0.1.8 → biolmai-0.1.10}/README.rst +0 -0
- {biolmai-0.1.8 → biolmai-0.1.10}/biolmai/auth.py +0 -0
- {biolmai-0.1.8 → biolmai-0.1.10}/biolmai/biolmai.py +0 -0
- {biolmai-0.1.8 → biolmai-0.1.10}/biolmai/cli.py +0 -0
- {biolmai-0.1.8 → biolmai-0.1.10}/biolmai/ltc.py +0 -0
- {biolmai-0.1.8 → biolmai-0.1.10}/biolmai.egg-info/dependency_links.txt +0 -0
- {biolmai-0.1.8 → biolmai-0.1.10}/biolmai.egg-info/entry_points.txt +0 -0
- {biolmai-0.1.8 → biolmai-0.1.10}/biolmai.egg-info/not-zip-safe +0 -0
- {biolmai-0.1.8 → biolmai-0.1.10}/biolmai.egg-info/top_level.txt +0 -0
- {biolmai-0.1.8 → biolmai-0.1.10}/docs/Makefile +0 -0
- {biolmai-0.1.8 → biolmai-0.1.10}/docs/_static/api_reference_icon.png +0 -0
- {biolmai-0.1.8 → biolmai-0.1.10}/docs/_static/biolm_docs_logo_dark.png +0 -0
- {biolmai-0.1.8 → biolmai-0.1.10}/docs/_static/biolm_docs_logo_light.png +0 -0
- {biolmai-0.1.8 → biolmai-0.1.10}/docs/_static/chat_agents_icon.png +0 -0
- {biolmai-0.1.8 → biolmai-0.1.10}/docs/_static/jupyter_notebooks_icon.png +0 -0
- {biolmai-0.1.8 → biolmai-0.1.10}/docs/_static/model_docs_icon.png +0 -0
- {biolmai-0.1.8 → biolmai-0.1.10}/docs/_static/python_sdk_icon.png +0 -0
- {biolmai-0.1.8 → biolmai-0.1.10}/docs/_static/tutorials_icon.png +0 -0
- {biolmai-0.1.8 → biolmai-0.1.10}/docs/biolmai.rst +0 -0
- {biolmai-0.1.8 → biolmai-0.1.10}/docs/make.bat +0 -0
- {biolmai-0.1.8 → biolmai-0.1.10}/docs/model-docs/img/book_icon.png +0 -0
- {biolmai-0.1.8 → biolmai-0.1.10}/docs/model-docs/img/esmfold_perf.png +0 -0
- {biolmai-0.1.8 → biolmai-0.1.10}/docs/model-docs/index.rst +0 -0
- {biolmai-0.1.8 → biolmai-0.1.10}/docs/model-docs/progen2/index.rst +0 -0
- {biolmai-0.1.8 → biolmai-0.1.10}/docs/modules.rst +0 -0
- {biolmai-0.1.8 → biolmai-0.1.10}/docs/python-client/usage.rst +0 -0
- {biolmai-0.1.8 → biolmai-0.1.10}/docs/tutorials_use_cases/notebooks.rst +0 -0
- {biolmai-0.1.8 → biolmai-0.1.10}/pyproject.toml +0 -0
- {biolmai-0.1.8 → biolmai-0.1.10}/tests/__init__.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: biolmai
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.10
|
|
4
4
|
Summary: Python client and SDK for https://biolm.ai
|
|
5
5
|
Home-page: https://github.com/BioLM/py-biolm
|
|
6
6
|
Author: Nikhil Haas
|
|
@@ -18,6 +18,7 @@ Classifier: Programming Language :: Python :: 3.9
|
|
|
18
18
|
Classifier: Programming Language :: Python :: 3.10
|
|
19
19
|
Classifier: Programming Language :: Python :: 3.11
|
|
20
20
|
Requires-Python: >=3.6
|
|
21
|
+
Provides-Extra: aiodns
|
|
21
22
|
License-File: LICENSE
|
|
22
23
|
License-File: AUTHORS.rst
|
|
23
24
|
|
|
@@ -15,7 +15,7 @@ import biolmai.auth
|
|
|
15
15
|
from biolmai.asynch import async_api_call_wrapper
|
|
16
16
|
from biolmai.biolmai import log
|
|
17
17
|
from biolmai.const import MULTIPROCESS_THREADS
|
|
18
|
-
from biolmai.payloads import INST_DAT_TXT, predict_resp_many_in_one_to_many_singles
|
|
18
|
+
from biolmai.payloads import INST_DAT_TXT, PARAMS_ITEMS, predict_resp_many_in_one_to_many_singles
|
|
19
19
|
|
|
20
20
|
|
|
21
21
|
@lru_cache(maxsize=64)
|
|
@@ -35,65 +35,82 @@ def text_validator(text, c):
|
|
|
35
35
|
except Exception as e:
|
|
36
36
|
return str(e)
|
|
37
37
|
|
|
38
|
+
def combine_validation(x, y):
|
|
39
|
+
if x is None and y is None:
|
|
40
|
+
return None
|
|
41
|
+
elif isinstance(x, str) and y is None:
|
|
42
|
+
return x
|
|
43
|
+
elif x is None and isinstance(y, str):
|
|
44
|
+
return y
|
|
45
|
+
elif isinstance(x, str) and isinstance(y, str):
|
|
46
|
+
return f"{x}\n{y}"
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def validate_action(action):
|
|
50
|
+
def validate(f):
|
|
51
|
+
def wrapper(*args, **kwargs):
|
|
52
|
+
# Get class instance at runtime, so you can access not just
|
|
53
|
+
# APIEndpoints, but any *parent* classes of that,
|
|
54
|
+
# like ESMFoldSinglechain.
|
|
55
|
+
class_obj_self = args[0]
|
|
56
|
+
try:
|
|
57
|
+
is_method = inspect.getfullargspec(f)[0][0] == "self"
|
|
58
|
+
except Exception:
|
|
59
|
+
is_method = False
|
|
38
60
|
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
# APIEndpoints, but any *parent* classes of that,
|
|
43
|
-
# like ESMFoldSinglechain.
|
|
44
|
-
class_obj_self = args[0]
|
|
45
|
-
try:
|
|
46
|
-
is_method = inspect.getfullargspec(f)[0][0] == "self"
|
|
47
|
-
except Exception:
|
|
48
|
-
is_method = False
|
|
49
|
-
|
|
50
|
-
# Is the function we decorated a class method?
|
|
51
|
-
if is_method:
|
|
52
|
-
name = f"{f.__module__}.{class_obj_self.__class__.__name__}.{f.__name__}"
|
|
53
|
-
else:
|
|
54
|
-
name = f"{f.__module__}.{f.__name__}"
|
|
55
|
-
|
|
56
|
-
if is_method:
|
|
57
|
-
# Splits name, e.g. 'biolmai.api.ESMFoldSingleChain.predict'
|
|
58
|
-
action_method_name = name.split(".")[-1]
|
|
59
|
-
validate_endpoint_action(
|
|
60
|
-
class_obj_self.action_class_strings,
|
|
61
|
-
action_method_name,
|
|
62
|
-
class_obj_self.__class__.__name__,
|
|
63
|
-
)
|
|
64
|
-
|
|
65
|
-
input_data = args[1]
|
|
66
|
-
# Validate each row's text/input based on class attribute `seq_classes`
|
|
67
|
-
for c in class_obj_self.seq_classes:
|
|
68
|
-
# Validate input data against regex
|
|
69
|
-
if class_obj_self.multiprocess_threads:
|
|
70
|
-
validation = input_data.text.apply(text_validator, args=(c,))
|
|
61
|
+
# Is the function we decorated a class method?
|
|
62
|
+
if is_method:
|
|
63
|
+
name = f"{f.__module__}.{class_obj_self.__class__.__name__}.{f.__name__}"
|
|
71
64
|
else:
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
65
|
+
name = f"{f.__module__}.{f.__name__}"
|
|
66
|
+
|
|
67
|
+
if is_method:
|
|
68
|
+
# Splits name, e.g. 'biolmai.api.ESMFoldSingleChain.predict'
|
|
69
|
+
action_method_name = name.split(".")[-1]
|
|
70
|
+
validate_endpoint_action(
|
|
71
|
+
class_obj_self.action_class_strings,
|
|
72
|
+
action_method_name,
|
|
73
|
+
class_obj_self.__class__.__name__,
|
|
78
74
|
)
|
|
79
75
|
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
76
|
+
input_data = args[1]
|
|
77
|
+
# Validate each row's text/input based on class attribute `seq_classes`
|
|
78
|
+
if action == "predict":
|
|
79
|
+
input_classes = class_obj_self.predict_input_classes
|
|
80
|
+
elif action == "encode":
|
|
81
|
+
input_classes = class_obj_self.encode_input_classes
|
|
82
|
+
elif action == "generate":
|
|
83
|
+
input_classes = class_obj_self.generate_input_classes
|
|
84
|
+
elif action == "transform":
|
|
85
|
+
input_classes = class_obj_self.transform_input_classes
|
|
86
|
+
for c in input_classes:
|
|
87
|
+
# Validate input data against regex
|
|
88
|
+
if class_obj_self.multiprocess_threads:
|
|
89
|
+
validation = input_data.text.apply(text_validator, args=(c,))
|
|
90
|
+
else:
|
|
91
|
+
validation = input_data.text.apply(text_validator, args=(c,))
|
|
92
|
+
if "validation" not in input_data.columns:
|
|
93
|
+
input_data["validation"] = validation
|
|
94
|
+
else:
|
|
95
|
+
# masking and loc may be more performant option
|
|
96
|
+
input_data["validation"] = input_data["validation"].combine(validation, combine_validation)
|
|
97
|
+
|
|
98
|
+
# Mark your batches, excluding invalid rows
|
|
99
|
+
valid_dat = input_data.loc[input_data.validation.isnull(), :].copy()
|
|
100
|
+
N = class_obj_self.batch_size # N rows will go per API request
|
|
101
|
+
# JOIN back, which is by index
|
|
102
|
+
if valid_dat.shape[0] != input_data.shape[0]:
|
|
103
|
+
valid_dat["batch"] = np.arange(valid_dat.shape[0]) // N
|
|
104
|
+
input_data = input_data.merge(
|
|
105
|
+
valid_dat.batch, left_index=True, right_index=True, how="left"
|
|
106
|
+
)
|
|
107
|
+
else:
|
|
108
|
+
input_data["batch"] = np.arange(input_data.shape[0]) // N
|
|
109
|
+
res = f(class_obj_self, input_data, **kwargs)
|
|
110
|
+
return res
|
|
96
111
|
|
|
112
|
+
return wrapper
|
|
113
|
+
return validate
|
|
97
114
|
|
|
98
115
|
def convert_input(f):
|
|
99
116
|
def wrapper(*args, **kwargs):
|
|
@@ -123,7 +140,20 @@ def convert_input(f):
|
|
|
123
140
|
|
|
124
141
|
|
|
125
142
|
class APIEndpoint:
|
|
126
|
-
|
|
143
|
+
# Overwrite in parent classes as needed
|
|
144
|
+
batch_size = 3
|
|
145
|
+
params = None
|
|
146
|
+
action_classes = ()
|
|
147
|
+
api_version = 2
|
|
148
|
+
|
|
149
|
+
predict_input_key = "sequence"
|
|
150
|
+
encode_input_key = "sequence"
|
|
151
|
+
generate_input_key = "context"
|
|
152
|
+
|
|
153
|
+
predict_input_classes = ()
|
|
154
|
+
encode_input_classes = ()
|
|
155
|
+
generate_input_classes = ()
|
|
156
|
+
transform_input_classes = ()
|
|
127
157
|
|
|
128
158
|
def __init__(self, multiprocess_threads=None):
|
|
129
159
|
# Check for instance-specific threads, otherwise read from env var
|
|
@@ -137,7 +167,7 @@ class APIEndpoint:
|
|
|
137
167
|
[c.__name__.replace("Action", "").lower() for c in self.action_classes]
|
|
138
168
|
)
|
|
139
169
|
|
|
140
|
-
def post_batches(self, dat, slug, action, payload_maker, resp_key):
|
|
170
|
+
def post_batches(self, dat, slug, action, payload_maker, resp_key, key="sequence", params=None):
|
|
141
171
|
keep_batches = dat.loc[~dat.batch.isnull(), ["text", "batch"]]
|
|
142
172
|
if keep_batches.shape[0] == 0:
|
|
143
173
|
pass # Do nothing - we made nice JSON errors to return in the DF
|
|
@@ -145,7 +175,7 @@ class APIEndpoint:
|
|
|
145
175
|
# raise AssertionError(err)
|
|
146
176
|
if keep_batches.shape[0] > 0:
|
|
147
177
|
api_resps = async_api_call_wrapper(
|
|
148
|
-
keep_batches, slug, action, payload_maker, resp_key
|
|
178
|
+
keep_batches, slug, action, payload_maker, resp_key, api_version=self.api_version, key=key, params=params,
|
|
149
179
|
)
|
|
150
180
|
if isinstance(api_resps, pd.DataFrame):
|
|
151
181
|
batch_res = api_resps.explode("api_resp") # Should be lists of results
|
|
@@ -170,11 +200,11 @@ class APIEndpoint:
|
|
|
170
200
|
dat["api_resp"] = None
|
|
171
201
|
return dat
|
|
172
202
|
|
|
173
|
-
def unpack_local_validations(self, dat):
|
|
203
|
+
def unpack_local_validations(self, dat, response_key):
|
|
174
204
|
dat.loc[dat.api_resp.isnull(), "api_resp"] = (
|
|
175
205
|
dat.loc[~dat.validation.isnull(), "validation"]
|
|
176
206
|
.apply(
|
|
177
|
-
predict_resp_many_in_one_to_many_singles, args=(None, None, True, None)
|
|
207
|
+
predict_resp_many_in_one_to_many_singles, args=(None, None, True, None), response_key=response_key
|
|
178
208
|
)
|
|
179
209
|
.explode()
|
|
180
210
|
)
|
|
@@ -182,39 +212,46 @@ class APIEndpoint:
|
|
|
182
212
|
return dat
|
|
183
213
|
|
|
184
214
|
@convert_input
|
|
185
|
-
@
|
|
186
|
-
def predict(self, dat):
|
|
187
|
-
|
|
188
|
-
|
|
215
|
+
@validate_action("predict")
|
|
216
|
+
def predict(self, dat, params=None):
|
|
217
|
+
if self.api_version == 1:
|
|
218
|
+
dat = self.post_batches(dat, self.slug, "predict", INST_DAT_TXT, "predictions")
|
|
219
|
+
dat = self.unpack_local_validations(dat, "predictions")
|
|
220
|
+
else:
|
|
221
|
+
dat = self.post_batches(dat, self.slug, "predict", PARAMS_ITEMS, "results", key=self.predict_input_key, params=params)
|
|
222
|
+
dat = self.unpack_local_validations(dat,"results")
|
|
189
223
|
return dat.api_resp.replace(np.nan, None).tolist()
|
|
190
224
|
|
|
191
|
-
def infer(self, dat):
|
|
192
|
-
return self.predict(dat)
|
|
225
|
+
def infer(self, dat, params=None):
|
|
226
|
+
return self.predict(dat, params)
|
|
193
227
|
|
|
194
228
|
@convert_input
|
|
195
|
-
@
|
|
229
|
+
@validate_action("transform") # api v1 legacy action
|
|
196
230
|
def transform(self, dat):
|
|
197
231
|
dat = self.post_batches(
|
|
198
232
|
dat, self.slug, "transform", INST_DAT_TXT, "predictions"
|
|
199
233
|
)
|
|
200
|
-
dat = self.unpack_local_validations(dat)
|
|
234
|
+
dat = self.unpack_local_validations(dat,"predictions")
|
|
201
235
|
return dat.api_resp.replace(np.nan, None).tolist()
|
|
202
236
|
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
# dat = self.unpack_local_validations(dat)
|
|
211
|
-
# return dat.api_resp.replace(np.nan, None).tolist()
|
|
237
|
+
@convert_input
|
|
238
|
+
@validate_action("encode")
|
|
239
|
+
def encode(self, dat, params=None):
|
|
240
|
+
|
|
241
|
+
dat = self.post_batches(dat, self.slug, "encode", PARAMS_ITEMS, "results", key=self.encode_input_key, params=params)
|
|
242
|
+
dat = self.unpack_local_validations(dat, "results")
|
|
243
|
+
return dat.api_resp.replace(np.nan, None).tolist()
|
|
212
244
|
|
|
213
245
|
@convert_input
|
|
214
|
-
@
|
|
215
|
-
def generate(self, dat):
|
|
216
|
-
|
|
217
|
-
|
|
246
|
+
@validate_action("generate")
|
|
247
|
+
def generate(self, dat, params=None):
|
|
248
|
+
if self.api_version == 1:
|
|
249
|
+
dat = self.post_batches(dat, self.slug, "generate", INST_DAT_TXT, "generated")
|
|
250
|
+
dat = self.unpack_local_validations(dat, "predictions")
|
|
251
|
+
else:
|
|
252
|
+
dat = self.post_batches(dat, self.slug, "generate", PARAMS_ITEMS, "results", key=self.generate_input_key, params=params)
|
|
253
|
+
dat = self.unpack_local_validations(dat, "results")
|
|
254
|
+
|
|
218
255
|
return dat.api_resp.replace(np.nan, None).tolist()
|
|
219
256
|
|
|
220
257
|
|
|
@@ -290,9 +327,9 @@ class TransformAction:
|
|
|
290
327
|
return "TransformAction"
|
|
291
328
|
|
|
292
329
|
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
330
|
+
class EncodeAction:
|
|
331
|
+
def __str__(self):
|
|
332
|
+
return "EncodeAction"
|
|
296
333
|
|
|
297
334
|
|
|
298
335
|
class ExplainAction:
|
|
@@ -7,7 +7,7 @@ import aiohttp.resolver
|
|
|
7
7
|
from aiohttp import ClientSession
|
|
8
8
|
|
|
9
9
|
from biolmai.auth import get_user_auth_header
|
|
10
|
-
from biolmai.const import BASE_API_URL, MULTIPROCESS_THREADS
|
|
10
|
+
from biolmai.const import BASE_API_URL, BASE_API_URL_V1, MULTIPROCESS_THREADS
|
|
11
11
|
|
|
12
12
|
aiohttp.resolver.DefaultResolver = aiohttp.resolver.AsyncResolver
|
|
13
13
|
|
|
@@ -146,11 +146,14 @@ async def async_main(urls, concurrency) -> list:
|
|
|
146
146
|
return await get_all(urls, concurrency)
|
|
147
147
|
|
|
148
148
|
|
|
149
|
-
async def async_api_calls(model_name, action, headers, payloads, response_key=None):
|
|
149
|
+
async def async_api_calls(model_name, action, headers, payloads, response_key=None, api_version=2):
|
|
150
150
|
"""Hit an arbitrary BioLM model inference API."""
|
|
151
151
|
# Normally would POST multiple sequences at once for greater efficiency,
|
|
152
152
|
# but for simplicity sake will do one at at time right now
|
|
153
|
-
|
|
153
|
+
if api_version == 1:
|
|
154
|
+
url = f"{BASE_API_URL_V1}/models/{model_name}/{action}/"
|
|
155
|
+
else:
|
|
156
|
+
url = f"{BASE_API_URL}/{model_name}/{action}/"
|
|
154
157
|
|
|
155
158
|
if not isinstance(payloads, (list, dict)):
|
|
156
159
|
err = "API request payload must be a list or dict, got {}"
|
|
@@ -180,15 +183,20 @@ async def async_api_calls(model_name, action, headers, payloads, response_key=No
|
|
|
180
183
|
# return response
|
|
181
184
|
|
|
182
185
|
|
|
183
|
-
def async_api_call_wrapper(grouped_df, slug, action, payload_maker, response_key):
|
|
186
|
+
def async_api_call_wrapper(grouped_df, slug, action, payload_maker, response_key, api_version=2, key="sequence", params=None):
|
|
184
187
|
"""Wrap API calls to assist with sequence validation as a pre-cursor to
|
|
185
188
|
each API call.
|
|
186
189
|
"""
|
|
187
190
|
model_name = slug
|
|
188
191
|
# payload = payload_maker(grouped_df)
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
+
if api_version == 1:
|
|
193
|
+
init_ploads = grouped_df.groupby("batch").apply(
|
|
194
|
+
payload_maker, include_batch_size=True
|
|
195
|
+
)
|
|
196
|
+
else:
|
|
197
|
+
init_ploads = grouped_df.groupby("batch").apply(
|
|
198
|
+
payload_maker, key=key, params=params, include_batch_size=True
|
|
199
|
+
)
|
|
192
200
|
ploads = init_ploads.to_list()
|
|
193
201
|
init_ploads = init_ploads.to_frame(name="pload")
|
|
194
202
|
init_ploads["batch"] = init_ploads.index
|
|
@@ -208,7 +216,7 @@ def async_api_call_wrapper(grouped_df, slug, action, payload_maker, response_key
|
|
|
208
216
|
# "https://python.org",
|
|
209
217
|
# ]
|
|
210
218
|
# concurrency = 3
|
|
211
|
-
api_resp = run(async_api_calls(model_name, action, headers, ploads, response_key))
|
|
219
|
+
api_resp = run(async_api_calls(model_name, action, headers, ploads, response_key, api_version))
|
|
212
220
|
api_resp = [item for sublist in api_resp for item in sublist]
|
|
213
221
|
api_resp = sorted(api_resp, key=lambda x: x["batch_id"])
|
|
214
222
|
# print(api_resp)
|
|
@@ -0,0 +1,176 @@
|
|
|
1
|
+
"""API inference classes."""
|
|
2
|
+
from biolmai.api import APIEndpoint, GenerateAction, PredictAction, TransformAction, EncodeAction
|
|
3
|
+
from biolmai.validate import (AAExtended,
|
|
4
|
+
AAExtendedPlusExtra,
|
|
5
|
+
AAUnambiguous,
|
|
6
|
+
AAUnambiguousPlusExtra,
|
|
7
|
+
DNAUnambiguous,
|
|
8
|
+
SingleOrMoreOccurrencesOf,
|
|
9
|
+
SingleOccurrenceOf,
|
|
10
|
+
PDB,
|
|
11
|
+
AAUnambiguousEmpty
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class ESMFoldSingleChain(APIEndpoint):
|
|
16
|
+
slug = "esmfold-singlechain"
|
|
17
|
+
action_classes = (PredictAction,)
|
|
18
|
+
predict_input_classes = (AAUnambiguous(),)
|
|
19
|
+
batch_size = 2
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class ESMFoldMultiChain(APIEndpoint):
|
|
23
|
+
slug = "esmfold-multichain"
|
|
24
|
+
action_classes = (PredictAction,)
|
|
25
|
+
predict_input_classes = (AAExtendedPlusExtra(extra=[":"]),)
|
|
26
|
+
batch_size = 2
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class ESM2(APIEndpoint):
|
|
30
|
+
"""Example.
|
|
31
|
+
|
|
32
|
+
.. highlight:: python
|
|
33
|
+
.. code-block:: python
|
|
34
|
+
|
|
35
|
+
{
|
|
36
|
+
"items": [{
|
|
37
|
+
"sequence": "MSILVTRPSPAGEELVSRLRTLGQVAWHFPLIEFSPGQQLPQ"
|
|
38
|
+
}]
|
|
39
|
+
}
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
action_classes = (EncodeAction, PredictAction, )
|
|
43
|
+
encode_input_classes = (AAUnambiguous(),)
|
|
44
|
+
predict_input_classes = (SingleOrMoreOccurrencesOf(token="<mask>"), AAExtendedPlusExtra(extra=["<mask>"]))
|
|
45
|
+
batch_size = 1
|
|
46
|
+
|
|
47
|
+
class ESM2_8M(ESM2):
|
|
48
|
+
slug = "esm2-8m"
|
|
49
|
+
|
|
50
|
+
class ESM2_35M(ESM2):
|
|
51
|
+
slug = "esm2-35m"
|
|
52
|
+
|
|
53
|
+
class ESM2_150M(ESM2):
|
|
54
|
+
slug = "esm2-150m"
|
|
55
|
+
|
|
56
|
+
class ESM2_650M(ESM2):
|
|
57
|
+
slug = "esm2-650m"
|
|
58
|
+
|
|
59
|
+
class ESM2_3B(ESM2):
|
|
60
|
+
slug = "esm2-3b"
|
|
61
|
+
|
|
62
|
+
class ESM1v(APIEndpoint):
|
|
63
|
+
"""Example.
|
|
64
|
+
|
|
65
|
+
.. highlight:: python
|
|
66
|
+
.. code-block:: python
|
|
67
|
+
|
|
68
|
+
{
|
|
69
|
+
"items": [{
|
|
70
|
+
"sequence": "QERLEUTGR<mask>SLGYNIVAT"
|
|
71
|
+
}]
|
|
72
|
+
}
|
|
73
|
+
"""
|
|
74
|
+
action_classes = (PredictAction,)
|
|
75
|
+
predict_input_classes = (SingleOccurrenceOf("<mask>"), AAExtendedPlusExtra(extra=["<mask>"]))
|
|
76
|
+
batch_size = 5
|
|
77
|
+
|
|
78
|
+
class ESM1v1(ESM1v):
|
|
79
|
+
slug = "esm1v-n1"
|
|
80
|
+
|
|
81
|
+
class ESM1v2(ESM1v):
|
|
82
|
+
slug = "esm1v-n2"
|
|
83
|
+
|
|
84
|
+
class ESM1v3(ESM1v):
|
|
85
|
+
slug = "esm1v-n3"
|
|
86
|
+
|
|
87
|
+
class ESM1v4(ESM1v):
|
|
88
|
+
slug = "esm1v-n4"
|
|
89
|
+
|
|
90
|
+
class ESM1v5(ESM1v):
|
|
91
|
+
slug = "esm1v-n5"
|
|
92
|
+
|
|
93
|
+
class ESM1vAll(ESM1v):
|
|
94
|
+
slug = "esm1v-all"
|
|
95
|
+
|
|
96
|
+
class ESMIF1(APIEndpoint):
|
|
97
|
+
slug = "esm-if1"
|
|
98
|
+
action_classes = (GenerateAction,)
|
|
99
|
+
generate_input_classes = PDB
|
|
100
|
+
batch_size = 2
|
|
101
|
+
generate_input_key = "pdb"
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
class ProGen2(APIEndpoint):
|
|
105
|
+
action_classes = (GenerateAction,)
|
|
106
|
+
generate_input_classes = (AAUnambiguousEmpty(),)
|
|
107
|
+
batch_size = 1
|
|
108
|
+
|
|
109
|
+
class ProGen2Oas(ProGen2):
|
|
110
|
+
slug = "progen2-oas"
|
|
111
|
+
|
|
112
|
+
class ProGen2Medium(ProGen2):
|
|
113
|
+
slug = "progen2-medium"
|
|
114
|
+
|
|
115
|
+
class ProGen2Large(ProGen2):
|
|
116
|
+
slug = "progen2-large"
|
|
117
|
+
|
|
118
|
+
class ProGen2BFD90(ProGen2):
|
|
119
|
+
slug = "progen2-bfd90"
|
|
120
|
+
|
|
121
|
+
class AbLang(APIEndpoint):
|
|
122
|
+
action_classes = (PredictAction, EncodeAction, GenerateAction,)
|
|
123
|
+
predict_input_classes = (AAUnambiguous(),)
|
|
124
|
+
encode_input_classes = (AAUnambiguous(),)
|
|
125
|
+
generate_input_classes = (SingleOrMoreOccurrencesOf(token="*"), AAUnambiguousPlusExtra(extra=["*"]))
|
|
126
|
+
batch_size = 32
|
|
127
|
+
generate_input_key = "sequence"
|
|
128
|
+
|
|
129
|
+
class AbLangHeavy(AbLang):
|
|
130
|
+
slug = "ablang-heavy"
|
|
131
|
+
|
|
132
|
+
class AbLangLight(AbLang):
|
|
133
|
+
slug = "ablang-light"
|
|
134
|
+
|
|
135
|
+
class DNABERT(APIEndpoint):
|
|
136
|
+
slug = "dnabert"
|
|
137
|
+
action_classes = (EncodeAction,)
|
|
138
|
+
encode_input_classes = (DNAUnambiguous(),)
|
|
139
|
+
batch_size = 10
|
|
140
|
+
|
|
141
|
+
class DNABERT2(APIEndpoint):
|
|
142
|
+
slug = "dnabert2"
|
|
143
|
+
action_classes = (EncodeAction,)
|
|
144
|
+
encode_input_classes = (DNAUnambiguous(),)
|
|
145
|
+
batch_size = 10
|
|
146
|
+
|
|
147
|
+
class BioLMToxV1(APIEndpoint):
|
|
148
|
+
"""Example.
|
|
149
|
+
|
|
150
|
+
.. highlight:: python
|
|
151
|
+
.. code-block:: python
|
|
152
|
+
|
|
153
|
+
{
|
|
154
|
+
"instances": [{
|
|
155
|
+
"data": {"text": "MSILVTRPSPAGEELVSRLRTLGQVAWHFPLIEFSPGQQLPQ"}
|
|
156
|
+
}]
|
|
157
|
+
}
|
|
158
|
+
"""
|
|
159
|
+
|
|
160
|
+
slug = "biolmtox_v1"
|
|
161
|
+
action_classes = (TransformAction, PredictAction,)
|
|
162
|
+
predict_input_classes = (AAUnambiguous(),)
|
|
163
|
+
transform_input_classes = (AAUnambiguous(),)
|
|
164
|
+
batch_size = 1
|
|
165
|
+
api_version = 1
|
|
166
|
+
|
|
167
|
+
class ProteInfer(APIEndpoint):
|
|
168
|
+
action_classes = (PredictAction,)
|
|
169
|
+
predict_input_classes = (AAExtended(),)
|
|
170
|
+
batch_size = 64
|
|
171
|
+
|
|
172
|
+
class ProteInferEC(ProteInfer):
|
|
173
|
+
slug = "proteinfer-ec"
|
|
174
|
+
|
|
175
|
+
class ProteInferGO(ProteInfer):
|
|
176
|
+
slug = "proteinfer-go"
|
|
@@ -26,4 +26,5 @@ if int(MULTIPROCESS_THREADS) > max_threads or int(MULTIPROCESS_THREADS) > 128:
|
|
|
26
26
|
elif int(MULTIPROCESS_THREADS) <= 0:
|
|
27
27
|
err = "Environment variable BIOLMAI_THREADS must be a positive integer."
|
|
28
28
|
raise ValueError(err)
|
|
29
|
-
|
|
29
|
+
BASE_API_URL_V1 = f"{BASE_DOMAIN}/api/v1"
|
|
30
|
+
BASE_API_URL = f"{BASE_DOMAIN}/api/v2"
|
|
@@ -7,11 +7,22 @@ def INST_DAT_TXT(batch, include_batch_size=False):
|
|
|
7
7
|
d["batch_size"] = len(d["instances"])
|
|
8
8
|
return d
|
|
9
9
|
|
|
10
|
+
def PARAMS_ITEMS(batch, key="sequence", params=None, include_batch_size=False):
|
|
11
|
+
d = {"items": []}
|
|
12
|
+
for _, row in batch.iterrows():
|
|
13
|
+
inst = {key: row.text}
|
|
14
|
+
d["items"].append(inst)
|
|
15
|
+
if include_batch_size is True:
|
|
16
|
+
d["batch_size"] = len(d["items"])
|
|
17
|
+
if isinstance(params, dict):
|
|
18
|
+
d["params"] = params
|
|
19
|
+
return d
|
|
20
|
+
|
|
10
21
|
|
|
11
22
|
def predict_resp_many_in_one_to_many_singles(
|
|
12
|
-
resp_json, status_code, batch_id, local_err, batch_size
|
|
23
|
+
resp_json, status_code, batch_id, local_err, batch_size, response_key = "results"
|
|
13
24
|
):
|
|
14
|
-
expected_root_key =
|
|
25
|
+
expected_root_key = response_key
|
|
15
26
|
to_ret = []
|
|
16
27
|
if not local_err and status_code and status_code == 200:
|
|
17
28
|
list_of_individual_seq_results = resp_json[expected_root_key]
|