biolmai 0.1.4__py2.py3-none-any.whl → 0.1.7__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of biolmai might be problematic. Click here for more details.

biolmai/__init__.py CHANGED
@@ -1,15 +1,7 @@
1
1
  """Top-level package for BioLM AI."""
2
2
  __author__ = """Nikhil Haas"""
3
- __email__ = 'nikhil@biolm.ai'
4
- __version__ = '0.1.4'
3
+ __email__ = "nikhil@biolm.ai"
4
+ __version__ = '0.1.7'
5
5
 
6
- from biolmai.biolmai import get_api_token, api_call
7
- from biolmai.api import ESMFoldSingleChain, ESMFoldMultiChain
8
6
 
9
-
10
- __all__ = [
11
- "get_api_token",
12
- "api_call",
13
- "ESMFoldSingleChain",
14
- "ESMFoldMultiChain",
15
- ]
7
+ __all__ = []
biolmai/api.py CHANGED
@@ -1,111 +1,29 @@
1
1
  """References to API endpoints."""
2
- from biolmai import biolmai
2
+ import datetime
3
3
  import inspect
4
- import pandas as pd
5
- import numpy as np
6
- from asyncio import create_task, gather, run, sleep
7
- from biolmai.asynch import async_main, async_api_calls
8
-
9
- from biolmai.biolmai import get_user_auth_header
10
- from biolmai.const import MULTIPROCESS_THREADS
4
+ import time
11
5
  from functools import lru_cache
12
6
 
13
- from biolmai.payloads import INST_DAT_TXT
14
- from biolmai.validate import ExtendedAAPlusExtra, SingleOccurrenceOf, \
15
- UnambiguousAA, \
16
- UnambiguousAAPlusExtra
17
-
18
-
19
- def predict_resp_many_in_one_to_many_singles(resp_json, status_code,
20
- batch_id, local_err, batch_size):
21
- expected_root_key = 'predictions'
22
- to_ret = []
23
- if not local_err and status_code and status_code == 200:
24
- list_of_individual_seq_results = resp_json[expected_root_key]
25
- elif local_err:
26
- list_of_individual_seq_results = [{'error': resp_json}]
27
- elif status_code and status_code != 200 and isinstance(resp_json, dict):
28
- list_of_individual_seq_results = [resp_json] * batch_size
29
- else:
30
- raise ValueError("Unexpected response in parser")
31
- for idx, item in enumerate(list_of_individual_seq_results):
32
- d = {'status_code': status_code,
33
- 'batch_id': batch_id,
34
- 'batch_item': idx}
35
- if not status_code or status_code != 200:
36
- d.update(item) # Put all resp keys at root there
37
- else:
38
- # We just append one item, mimicking a single seq in POST req/resp
39
- d[expected_root_key] = []
40
- d[expected_root_key].append(item)
41
- to_ret.append(d)
42
- return to_ret
43
-
44
-
45
- def async_api_call_wrapper(grouped_df, slug, action, payload_maker,
46
- response_key):
47
- """Wrap API calls to assist with sequence validation as a pre-cursor to
48
- each API call.
49
- """
50
- model_name = slug
51
- # payload = payload_maker(grouped_df)
52
- init_ploads = grouped_df.groupby('batch').apply(payload_maker, include_batch_size=True)
53
- ploads = init_ploads.to_list()
54
- init_ploads = init_ploads.to_frame(name='pload')
55
- init_ploads['batch'] = init_ploads.index
56
- init_ploads = init_ploads.reset_index(drop=True)
57
- assert len(ploads) == init_ploads.shape[0]
58
- for inst, b in zip(ploads, init_ploads['batch'].to_list()):
59
- inst['batch'] = b
60
-
61
- headers = get_user_auth_header() # Need to pull each time
62
- urls = [
63
- "https://github.com",
64
- "https://stackoverflow.com",
65
- "https://python.org",
66
- ]
67
- # concurrency = 3
68
- api_resp = run(async_api_calls(model_name, action, headers,
69
- ploads, response_key))
70
- api_resp = [item for sublist in api_resp for item in sublist]
71
- api_resp = sorted(api_resp, key=lambda x: x['batch_id'])
72
- # print(api_resp)
73
- # api_resp = biolmai.api_call(model_name, action, headers, payload,
74
- # response_key)
75
- # resp_json = api_resp.json()
76
- # batch_id = int(grouped_df.batch.iloc[0])
77
- # batch_size = grouped_df.shape[0]
78
- # response = predict_resp_many_in_one_to_many_singles(
79
- # resp_json, api_resp.status_code, batch_id, None, batch_size)
80
- return api_resp
81
-
82
-
83
- def api_call_wrapper(df, args):
84
- """Wrap API calls to assist with sequence validation as a pre-cursor to
85
- each API call.
86
- """
87
- model_name, action, payload_maker, response_key = args
88
- payload = payload_maker(df)
89
- headers = get_user_auth_header() # Need to pull each time
90
- api_resp = biolmai.api_call(model_name, action, headers, payload,
91
- response_key)
92
- resp_json = api_resp.json()
93
- batch_id = int(df.batch.iloc[0])
94
- batch_size = df.shape[0]
95
- response = predict_resp_many_in_one_to_many_singles(
96
- resp_json, api_resp.status_code, batch_id, None, batch_size)
97
- return response
7
+ import numpy as np
8
+ import pandas as pd
9
+ import requests
10
+ from requests.adapters import HTTPAdapter
11
+ from requests.packages.urllib3.util.retry import Retry
12
+
13
+ import biolmai
14
+ import biolmai.auth
15
+ from biolmai.asynch import async_api_call_wrapper
16
+ from biolmai.biolmai import log
17
+ from biolmai.const import MULTIPROCESS_THREADS
18
+ from biolmai.payloads import INST_DAT_TXT, predict_resp_many_in_one_to_many_singles
98
19
 
