biolmai 0.1.4__py2.py3-none-any.whl → 0.1.7__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of biolmai might be problematic. Click here for more details.
- biolmai/__init__.py +3 -11
- biolmai/api.py +163 -247
- biolmai/asynch.py +90 -53
- biolmai/auth.py +75 -29
- biolmai/biolmai.py +1 -149
- biolmai/cli.py +30 -22
- biolmai/cls.py +96 -0
- biolmai/const.py +13 -11
- biolmai/payloads.py +28 -3
- biolmai/validate.py +55 -28
- {biolmai-0.1.4.dist-info → biolmai-0.1.7.dist-info}/METADATA +1 -1
- biolmai-0.1.7.dist-info/RECORD +18 -0
- {biolmai-0.1.4.dist-info → biolmai-0.1.7.dist-info}/WHEEL +1 -1
- biolmai-0.1.4.dist-info/RECORD +0 -18
- {biolmai-0.1.4.dist-info → biolmai-0.1.7.dist-info}/AUTHORS.rst +0 -0
- {biolmai-0.1.4.dist-info → biolmai-0.1.7.dist-info}/LICENSE +0 -0
- {biolmai-0.1.4.dist-info → biolmai-0.1.7.dist-info}/entry_points.txt +0 -0
- {biolmai-0.1.4.dist-info → biolmai-0.1.7.dist-info}/top_level.txt +0 -0
biolmai/__init__.py
CHANGED
|
@@ -1,15 +1,7 @@
|
|
|
1
1
|
"""Top-level package for BioLM AI."""
|
|
2
2
|
__author__ = """Nikhil Haas"""
|
|
3
|
-
__email__ =
|
|
4
|
-
__version__ = '0.1.
|
|
3
|
+
__email__ = "nikhil@biolm.ai"
|
|
4
|
+
__version__ = '0.1.7'
|
|
5
5
|
|
|
6
|
-
from biolmai.biolmai import get_api_token, api_call
|
|
7
|
-
from biolmai.api import ESMFoldSingleChain, ESMFoldMultiChain
|
|
8
6
|
|
|
9
|
-
|
|
10
|
-
__all__ = [
|
|
11
|
-
"get_api_token",
|
|
12
|
-
"api_call",
|
|
13
|
-
"ESMFoldSingleChain",
|
|
14
|
-
"ESMFoldMultiChain",
|
|
15
|
-
]
|
|
7
|
+
__all__ = []
|
biolmai/api.py
CHANGED
|
@@ -1,111 +1,29 @@
|
|
|
1
1
|
"""References to API endpoints."""
|
|
2
|
-
|
|
2
|
+
import datetime
|
|
3
3
|
import inspect
|
|
4
|
-
import
|
|
5
|
-
import numpy as np
|
|
6
|
-
from asyncio import create_task, gather, run, sleep
|
|
7
|
-
from biolmai.asynch import async_main, async_api_calls
|
|
8
|
-
|
|
9
|
-
from biolmai.biolmai import get_user_auth_header
|
|
10
|
-
from biolmai.const import MULTIPROCESS_THREADS
|
|
4
|
+
import time
|
|
11
5
|
from functools import lru_cache
|
|
12
6
|
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
elif local_err:
|
|
26
|
-
list_of_individual_seq_results = [{'error': resp_json}]
|
|
27
|
-
elif status_code and status_code != 200 and isinstance(resp_json, dict):
|
|
28
|
-
list_of_individual_seq_results = [resp_json] * batch_size
|
|
29
|
-
else:
|
|
30
|
-
raise ValueError("Unexpected response in parser")
|
|
31
|
-
for idx, item in enumerate(list_of_individual_seq_results):
|
|
32
|
-
d = {'status_code': status_code,
|
|
33
|
-
'batch_id': batch_id,
|
|
34
|
-
'batch_item': idx}
|
|
35
|
-
if not status_code or status_code != 200:
|
|
36
|
-
d.update(item) # Put all resp keys at root there
|
|
37
|
-
else:
|
|
38
|
-
# We just append one item, mimicking a single seq in POST req/resp
|
|
39
|
-
d[expected_root_key] = []
|
|
40
|
-
d[expected_root_key].append(item)
|
|
41
|
-
to_ret.append(d)
|
|
42
|
-
return to_ret
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
def async_api_call_wrapper(grouped_df, slug, action, payload_maker,
|
|
46
|
-
response_key):
|
|
47
|
-
"""Wrap API calls to assist with sequence validation as a pre-cursor to
|
|
48
|
-
each API call.
|
|
49
|
-
"""
|
|
50
|
-
model_name = slug
|
|
51
|
-
# payload = payload_maker(grouped_df)
|
|
52
|
-
init_ploads = grouped_df.groupby('batch').apply(payload_maker, include_batch_size=True)
|
|
53
|
-
ploads = init_ploads.to_list()
|
|
54
|
-
init_ploads = init_ploads.to_frame(name='pload')
|
|
55
|
-
init_ploads['batch'] = init_ploads.index
|
|
56
|
-
init_ploads = init_ploads.reset_index(drop=True)
|
|
57
|
-
assert len(ploads) == init_ploads.shape[0]
|
|
58
|
-
for inst, b in zip(ploads, init_ploads['batch'].to_list()):
|
|
59
|
-
inst['batch'] = b
|
|
60
|
-
|
|
61
|
-
headers = get_user_auth_header() # Need to pull each time
|
|
62
|
-
urls = [
|
|
63
|
-
"https://github.com",
|
|
64
|
-
"https://stackoverflow.com",
|
|
65
|
-
"https://python.org",
|
|
66
|
-
]
|
|
67
|
-
# concurrency = 3
|
|
68
|
-
api_resp = run(async_api_calls(model_name, action, headers,
|
|
69
|
-
ploads, response_key))
|
|
70
|
-
api_resp = [item for sublist in api_resp for item in sublist]
|
|
71
|
-
api_resp = sorted(api_resp, key=lambda x: x['batch_id'])
|
|
72
|
-
# print(api_resp)
|
|
73
|
-
# api_resp = biolmai.api_call(model_name, action, headers, payload,
|
|
74
|
-
# response_key)
|
|
75
|
-
# resp_json = api_resp.json()
|
|
76
|
-
# batch_id = int(grouped_df.batch.iloc[0])
|
|
77
|
-
# batch_size = grouped_df.shape[0]
|
|
78
|
-
# response = predict_resp_many_in_one_to_many_singles(
|
|
79
|
-
# resp_json, api_resp.status_code, batch_id, None, batch_size)
|
|
80
|
-
return api_resp
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
def api_call_wrapper(df, args):
|
|
84
|
-
"""Wrap API calls to assist with sequence validation as a pre-cursor to
|
|
85
|
-
each API call.
|
|
86
|
-
"""
|
|
87
|
-
model_name, action, payload_maker, response_key = args
|
|
88
|
-
payload = payload_maker(df)
|
|
89
|
-
headers = get_user_auth_header() # Need to pull each time
|
|
90
|
-
api_resp = biolmai.api_call(model_name, action, headers, payload,
|
|
91
|
-
response_key)
|
|
92
|
-
resp_json = api_resp.json()
|
|
93
|
-
batch_id = int(df.batch.iloc[0])
|
|
94
|
-
batch_size = df.shape[0]
|
|
95
|
-
response = predict_resp_many_in_one_to_many_singles(
|
|
96
|
-
resp_json, api_resp.status_code, batch_id, None, batch_size)
|
|
97
|
-
return response
|
|
7
|
+
import numpy as np
|
|
8
|
+
import pandas as pd
|
|
9
|
+
import requests
|
|
10
|
+
from requests.adapters import HTTPAdapter
|
|
11
|
+
from requests.packages.urllib3.util.retry import Retry
|
|
12
|
+
|
|
13
|
+
import biolmai
|
|
14
|
+
import biolmai.auth
|
|
15
|
+
from biolmai.asynch import async_api_call_wrapper
|
|
16
|
+
from biolmai.biolmai import log
|
|
17
|
+
from biolmai.const import MULTIPROCESS_THREADS
|
|
18
|
+
from biolmai.payloads import INST_DAT_TXT, predict_resp_many_in_one_to_many_singles
|
|
98
19
|
|
|
99
20
|
|
|
100
21
|
@lru_cache(maxsize=64)
|
|
101
22
|
def validate_endpoint_action(allowed_classes, method_name, api_class_name):
|
|
102
|
-
action_method_name = method_name.split(
|
|
23
|
+
action_method_name = method_name.split(".")[-1]
|
|
103
24
|
if action_method_name not in allowed_classes:
|
|
104
|
-
err =
|
|
105
|
-
err = err.format(
|
|
106
|
-
list(allowed_classes),
|
|
107
|
-
api_class_name
|
|
108
|
-
)
|
|
25
|
+
err = "Only {} supported on {}"
|
|
26
|
+
err = err.format(list(allowed_classes), api_class_name)
|
|
109
27
|
raise AssertionError(err)
|
|
110
28
|
|
|
111
29
|
|
|
@@ -125,24 +43,23 @@ def validate(f):
|
|
|
125
43
|
# like ESMFoldSinglechain.
|
|
126
44
|
class_obj_self = args[0]
|
|
127
45
|
try:
|
|
128
|
-
is_method = inspect.getfullargspec(f)[0][0] ==
|
|
129
|
-
except:
|
|
46
|
+
is_method = inspect.getfullargspec(f)[0][0] == "self"
|
|
47
|
+
except Exception:
|
|
130
48
|
is_method = False
|
|
131
49
|
|
|
132
50
|
# Is the function we decorated a class method?
|
|
133
51
|
if is_method:
|
|
134
|
-
name =
|
|
135
|
-
f.__name__)
|
|
52
|
+
name = f"{f.__module__}.{class_obj_self.__class__.__name__}.{f.__name__}"
|
|
136
53
|
else:
|
|
137
|
-
name =
|
|
54
|
+
name = f"{f.__module__}.{f.__name__}"
|
|
138
55
|
|
|
139
56
|
if is_method:
|
|
140
57
|
# Splits name, e.g. 'biolmai.api.ESMFoldSingleChain.predict'
|
|
141
|
-
action_method_name = name.split(
|
|
58
|
+
action_method_name = name.split(".")[-1]
|
|
142
59
|
validate_endpoint_action(
|
|
143
60
|
class_obj_self.action_class_strings,
|
|
144
61
|
action_method_name,
|
|
145
|
-
class_obj_self.__class__.__name__
|
|
62
|
+
class_obj_self.__class__.__name__,
|
|
146
63
|
)
|
|
147
64
|
|
|
148
65
|
input_data = args[1]
|
|
@@ -150,35 +67,38 @@ def validate(f):
|
|
|
150
67
|
for c in class_obj_self.seq_classes:
|
|
151
68
|
# Validate input data against regex
|
|
152
69
|
if class_obj_self.multiprocess_threads:
|
|
153
|
-
validation = input_data.text.apply(text_validator, args=(c,
|
|
70
|
+
validation = input_data.text.apply(text_validator, args=(c,))
|
|
154
71
|
else:
|
|
155
|
-
validation = input_data.text.apply(text_validator, args=(c,
|
|
156
|
-
if
|
|
157
|
-
input_data[
|
|
72
|
+
validation = input_data.text.apply(text_validator, args=(c,))
|
|
73
|
+
if "validation" not in input_data.columns:
|
|
74
|
+
input_data["validation"] = validation
|
|
158
75
|
else:
|
|
159
|
-
input_data[
|
|
160
|
-
validation, sep=
|
|
76
|
+
input_data["validation"] = input_data["validation"].str.cat(
|
|
77
|
+
validation, sep="\n", na_rep=""
|
|
78
|
+
)
|
|
161
79
|
|
|
162
80
|
# Mark your batches, excluding invalid rows
|
|
163
81
|
valid_dat = input_data.loc[input_data.validation.isnull(), :].copy()
|
|
164
82
|
N = class_obj_self.batch_size # N rows will go per API request
|
|
165
83
|
# JOIN back, which is by index
|
|
166
84
|
if valid_dat.shape[0] != input_data.shape[0]:
|
|
167
|
-
valid_dat[
|
|
85
|
+
valid_dat["batch"] = np.arange(valid_dat.shape[0]) // N
|
|
168
86
|
input_data = input_data.merge(
|
|
169
|
-
valid_dat.batch, left_index=True, right_index=True, how=
|
|
87
|
+
valid_dat.batch, left_index=True, right_index=True, how="left"
|
|
88
|
+
)
|
|
170
89
|
else:
|
|
171
|
-
input_data[
|
|
90
|
+
input_data["batch"] = np.arange(input_data.shape[0]) // N
|
|
172
91
|
|
|
173
92
|
res = f(class_obj_self, input_data, **kwargs)
|
|
174
93
|
return res
|
|
94
|
+
|
|
175
95
|
return wrapper
|
|
176
96
|
|
|
177
97
|
|
|
178
98
|
def convert_input(f):
|
|
179
99
|
def wrapper(*args, **kwargs):
|
|
180
100
|
# Get the user-input data argument to the decorated function
|
|
181
|
-
class_obj_self = args[0]
|
|
101
|
+
# class_obj_self = args[0]
|
|
182
102
|
input_data = args[1]
|
|
183
103
|
# Make sure we have expected input types
|
|
184
104
|
acceptable_inputs = (str, list, tuple, np.ndarray, pd.DataFrame)
|
|
@@ -196,12 +116,13 @@ def convert_input(f):
|
|
|
196
116
|
if isinstance(input_data, pd.DataFrame) and len(input_data.shape) > 1:
|
|
197
117
|
err = "Detected Pandas DataFrame - input a single vector or Series"
|
|
198
118
|
raise AssertionError(err)
|
|
199
|
-
input_data = pd.DataFrame(input_data, columns=[
|
|
119
|
+
input_data = pd.DataFrame(input_data, columns=["text"])
|
|
200
120
|
return f(args[0], input_data, **kwargs)
|
|
121
|
+
|
|
201
122
|
return wrapper
|
|
202
123
|
|
|
203
124
|
|
|
204
|
-
class APIEndpoint
|
|
125
|
+
class APIEndpoint:
|
|
205
126
|
batch_size = 3 # Overwrite in parent classes as needed
|
|
206
127
|
|
|
207
128
|
def __init__(self, multiprocess_threads=None):
|
|
@@ -211,32 +132,26 @@ class APIEndpoint(object):
|
|
|
211
132
|
else:
|
|
212
133
|
self.multiprocess_threads = MULTIPROCESS_THREADS # Could be False
|
|
213
134
|
# Get correct auth-like headers
|
|
214
|
-
self.auth_headers = biolmai.get_user_auth_header()
|
|
215
|
-
self.action_class_strings = tuple(
|
|
216
|
-
c.__name__.replace(
|
|
217
|
-
|
|
135
|
+
self.auth_headers = biolmai.auth.get_user_auth_header()
|
|
136
|
+
self.action_class_strings = tuple(
|
|
137
|
+
[c.__name__.replace("Action", "").lower() for c in self.action_classes]
|
|
138
|
+
)
|
|
218
139
|
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
def predict(self, dat):
|
|
222
|
-
keep_batches = dat.loc[~dat.batch.isnull(), ['text', 'batch']]
|
|
140
|
+
def post_batches(self, dat, slug, action, payload_maker, resp_key):
|
|
141
|
+
keep_batches = dat.loc[~dat.batch.isnull(), ["text", "batch"]]
|
|
223
142
|
if keep_batches.shape[0] == 0:
|
|
224
143
|
pass # Do nothing - we made nice JSON errors to return in the DF
|
|
225
144
|
# err = "No inputs found following local validation"
|
|
226
145
|
# raise AssertionError(err)
|
|
227
146
|
if keep_batches.shape[0] > 0:
|
|
228
147
|
api_resps = async_api_call_wrapper(
|
|
229
|
-
keep_batches,
|
|
230
|
-
self.slug,
|
|
231
|
-
'predict',
|
|
232
|
-
INST_DAT_TXT,
|
|
233
|
-
'predictions'
|
|
148
|
+
keep_batches, slug, action, payload_maker, resp_key
|
|
234
149
|
)
|
|
235
150
|
if isinstance(api_resps, pd.DataFrame):
|
|
236
|
-
batch_res = api_resps.explode(
|
|
151
|
+
batch_res = api_resps.explode("api_resp") # Should be lists of results
|
|
237
152
|
len_res = batch_res.shape[0]
|
|
238
153
|
else:
|
|
239
|
-
batch_res = pd.DataFrame({
|
|
154
|
+
batch_res = pd.DataFrame({"api_resp": api_resps})
|
|
240
155
|
len_res = batch_res.shape[0]
|
|
241
156
|
orig_request_rows = keep_batches.shape[0]
|
|
242
157
|
if len_res != orig_request_rows:
|
|
@@ -245,150 +160,151 @@ class APIEndpoint(object):
|
|
|
245
160
|
raise AssertionError(err)
|
|
246
161
|
|
|
247
162
|
# Stack the results horizontally w/ original rows of batches
|
|
248
|
-
keep_batches[
|
|
163
|
+
keep_batches["prev_idx"] = keep_batches.index
|
|
249
164
|
keep_batches.reset_index(drop=False, inplace=True)
|
|
250
165
|
batch_res.reset_index(drop=True, inplace=True)
|
|
251
|
-
keep_batches[
|
|
252
|
-
keep_batches.set_index(
|
|
253
|
-
dat = dat.join(keep_batches.reindex([
|
|
166
|
+
keep_batches["api_resp"] = batch_res
|
|
167
|
+
keep_batches.set_index("prev_idx", inplace=True)
|
|
168
|
+
dat = dat.join(keep_batches.reindex(["api_resp"], axis=1))
|
|
254
169
|
else:
|
|
255
|
-
dat[
|
|
170
|
+
dat["api_resp"] = None
|
|
171
|
+
return dat
|
|
172
|
+
|
|
173
|
+
def unpack_local_validations(self, dat):
|
|
174
|
+
dat.loc[dat.api_resp.isnull(), "api_resp"] = (
|
|
175
|
+
dat.loc[~dat.validation.isnull(), "validation"]
|
|
176
|
+
.apply(
|
|
177
|
+
predict_resp_many_in_one_to_many_singles, args=(None, None, True, None)
|
|
178
|
+
)
|
|
179
|
+
.explode()
|
|
180
|
+
)
|
|
256
181
|
|
|
257
|
-
dat
|
|
258
|
-
dat.api_resp.isnull(), 'api_resp'
|
|
259
|
-
] = dat.loc[~dat.validation.isnull(), 'validation'].apply(
|
|
260
|
-
predict_resp_many_in_one_to_many_singles,
|
|
261
|
-
args=(None, None, True, None)).explode()
|
|
182
|
+
return dat
|
|
262
183
|
|
|
184
|
+
@convert_input
|
|
185
|
+
@validate
|
|
186
|
+
def predict(self, dat):
|
|
187
|
+
dat = self.post_batches(dat, self.slug, "predict", INST_DAT_TXT, "predictions")
|
|
188
|
+
dat = self.unpack_local_validations(dat)
|
|
263
189
|
return dat.api_resp.replace(np.nan, None).tolist()
|
|
264
190
|
|
|
265
191
|
def infer(self, dat):
|
|
266
192
|
return self.predict(dat)
|
|
267
193
|
|
|
194
|
+
@convert_input
|
|
268
195
|
@validate
|
|
269
|
-
def
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
model_name=self.slug,
|
|
273
|
-
headers=self.auth_headers, # From APIEndpoint base class
|
|
274
|
-
action='transform',
|
|
275
|
-
payload=payload
|
|
196
|
+
def transform(self, dat):
|
|
197
|
+
dat = self.post_batches(
|
|
198
|
+
dat, self.slug, "transform", INST_DAT_TXT, "predictions"
|
|
276
199
|
)
|
|
277
|
-
|
|
200
|
+
dat = self.unpack_local_validations(dat)
|
|
201
|
+
return dat.api_resp.replace(np.nan, None).tolist()
|
|
278
202
|
|
|
203
|
+
# @convert_input
|
|
204
|
+
# @validate
|
|
205
|
+
# def encode(self, dat):
|
|
206
|
+
# # NOTE: we defined this for the specific case of ESM2
|
|
207
|
+
# # TODO: this will be need again in v2 of API contract
|
|
208
|
+
# dat = self.post_batches(dat, self.slug, "transform",
|
|
209
|
+
# INST_DAT_TXT, "embeddings")
|
|
210
|
+
# dat = self.unpack_local_validations(dat)
|
|
211
|
+
# return dat.api_resp.replace(np.nan, None).tolist()
|
|
279
212
|
|
|
280
|
-
|
|
213
|
+
@convert_input
|
|
214
|
+
@validate
|
|
215
|
+
def generate(self, dat):
|
|
216
|
+
dat = self.post_batches(dat, self.slug, "generate", INST_DAT_TXT, "generated")
|
|
217
|
+
dat = self.unpack_local_validations(dat)
|
|
218
|
+
return dat.api_resp.replace(np.nan, None).tolist()
|
|
281
219
|
|
|
282
|
-
def __str__(self):
|
|
283
|
-
return 'PredictAction'
|
|
284
220
|
|
|
221
|
+
def retry_minutes(sess, URL, HEADERS, dat, timeout, mins):
|
|
222
|
+
"""Retry for N minutes."""
|
|
223
|
+
HEADERS.update({"Content-Type": "application/json"})
|
|
224
|
+
attempts, max_attempts = 0, 5
|
|
225
|
+
try:
|
|
226
|
+
now = datetime.datetime.now()
|
|
227
|
+
try_until = now + datetime.timedelta(minutes=mins)
|
|
228
|
+
while datetime.datetime.now() < try_until and attempts < max_attempts:
|
|
229
|
+
response = None
|
|
230
|
+
try:
|
|
231
|
+
log.info(f"Trying {datetime.datetime.now()}")
|
|
232
|
+
response = sess.post(URL, headers=HEADERS, data=dat, timeout=timeout)
|
|
233
|
+
if response.status_code not in (400, 404):
|
|
234
|
+
response.raise_for_status()
|
|
235
|
+
if "error" in response.json():
|
|
236
|
+
raise ValueError(response.json().dumps())
|
|
237
|
+
else:
|
|
238
|
+
break
|
|
239
|
+
except Exception as e:
|
|
240
|
+
log.warning(e)
|
|
241
|
+
if response:
|
|
242
|
+
log.warning(response.text)
|
|
243
|
+
time.sleep(5) # Wait 5 seconds between tries
|
|
244
|
+
attempts += 1
|
|
245
|
+
if response is None:
|
|
246
|
+
err = "Got Nonetype response"
|
|
247
|
+
raise ValueError(err)
|
|
248
|
+
elif "Server Error" in response.text:
|
|
249
|
+
err = "Got Server Error"
|
|
250
|
+
raise ValueError(err)
|
|
251
|
+
except Exception:
|
|
252
|
+
return response
|
|
253
|
+
return response
|
|
285
254
|
|
|
286
|
-
class GenerateAction(object):
|
|
287
255
|
|
|
256
|
+
def requests_retry_session(
|
|
257
|
+
retries=3,
|
|
258
|
+
backoff_factor=0.3,
|
|
259
|
+
status_forcelist=None,
|
|
260
|
+
session=None,
|
|
261
|
+
):
|
|
262
|
+
if status_forcelist is None:
|
|
263
|
+
status_forcelist = list(range(400, 599))
|
|
264
|
+
session = session or requests.Session()
|
|
265
|
+
retry = Retry(
|
|
266
|
+
total=retries,
|
|
267
|
+
read=retries,
|
|
268
|
+
connect=retries,
|
|
269
|
+
backoff_factor=backoff_factor,
|
|
270
|
+
status_forcelist=status_forcelist,
|
|
271
|
+
)
|
|
272
|
+
adapter = HTTPAdapter(max_retries=retry)
|
|
273
|
+
session.mount("http://", adapter)
|
|
274
|
+
session.mount("https://", adapter)
|
|
275
|
+
return session
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
class PredictAction:
|
|
288
279
|
def __str__(self):
|
|
289
|
-
return
|
|
290
|
-
|
|
280
|
+
return "PredictAction"
|
|
291
281
|
|
|
292
|
-
class TransformAction(object):
|
|
293
282
|
|
|
283
|
+
class GenerateAction:
|
|
294
284
|
def __str__(self):
|
|
295
|
-
return
|
|
285
|
+
return "GenerateAction"
|
|
296
286
|
|
|
297
287
|
|
|
298
|
-
class
|
|
299
|
-
|
|
288
|
+
class TransformAction:
|
|
300
289
|
def __str__(self):
|
|
301
|
-
return
|
|
290
|
+
return "TransformAction"
|
|
291
|
+
|
|
302
292
|
|
|
293
|
+
# class EncodeAction:
|
|
294
|
+
# def __str__(self):
|
|
295
|
+
# return "EncodeAction"
|
|
303
296
|
|
|
304
|
-
class SimilarityAction(object):
|
|
305
297
|
|
|
298
|
+
class ExplainAction:
|
|
306
299
|
def __str__(self):
|
|
307
|
-
return
|
|
300
|
+
return "ExplainAction"
|
|
308
301
|
|
|
309
302
|
|
|
310
|
-
class
|
|
303
|
+
class SimilarityAction:
|
|
304
|
+
def __str__(self):
|
|
305
|
+
return "SimilarityAction"
|
|
306
|
+
|
|
311
307
|
|
|
308
|
+
class FinetuneAction:
|
|
312
309
|
def __str__(self):
|
|
313
|
-
return
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
class ESMFoldSingleChain(APIEndpoint):
|
|
317
|
-
slug = 'esmfold-singlechain'
|
|
318
|
-
action_classes = (PredictAction, )
|
|
319
|
-
seq_classes = (UnambiguousAA(), )
|
|
320
|
-
batch_size = 2
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
class ESMFoldMultiChain(APIEndpoint):
|
|
324
|
-
slug = 'esmfold-multichain'
|
|
325
|
-
action_classes = (PredictAction, )
|
|
326
|
-
seq_classes = (ExtendedAAPlusExtra(extra=[':']), )
|
|
327
|
-
batch_size = 2
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
class ESM2Embeddings(APIEndpoint):
|
|
331
|
-
"""Example.
|
|
332
|
-
|
|
333
|
-
```python
|
|
334
|
-
{
|
|
335
|
-
"instances": [{
|
|
336
|
-
"data": {"text": "MSILVTRPSPAGEELVSRLRTLGQVAWHFPLIEFSPGQQLPQ"}
|
|
337
|
-
}]
|
|
338
|
-
}
|
|
339
|
-
```
|
|
340
|
-
"""
|
|
341
|
-
slug = 'esm2_t33_650M_UR50D'
|
|
342
|
-
action_classes = (TransformAction,)
|
|
343
|
-
seq_classes = (UnambiguousAA(), )
|
|
344
|
-
batch_size = 3
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
class ESM1v1(APIEndpoint):
|
|
348
|
-
"""Example.
|
|
349
|
-
|
|
350
|
-
```python
|
|
351
|
-
{
|
|
352
|
-
"instances": [{
|
|
353
|
-
"data": {"text": "QERLEUTGR<mask>SLGYNIVAT"}
|
|
354
|
-
}]
|
|
355
|
-
}
|
|
356
|
-
```
|
|
357
|
-
"""
|
|
358
|
-
slug = 'esm1v_t33_650M_UR90S_1'
|
|
359
|
-
action_classes = (PredictAction, )
|
|
360
|
-
seq_classes = (SingleOccurrenceOf('<mask>'),
|
|
361
|
-
ExtendedAAPlusExtra(extra=['<mask>']))
|
|
362
|
-
batch_size = 5
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
class ESM1v2(APIEndpoint):
|
|
366
|
-
slug = 'esm1v_t33_650M_UR90S_2'
|
|
367
|
-
action_classes = (PredictAction, )
|
|
368
|
-
seq_classes = (SingleOccurrenceOf('<mask>'),
|
|
369
|
-
ExtendedAAPlusExtra(extra=['<mask>']))
|
|
370
|
-
batch_size = 5
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
class ESM1v3(APIEndpoint):
|
|
374
|
-
slug = 'esm1v_t33_650M_UR90S_3'
|
|
375
|
-
action_classes = (PredictAction, )
|
|
376
|
-
seq_classes = (SingleOccurrenceOf('<mask>'),
|
|
377
|
-
ExtendedAAPlusExtra(extra=['<mask>']))
|
|
378
|
-
batch_size = 5
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
class ESM1v4(APIEndpoint):
|
|
382
|
-
slug = 'esm1v_t33_650M_UR90S_4'
|
|
383
|
-
action_classes = (PredictAction, )
|
|
384
|
-
seq_classes = (SingleOccurrenceOf('<mask>'),
|
|
385
|
-
ExtendedAAPlusExtra(extra=['<mask>']))
|
|
386
|
-
batch_size = 5
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
class ESM1v5(APIEndpoint):
|
|
390
|
-
slug = 'esm1v_t33_650M_UR90S_5'
|
|
391
|
-
action_classes = (PredictAction, )
|
|
392
|
-
seq_classes = (SingleOccurrenceOf('<mask>'),
|
|
393
|
-
ExtendedAAPlusExtra(extra=['<mask>']))
|
|
394
|
-
batch_size = 5
|
|
310
|
+
return "FinetuneAction"
|