99
20
 
100
21
  @lru_cache(maxsize=64)
101
22
  def validate_endpoint_action(allowed_classes, method_name, api_class_name):
102
- action_method_name = method_name.split('.')[-1]
23
+ action_method_name = method_name.split(".")[-1]
103
24
  if action_method_name not in allowed_classes:
104
- err = 'Only {} supported on {}'
105
- err = err.format(
106
- list(allowed_classes),
107
- api_class_name
108
- )
25
+ err = "Only {} supported on {}"
26
+ err = err.format(list(allowed_classes), api_class_name)
109
27
  raise AssertionError(err)
110
28
 
111
29
 
@@ -125,24 +43,23 @@ def validate(f):
125
43
  # like ESMFoldSinglechain.
126
44
  class_obj_self = args[0]
127
45
  try:
128
- is_method = inspect.getfullargspec(f)[0][0] == 'self'
129
- except:
46
+ is_method = inspect.getfullargspec(f)[0][0] == "self"
47
+ except Exception:
130
48
  is_method = False
131
49
 
132
50
  # Is the function we decorated a class method?
133
51
  if is_method:
134
- name = '{}.{}.{}'.format(f.__module__, args[0].__class__.__name__,
135
- f.__name__)
52
+ name = f"{f.__module__}.{class_obj_self.__class__.__name__}.{f.__name__}"
136
53
  else:
137
- name = '{}.{}'.format(f.__module__, f.__name__)
54
+ name = f"{f.__module__}.{f.__name__}"
138
55
 
139
56
  if is_method:
140
57
  # Splits name, e.g. 'biolmai.api.ESMFoldSingleChain.predict'
141
- action_method_name = name.split('.')[-1]
58
+ action_method_name = name.split(".")[-1]
142
59
  validate_endpoint_action(
143
60
  class_obj_self.action_class_strings,
144
61
  action_method_name,
145
- class_obj_self.__class__.__name__
62
+ class_obj_self.__class__.__name__,
146
63
  )
147
64
 
148
65
  input_data = args[1]
@@ -150,35 +67,38 @@ def validate(f):
150
67
  for c in class_obj_self.seq_classes:
151
68
  # Validate input data against regex
152
69
  if class_obj_self.multiprocess_threads:
153
- validation = input_data.text.apply(text_validator, args=(c, ))
70
+ validation = input_data.text.apply(text_validator, args=(c,))
154
71
  else:
155
- validation = input_data.text.apply(text_validator, args=(c, ))
156
- if 'validation' not in input_data.columns:
157
- input_data['validation'] = validation
72
+ validation = input_data.text.apply(text_validator, args=(c,))
73
+ if "validation" not in input_data.columns:
74
+ input_data["validation"] = validation
158
75
  else:
159
- input_data['validation'] = input_data['validation'].str.cat(
160
- validation, sep='\n', na_rep='')
76
+ input_data["validation"] = input_data["validation"].str.cat(
77
+ validation, sep="\n", na_rep=""
78
+ )
161
79
 
162
80
  # Mark your batches, excluding invalid rows
163
81
  valid_dat = input_data.loc[input_data.validation.isnull(), :].copy()
164
82
  N = class_obj_self.batch_size # N rows will go per API request
165
83
  # JOIN back, which is by index
166
84
  if valid_dat.shape[0] != input_data.shape[0]:
167
- valid_dat['batch'] = np.arange(valid_dat.shape[0])//N
85
+ valid_dat["batch"] = np.arange(valid_dat.shape[0]) // N
168
86
  input_data = input_data.merge(
169
- valid_dat.batch, left_index=True, right_index=True, how='left')
87
+ valid_dat.batch, left_index=True, right_index=True, how="left"
88
+ )
170
89
  else:
171
- input_data['batch'] = np.arange(input_data.shape[0])//N
90
+ input_data["batch"] = np.arange(input_data.shape[0]) // N
172
91
 
173
92
  res = f(class_obj_self, input_data, **kwargs)
174
93
  return res
94
+
175
95
  return wrapper
176
96
 
177
97
 
178
98
  def convert_input(f):
179
99
  def wrapper(*args, **kwargs):
180
100
  # Get the user-input data argument to the decorated function
181
- class_obj_self = args[0]
101
+ # class_obj_self = args[0]
182
102
  input_data = args[1]
183
103
  # Make sure we have expected input types
184
104
  acceptable_inputs = (str, list, tuple, np.ndarray, pd.DataFrame)
@@ -196,12 +116,13 @@ def convert_input(f):
196
116
  if isinstance(input_data, pd.DataFrame) and len(input_data.shape) > 1:
197
117
  err = "Detected Pandas DataFrame - input a single vector or Series"
198
118
  raise AssertionError(err)
199
- input_data = pd.DataFrame(input_data, columns=['text'])
119
+ input_data = pd.DataFrame(input_data, columns=["text"])
200
120
  return f(args[0], input_data, **kwargs)
121
+
201
122
  return wrapper
202
123
 
203
124
 
204
- class APIEndpoint(object):
125
+ class APIEndpoint:
205
126
  batch_size = 3 # Overwrite in parent classes as needed
206
127
 
207
128
  def __init__(self, multiprocess_threads=None):
@@ -211,32 +132,26 @@ class APIEndpoint(object):
211
132
  else:
212
133
  self.multiprocess_threads = MULTIPROCESS_THREADS # Could be False
213
134
  # Get correct auth-like headers
214
- self.auth_headers = biolmai.get_user_auth_header()
215
- self.action_class_strings = tuple([
216
- c.__name__.replace('Action', '').lower() for c in self.action_classes
217
- ])
135
+ self.auth_headers = biolmai.auth.get_user_auth_header()
136
+ self.action_class_strings = tuple(
137
+ [c.__name__.replace("Action", "").lower() for c in self.action_classes]
138
+ )
218
139
 
219
- @convert_input
220
- @validate
221
- def predict(self, dat):
222
- keep_batches = dat.loc[~dat.batch.isnull(), ['text', 'batch']]
140
+ def post_batches(self, dat, slug, action, payload_maker, resp_key):
141
+ keep_batches = dat.loc[~dat.batch.isnull(), ["text", "batch"]]
223
142
  if keep_batches.shape[0] == 0:
224
143
  pass # Do nothing - we made nice JSON errors to return in the DF
225
144
  # err = "No inputs found following local validation"
226
145
  # raise AssertionError(err)
227
146
  if keep_batches.shape[0] > 0:
228
147
  api_resps = async_api_call_wrapper(
229
- keep_batches,
230
- self.slug,
231
- 'predict',
232
- INST_DAT_TXT,
233
- 'predictions'
148
+ keep_batches, slug, action, payload_maker, resp_key
234
149
  )
235
150
  if isinstance(api_resps, pd.DataFrame):
236
- batch_res = api_resps.explode('api_resp') # Should be lists of results
151
+ batch_res = api_resps.explode("api_resp") # Should be lists of results
237
152
  len_res = batch_res.shape[0]
238
153
  else:
239
- batch_res = pd.DataFrame({'api_resp': api_resps})
154
+ batch_res = pd.DataFrame({"api_resp": api_resps})
240
155
  len_res = batch_res.shape[0]
241
156
  orig_request_rows = keep_batches.shape[0]
242
157
  if len_res != orig_request_rows:
@@ -245,150 +160,151 @@ class APIEndpoint(object):
245
160
  raise AssertionError(err)
246
161
 
247
162
  # Stack the results horizontally w/ original rows of batches
248
- keep_batches['prev_idx'] = keep_batches.index
163
+ keep_batches["prev_idx"] = keep_batches.index
249
164
  keep_batches.reset_index(drop=False, inplace=True)
250
165
  batch_res.reset_index(drop=True, inplace=True)
251
- keep_batches['api_resp'] = batch_res
252
- keep_batches.set_index('prev_idx', inplace=True)
253
- dat = dat.join(keep_batches.reindex(['api_resp'], axis=1))
166
+ keep_batches["api_resp"] = batch_res
167
+ keep_batches.set_index("prev_idx", inplace=True)
168
+ dat = dat.join(keep_batches.reindex(["api_resp"], axis=1))
254
169
  else:
255
- dat['api_resp'] = None
170
+ dat["api_resp"] = None
171
+ return dat
172
+
173
+ def unpack_local_validations(self, dat):
174
+ dat.loc[dat.api_resp.isnull(), "api_resp"] = (
175
+ dat.loc[~dat.validation.isnull(), "validation"]
176
+ .apply(
177
+ predict_resp_many_in_one_to_many_singles, args=(None, None, True, None)
178
+ )
179
+ .explode()
180
+ )
256
181
 
257
- dat.loc[
258
- dat.api_resp.isnull(), 'api_resp'
259
- ] = dat.loc[~dat.validation.isnull(), 'validation'].apply(
260
- predict_resp_many_in_one_to_many_singles,
261
- args=(None, None, True, None)).explode()
182
+ return dat
262
183
 
184
+ @convert_input
185
+ @validate
186
+ def predict(self, dat):
187
+ dat = self.post_batches(dat, self.slug, "predict", INST_DAT_TXT, "predictions")
188
+ dat = self.unpack_local_validations(dat)
263
189
  return dat.api_resp.replace(np.nan, None).tolist()
264
190
 
265
191
  def infer(self, dat):
266
192
  return self.predict(dat)
267
193
 
194
+ @convert_input
268
195
  @validate
269
- def tokenize(self, dat):
270
- payload = {"instances": [{"data": {"text": dat}}]}
271
- resp = biolmai.api_call(
272
- model_name=self.slug,
273
- headers=self.auth_headers, # From APIEndpoint base class
274
- action='transform',
275
- payload=payload
196
+ def transform(self, dat):
197
+ dat = self.post_batches(
198
+ dat, self.slug, "transform", INST_DAT_TXT, "predictions"
276
199
  )
277
- return resp
200
+ dat = self.unpack_local_validations(dat)
201
+ return dat.api_resp.replace(np.nan, None).tolist()
278
202
 
203
+ # @convert_input
204
+ # @validate
205
+ # def encode(self, dat):
206
+ # # NOTE: we defined this for the specific case of ESM2
207
+ # # TODO: this will be need again in v2 of API contract
208
+ # dat = self.post_batches(dat, self.slug, "transform",
209
+ # INST_DAT_TXT, "embeddings")
210
+ # dat = self.unpack_local_validations(dat)
211
+ # return dat.api_resp.replace(np.nan, None).tolist()
279
212
 
280
- class PredictAction(object):
213
+ @convert_input
214
+ @validate
215
+ def generate(self, dat):
216
+ dat = self.post_batches(dat, self.slug, "generate", INST_DAT_TXT, "generated")
217
+ dat = self.unpack_local_validations(dat)
218
+ return dat.api_resp.replace(np.nan, None).tolist()
281
219
 
282
- def __str__(self):
283
- return 'PredictAction'
284
220
 
221
+ def retry_minutes(sess, URL, HEADERS, dat, timeout, mins):
222
+ """Retry for N minutes."""
223
+ HEADERS.update({"Content-Type": "application/json"})
224
+ attempts, max_attempts = 0, 5
225
+ try:
226
+ now = datetime.datetime.now()
227
+ try_until = now + datetime.timedelta(minutes=mins)
228
+ while datetime.datetime.now() < try_until and attempts < max_attempts:
229
+ response = None
230
+ try:
231
+ log.info(f"Trying {datetime.datetime.now()}")
232
+ response = sess.post(URL, headers=HEADERS, data=dat, timeout=timeout)
233
+ if response.status_code not in (400, 404):
234
+ response.raise_for_status()
235
+ if "error" in response.json():
236
+ raise ValueError(response.json().dumps())
237
+ else:
238
+ break
239
+ except Exception as e:
240
+ log.warning(e)
241
+ if response:
242
+ log.warning(response.text)
243
+ time.sleep(5) # Wait 5 seconds between tries
244
+ attempts += 1
245
+ if response is None:
246
+ err = "Got Nonetype response"
247
+ raise ValueError(err)
248
+ elif "Server Error" in response.text:
249
+ err = "Got Server Error"
250
+ raise ValueError(err)
251
+ except Exception:
252
+ return response
253
+ return response
285
254
 
286
- class GenerateAction(object):
287
255
 
256
+ def requests_retry_session(
257
+ retries=3,
258
+ backoff_factor=0.3,
259
+ status_forcelist=None,
260
+ session=None,
261
+ ):
262
+ if status_forcelist is None:
263
+ status_forcelist = list(range(400, 599))
264
+ session = session or requests.Session()
265
+ retry = Retry(
266
+ total=retries,
267
+ read=retries,
268
+ connect=retries,
269
+ backoff_factor=backoff_factor,
270
+ status_forcelist=status_forcelist,
271
+ )
272
+ adapter = HTTPAdapter(max_retries=retry)
273
+ session.mount("http://", adapter)
274
+ session.mount("https://", adapter)
275
+ return session
276
+
277
+
278
+ class PredictAction:
288
279
  def __str__(self):
289
- return 'GenerateAction'
290
-
280
+ return "PredictAction"
291
281
 
292
- class TransformAction(object):
293
282
 
283
+ class GenerateAction:
294
284
  def __str__(self):
295
- return 'TransformAction'
285
+ return "GenerateAction"
296
286
 
297
287
 
298
- class ExplainAction(object):
299
-
288
+ class TransformAction:
300
289
  def __str__(self):
301
- return 'ExplainAction'
290
+ return "TransformAction"
291
+
302
292
 
293
+ # class EncodeAction:
294
+ # def __str__(self):
295
+ # return "EncodeAction"
303
296
 
304
- class SimilarityAction(object):
305
297
 
298
+ class ExplainAction:
306
299
  def __str__(self):
307
- return 'SimilarityAction'
300
+ return "ExplainAction"
308
301
 
309
302
 
310
- class FinetuneAction(object):
303
+ class SimilarityAction:
304
+ def __str__(self):
305
+ return "SimilarityAction"
306
+
311
307
 
308
+ class FinetuneAction:
312
309
  def __str__(self):
313
- return 'FinetuneAction'
314
-
315
-
316
- class ESMFoldSingleChain(APIEndpoint):
317
- slug = 'esmfold-singlechain'
318
- action_classes = (PredictAction, )
319
- seq_classes = (UnambiguousAA(), )
320
- batch_size = 2
321
-
322
-
323
- class ESMFoldMultiChain(APIEndpoint):
324
- slug = 'esmfold-multichain'
325
- action_classes = (PredictAction, )
326
- seq_classes = (ExtendedAAPlusExtra(extra=[':']), )
327
- batch_size = 2
328
-
329
-
330
- class ESM2Embeddings(APIEndpoint):
331
- """Example.
332
-
333
- ```python
334
- {
335
- "instances": [{
336
- "data": {"text": "MSILVTRPSPAGEELVSRLRTLGQVAWHFPLIEFSPGQQLPQ"}
337
- }]
338
- }
339
- ```
340
- """
341
- slug = 'esm2_t33_650M_UR50D'
342
- action_classes = (TransformAction,)
343
- seq_classes = (UnambiguousAA(), )
344
- batch_size = 3
345
-
346
-
347
- class ESM1v1(APIEndpoint):
348
- """Example.
349
-
350
- ```python
351
- {
352
- "instances": [{
353
- "data": {"text": "QERLEUTGR<mask>SLGYNIVAT"}
354
- }]
355
- }
356
- ```
357
- """
358
- slug = 'esm1v_t33_650M_UR90S_1'
359
- action_classes = (PredictAction, )
360
- seq_classes = (SingleOccurrenceOf('<mask>'),
361
- ExtendedAAPlusExtra(extra=['<mask>']))
362
- batch_size = 5
363
-
364
-
365
- class ESM1v2(APIEndpoint):
366
- slug = 'esm1v_t33_650M_UR90S_2'
367
- action_classes = (PredictAction, )
368
- seq_classes = (SingleOccurrenceOf('<mask>'),
369
- ExtendedAAPlusExtra(extra=['<mask>']))
370
- batch_size = 5
371
-
372
-
373
- class ESM1v3(APIEndpoint):
374
- slug = 'esm1v_t33_650M_UR90S_3'
375
- action_classes = (PredictAction, )
376
- seq_classes = (SingleOccurrenceOf('<mask>'),
377
- ExtendedAAPlusExtra(extra=['<mask>']))
378
- batch_size = 5
379
-
380
-
381
- class ESM1v4(APIEndpoint):
382
- slug = 'esm1v_t33_650M_UR90S_4'
383
- action_classes = (PredictAction, )
384
- seq_classes = (SingleOccurrenceOf('<mask>'),
385
- ExtendedAAPlusExtra(extra=['<mask>']))
386
- batch_size = 5
387
-
388
-
389
- class ESM1v5(APIEndpoint):
390
- slug = 'esm1v_t33_650M_UR90S_5'
391
- action_classes = (PredictAction, )
392
- seq_classes = (SingleOccurrenceOf('<mask>'),
393
- ExtendedAAPlusExtra(extra=['<mask>']))
394
- batch_size = 5
310
+ return "FinetuneAction"