biolmai 0.1.4__tar.gz → 0.1.7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of biolmai might be problematic. Click here for more details.

Files changed (85) hide show
  1. {biolmai-0.1.4 → biolmai-0.1.7}/PKG-INFO +1 -1
  2. biolmai-0.1.7/biolmai/__init__.py +7 -0
  3. biolmai-0.1.7/biolmai/api.py +310 -0
  4. {biolmai-0.1.4 → biolmai-0.1.7}/biolmai/asynch.py +90 -53
  5. {biolmai-0.1.4 → biolmai-0.1.7}/biolmai/auth.py +75 -29
  6. biolmai-0.1.7/biolmai/biolmai.py +5 -0
  7. biolmai-0.1.7/biolmai/cli.py +75 -0
  8. biolmai-0.1.7/biolmai/cls.py +97 -0
  9. {biolmai-0.1.4 → biolmai-0.1.7}/biolmai/const.py +13 -11
  10. biolmai-0.1.7/biolmai/payloads.py +33 -0
  11. {biolmai-0.1.4 → biolmai-0.1.7}/biolmai/validate.py +55 -28
  12. {biolmai-0.1.4 → biolmai-0.1.7}/biolmai.egg-info/PKG-INFO +1 -1
  13. biolmai-0.1.7/biolmai.egg-info/SOURCES.txt +64 -0
  14. biolmai-0.1.7/docs/_static/api_reference_icon.png +0 -0
  15. biolmai-0.1.7/docs/_static/chat_agents_icon.png +0 -0
  16. biolmai-0.1.7/docs/_static/jupyter_notebooks_icon.png +0 -0
  17. biolmai-0.1.7/docs/_static/model_docs_icon.png +0 -0
  18. biolmai-0.1.7/docs/_static/python_sdk_icon.png +0 -0
  19. biolmai-0.1.7/docs/_static/tutorials_icon.png +0 -0
  20. {biolmai-0.1.4 → biolmai-0.1.7}/docs/conf.py +32 -44
  21. biolmai-0.1.7/docs/index.rst +107 -0
  22. biolmai-0.1.7/docs/model-docs/DNABERT.rst +640 -0
  23. biolmai-0.1.7/docs/model-docs/ESM-1v.rst +362 -0
  24. biolmai-0.1.7/docs/model-docs/ESM2_Embeddings.rst +242 -0
  25. biolmai-0.1.4/docs/model-docs/esm2_fold.rst → biolmai-0.1.7/docs/model-docs/ESMFold.rst +62 -63
  26. biolmai-0.1.7/docs/model-docs/ESM_InverseFold.rst +278 -0
  27. biolmai-0.1.7/docs/model-docs/ProtGPT2.rst +609 -0
  28. biolmai-0.1.7/docs/model-docs/ProteInfer_EC.rst +249 -0
  29. biolmai-0.1.7/docs/model-docs/ProteInfer_GO.rst +329 -0
  30. biolmai-0.1.7/docs/model-docs/index.rst +13 -0
  31. biolmai-0.1.7/docs/model-docs/progen2/ProGen2_BFD90.rst +251 -0
  32. biolmai-0.1.7/docs/model-docs/progen2/ProGen2_Medium.rst +248 -0
  33. biolmai-0.1.7/docs/model-docs/progen2/ProGen2_OAS.rst +246 -0
  34. biolmai-0.1.7/docs/model-docs/progen2/index.rst +10 -0
  35. biolmai-0.1.7/docs/python-client/get_started/authorization.rst +9 -0
  36. {biolmai-0.1.4/docs/python-client → biolmai-0.1.7/docs/python-client/get_started}/quickstart.rst +7 -0
  37. biolmai-0.1.7/docs/python-client/index.rst +18 -0
  38. biolmai-0.1.7/docs/python-client/usage.rst +7 -0
  39. biolmai-0.1.7/docs/tutorials_use_cases/notebooks.rst +9 -0
  40. biolmai-0.1.7/pyproject.toml +44 -0
  41. {biolmai-0.1.4 → biolmai-0.1.7}/setup.cfg +7 -2
  42. biolmai-0.1.7/setup.py +53 -0
  43. biolmai-0.1.7/tests/test_biolmai.py +263 -0
  44. biolmai-0.1.4/biolmai/__init__.py +0 -15
  45. biolmai-0.1.4/biolmai/api.py +0 -394
  46. biolmai-0.1.4/biolmai/biolmai.py +0 -153
  47. biolmai-0.1.4/biolmai/cli.py +0 -67
  48. biolmai-0.1.4/biolmai/cls.py +0 -1
  49. biolmai-0.1.4/biolmai/payloads.py +0 -8
  50. biolmai-0.1.4/biolmai.egg-info/SOURCES.txt +0 -51
  51. biolmai-0.1.4/docs/index.rst +0 -74
  52. biolmai-0.1.4/docs/model-docs/admonitions.rst +0 -39
  53. biolmai-0.1.4/docs/model-docs/esm2_embeddings.rst +0 -10
  54. biolmai-0.1.4/docs/python-client/authors.rst +0 -1
  55. biolmai-0.1.4/docs/python-client/contributing.rst +0 -1
  56. biolmai-0.1.4/docs/python-client/history.rst +0 -1
  57. biolmai-0.1.4/docs/python-client/readme.rst +0 -1
  58. biolmai-0.1.4/docs/python-client/usage.rst +0 -8
  59. biolmai-0.1.4/docs/tutorials_use_cases/bulk_protein_folding.rst +0 -3
  60. biolmai-0.1.4/docs/tutorials_use_cases/dna_tutorials.rst +0 -8
  61. biolmai-0.1.4/docs/tutorials_use_cases/protein_tutorials.rst +0 -15
  62. biolmai-0.1.4/setup.py +0 -51
  63. biolmai-0.1.4/tests/test_biolmai.py +0 -226
  64. {biolmai-0.1.4 → biolmai-0.1.7}/AUTHORS.rst +0 -0
  65. {biolmai-0.1.4 → biolmai-0.1.7}/CONTRIBUTING.rst +0 -0
  66. {biolmai-0.1.4 → biolmai-0.1.7}/HISTORY.rst +0 -0
  67. {biolmai-0.1.4 → biolmai-0.1.7}/LICENSE +0 -0
  68. {biolmai-0.1.4 → biolmai-0.1.7}/MANIFEST.in +0 -0
  69. {biolmai-0.1.4 → biolmai-0.1.7}/README.rst +0 -0
  70. {biolmai-0.1.4 → biolmai-0.1.7}/biolmai/ltc.py +0 -0
  71. {biolmai-0.1.4 → biolmai-0.1.7}/biolmai.egg-info/dependency_links.txt +0 -0
  72. {biolmai-0.1.4 → biolmai-0.1.7}/biolmai.egg-info/entry_points.txt +0 -0
  73. {biolmai-0.1.4 → biolmai-0.1.7}/biolmai.egg-info/not-zip-safe +0 -0
  74. {biolmai-0.1.4 → biolmai-0.1.7}/biolmai.egg-info/requires.txt +0 -0
  75. {biolmai-0.1.4 → biolmai-0.1.7}/biolmai.egg-info/top_level.txt +0 -0
  76. {biolmai-0.1.4 → biolmai-0.1.7}/docs/Makefile +0 -0
  77. {biolmai-0.1.4 → biolmai-0.1.7}/docs/_static/biolm_docs_logo_dark.png +0 -0
  78. {biolmai-0.1.4 → biolmai-0.1.7}/docs/_static/biolm_docs_logo_light.png +0 -0
  79. {biolmai-0.1.4 → biolmai-0.1.7}/docs/biolmai.rst +0 -0
  80. {biolmai-0.1.4 → biolmai-0.1.7}/docs/make.bat +0 -0
  81. {biolmai-0.1.4 → biolmai-0.1.7}/docs/model-docs/img/book_icon.png +0 -0
  82. {biolmai-0.1.4 → biolmai-0.1.7}/docs/model-docs/img/esmfold_perf.png +0 -0
  83. {biolmai-0.1.4 → biolmai-0.1.7}/docs/modules.rst +0 -0
  84. {biolmai-0.1.4/docs/python-client → biolmai-0.1.7/docs/python-client/get_started}/installation.rst +0 -0
  85. {biolmai-0.1.4 → biolmai-0.1.7}/tests/__init__.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: biolmai
3
- Version: 0.1.4
3
+ Version: 0.1.7
4
4
  Summary: Python client and SDK for https://biolm.ai
5
5
  Home-page: https://github.com/BioLM/py-biolm
6
6
  Author: Nikhil Haas
@@ -0,0 +1,7 @@
1
+ """Top-level package for BioLM AI."""
2
+ __author__ = """Nikhil Haas"""
3
+ __email__ = "nikhil@biolm.ai"
4
+ __version__ = '0.1.7'
5
+
6
+
7
+ __all__ = []
@@ -0,0 +1,310 @@
1
+ """References to API endpoints."""
2
+ import datetime
3
+ import inspect
4
+ import time
5
+ from functools import lru_cache
6
+
7
+ import numpy as np
8
+ import pandas as pd
9
+ import requests
10
+ from requests.adapters import HTTPAdapter
11
+ from requests.packages.urllib3.util.retry import Retry
12
+
13
+ import biolmai
14
+ import biolmai.auth
15
+ from biolmai.asynch import async_api_call_wrapper
16
+ from biolmai.biolmai import log
17
+ from biolmai.const import MULTIPROCESS_THREADS
18
+ from biolmai.payloads import INST_DAT_TXT, predict_resp_many_in_one_to_many_singles
19
+
20
+
21
+ @lru_cache(maxsize=64)
22
+ def validate_endpoint_action(allowed_classes, method_name, api_class_name):
23
+ action_method_name = method_name.split(".")[-1]
24
+ if action_method_name not in allowed_classes:
25
+ err = "Only {} supported on {}"
26
+ err = err.format(list(allowed_classes), api_class_name)
27
+ raise AssertionError(err)
28
+
29
+
30
+ def text_validator(text, c):
31
+ """Validate some text against a class-based validator, returning a string
32
+ if invalid, or None otherwise."""
33
+ try:
34
+ c(text)
35
+ except Exception as e:
36
+ return str(e)
37
+
38
+
39
+ def validate(f):
40
+ def wrapper(*args, **kwargs):
41
+ # Get class instance at runtime, so you can access not just
42
+ # APIEndpoints, but any *parent* classes of that,
43
+ # like ESMFoldSinglechain.
44
+ class_obj_self = args[0]
45
+ try:
46
+ is_method = inspect.getfullargspec(f)[0][0] == "self"
47
+ except Exception:
48
+ is_method = False
49
+
50
+ # Is the function we decorated a class method?
51
+ if is_method:
52
+ name = f"{f.__module__}.{class_obj_self.__class__.__name__}.{f.__name__}"
53
+ else:
54
+ name = f"{f.__module__}.{f.__name__}"
55
+
56
+ if is_method:
57
+ # Splits name, e.g. 'biolmai.api.ESMFoldSingleChain.predict'
58
+ action_method_name = name.split(".")[-1]
59
+ validate_endpoint_action(
60
+ class_obj_self.action_class_strings,
61
+ action_method_name,
62
+ class_obj_self.__class__.__name__,
63
+ )
64
+
65
+ input_data = args[1]
66
+ # Validate each row's text/input based on class attribute `seq_classes`
67
+ for c in class_obj_self.seq_classes:
68
+ # Validate input data against regex
69
+ if class_obj_self.multiprocess_threads:
70
+ validation = input_data.text.apply(text_validator, args=(c,))
71
+ else:
72
+ validation = input_data.text.apply(text_validator, args=(c,))
73
+ if "validation" not in input_data.columns:
74
+ input_data["validation"] = validation
75
+ else:
76
+ input_data["validation"] = input_data["validation"].str.cat(
77
+ validation, sep="\n", na_rep=""
78
+ )
79
+
80
+ # Mark your batches, excluding invalid rows
81
+ valid_dat = input_data.loc[input_data.validation.isnull(), :].copy()
82
+ N = class_obj_self.batch_size # N rows will go per API request
83
+ # JOIN back, which is by index
84
+ if valid_dat.shape[0] != input_data.shape[0]:
85
+ valid_dat["batch"] = np.arange(valid_dat.shape[0]) // N
86
+ input_data = input_data.merge(
87
+ valid_dat.batch, left_index=True, right_index=True, how="left"
88
+ )
89
+ else:
90
+ input_data["batch"] = np.arange(input_data.shape[0]) // N
91
+
92
+ res = f(class_obj_self, input_data, **kwargs)
93
+ return res
94
+
95
+ return wrapper
96
+
97
+
98
+ def convert_input(f):
99
+ def wrapper(*args, **kwargs):
100
+ # Get the user-input data argument to the decorated function
101
+ # class_obj_self = args[0]
102
+ input_data = args[1]
103
+ # Make sure we have expected input types
104
+ acceptable_inputs = (str, list, tuple, np.ndarray, pd.DataFrame)
105
+ if not isinstance(input_data, acceptable_inputs):
106
+ err = "Input must be one or many DNA or protein strings"
107
+ raise ValueError(err)
108
+ # Convert single-sequence input to list
109
+ if isinstance(input_data, str):
110
+ input_data = [input_data]
111
+ # Make sure we don't have a matrix
112
+ if isinstance(input_data, np.ndarray) and len(input_data.shape) > 1:
113
+ err = "Detected Numpy matrix - input a single vector or array"
114
+ raise AssertionError(err)
115
+ # Make sure we don't have a >=2D DF
116
+ if isinstance(input_data, pd.DataFrame) and len(input_data.shape) > 1:
117
+ err = "Detected Pandas DataFrame - input a single vector or Series"
118
+ raise AssertionError(err)
119
+ input_data = pd.DataFrame(input_data, columns=["text"])
120
+ return f(args[0], input_data, **kwargs)
121
+
122
+ return wrapper
123
+
124
+
125
+ class APIEndpoint:
126
+ batch_size = 3 # Overwrite in parent classes as needed
127
+
128
+ def __init__(self, multiprocess_threads=None):
129
+ # Check for instance-specific threads, otherwise read from env var
130
+ if multiprocess_threads is not None:
131
+ self.multiprocess_threads = multiprocess_threads
132
+ else:
133
+ self.multiprocess_threads = MULTIPROCESS_THREADS # Could be False
134
+ # Get correct auth-like headers
135
+ self.auth_headers = biolmai.auth.get_user_auth_header()
136
+ self.action_class_strings = tuple(
137
+ [c.__name__.replace("Action", "").lower() for c in self.action_classes]
138
+ )
139
+
140
+ def post_batches(self, dat, slug, action, payload_maker, resp_key):
141
+ keep_batches = dat.loc[~dat.batch.isnull(), ["text", "batch"]]
142
+ if keep_batches.shape[0] == 0:
143
+ pass # Do nothing - we made nice JSON errors to return in the DF
144
+ # err = "No inputs found following local validation"
145
+ # raise AssertionError(err)
146
+ if keep_batches.shape[0] > 0:
147
+ api_resps = async_api_call_wrapper(
148
+ keep_batches, slug, action, payload_maker, resp_key
149
+ )
150
+ if isinstance(api_resps, pd.DataFrame):
151
+ batch_res = api_resps.explode("api_resp") # Should be lists of results
152
+ len_res = batch_res.shape[0]
153
+ else:
154
+ batch_res = pd.DataFrame({"api_resp": api_resps})
155
+ len_res = batch_res.shape[0]
156
+ orig_request_rows = keep_batches.shape[0]
157
+ if len_res != orig_request_rows:
158
+ err = "Response rows ({}) mismatch with input rows ({})"
159
+ err = err.format(len_res, orig_request_rows)
160
+ raise AssertionError(err)
161
+
162
+ # Stack the results horizontally w/ original rows of batches
163
+ keep_batches["prev_idx"] = keep_batches.index
164
+ keep_batches.reset_index(drop=False, inplace=True)
165
+ batch_res.reset_index(drop=True, inplace=True)
166
+ keep_batches["api_resp"] = batch_res
167
+ keep_batches.set_index("prev_idx", inplace=True)
168
+ dat = dat.join(keep_batches.reindex(["api_resp"], axis=1))
169
+ else:
170
+ dat["api_resp"] = None
171
+ return dat
172
+
173
+ def unpack_local_validations(self, dat):
174
+ dat.loc[dat.api_resp.isnull(), "api_resp"] = (
175
+ dat.loc[~dat.validation.isnull(), "validation"]
176
+ .apply(
177
+ predict_resp_many_in_one_to_many_singles, args=(None, None, True, None)
178
+ )
179
+ .explode()
180
+ )
181
+
182
+ return dat
183
+
184
+ @convert_input
185
+ @validate
186
+ def predict(self, dat):
187
+ dat = self.post_batches(dat, self.slug, "predict", INST_DAT_TXT, "predictions")
188
+ dat = self.unpack_local_validations(dat)
189
+ return dat.api_resp.replace(np.nan, None).tolist()
190
+
191
+ def infer(self, dat):
192
+ return self.predict(dat)
193
+
194
+ @convert_input
195
+ @validate
196
+ def transform(self, dat):
197
+ dat = self.post_batches(
198
+ dat, self.slug, "transform", INST_DAT_TXT, "predictions"
199
+ )
200
+ dat = self.unpack_local_validations(dat)
201
+ return dat.api_resp.replace(np.nan, None).tolist()
202
+
203
+ # @convert_input
204
+ # @validate
205
+ # def encode(self, dat):
206
+ # # NOTE: we defined this for the specific case of ESM2
207
+ # # TODO: this will be need again in v2 of API contract
208
+ # dat = self.post_batches(dat, self.slug, "transform",
209
+ # INST_DAT_TXT, "embeddings")
210
+ # dat = self.unpack_local_validations(dat)
211
+ # return dat.api_resp.replace(np.nan, None).tolist()
212
+
213
+ @convert_input
214
+ @validate
215
+ def generate(self, dat):
216
+ dat = self.post_batches(dat, self.slug, "generate", INST_DAT_TXT, "generated")
217
+ dat = self.unpack_local_validations(dat)
218
+ return dat.api_resp.replace(np.nan, None).tolist()
219
+
220
+
221
+ def retry_minutes(sess, URL, HEADERS, dat, timeout, mins):
222
+ """Retry for N minutes."""
223
+ HEADERS.update({"Content-Type": "application/json"})
224
+ attempts, max_attempts = 0, 5
225
+ try:
226
+ now = datetime.datetime.now()
227
+ try_until = now + datetime.timedelta(minutes=mins)
228
+ while datetime.datetime.now() < try_until and attempts < max_attempts:
229
+ response = None
230
+ try:
231
+ log.info(f"Trying {datetime.datetime.now()}")
232
+ response = sess.post(URL, headers=HEADERS, data=dat, timeout=timeout)
233
+ if response.status_code not in (400, 404):
234
+ response.raise_for_status()
235
+ if "error" in response.json():
236
+ raise ValueError(response.json().dumps())
237
+ else:
238
+ break
239
+ except Exception as e:
240
+ log.warning(e)
241
+ if response:
242
+ log.warning(response.text)
243
+ time.sleep(5) # Wait 5 seconds between tries
244
+ attempts += 1
245
+ if response is None:
246
+ err = "Got Nonetype response"
247
+ raise ValueError(err)
248
+ elif "Server Error" in response.text:
249
+ err = "Got Server Error"
250
+ raise ValueError(err)
251
+ except Exception:
252
+ return response
253
+ return response
254
+
255
+
256
+ def requests_retry_session(
257
+ retries=3,
258
+ backoff_factor=0.3,
259
+ status_forcelist=None,
260
+ session=None,
261
+ ):
262
+ if status_forcelist is None:
263
+ status_forcelist = list(range(400, 599))
264
+ session = session or requests.Session()
265
+ retry = Retry(
266
+ total=retries,
267
+ read=retries,
268
+ connect=retries,
269
+ backoff_factor=backoff_factor,
270
+ status_forcelist=status_forcelist,
271
+ )
272
+ adapter = HTTPAdapter(max_retries=retry)
273
+ session.mount("http://", adapter)
274
+ session.mount("https://", adapter)
275
+ return session
276
+
277
+
278
+ class PredictAction:
279
+ def __str__(self):
280
+ return "PredictAction"
281
+
282
+
283
+ class GenerateAction:
284
+ def __str__(self):
285
+ return "GenerateAction"
286
+
287
+
288
+ class TransformAction:
289
+ def __str__(self):
290
+ return "TransformAction"
291
+
292
+
293
+ # class EncodeAction:
294
+ # def __str__(self):
295
+ # return "EncodeAction"
296
+
297
+
298
+ class ExplainAction:
299
+ def __str__(self):
300
+ return "ExplainAction"
301
+
302
+
303
+ class SimilarityAction:
304
+ def __str__(self):
305
+ return "SimilarityAction"
306
+
307
+
308
+ class FinetuneAction:
309
+ def __str__(self):
310
+ return "FinetuneAction"
@@ -1,23 +1,15 @@
1
+ import asyncio
2
+ from asyncio import create_task, gather, run
3
+ from itertools import zip_longest
4
+ from typing import Dict, List
5
+
1
6
  import aiohttp.resolver
7
+ from aiohttp import ClientSession
2
8
 
9
+ from biolmai.auth import get_user_auth_header
3
10
  from biolmai.const import BASE_API_URL, MULTIPROCESS_THREADS
4
11
 
5
12
  aiohttp.resolver.DefaultResolver = aiohttp.resolver.AsyncResolver
6
- from aiohttp import ClientSession, TCPConnector
7
- from typing import List
8
- import json
9
- import asyncio
10
-
11
- from asyncio import create_task, gather, run, sleep
12
-
13
-
14
-
15
- async def get_one(session: ClientSession, slug: str, action: str,
16
- payload: dict, response_key: str):
17
- pass
18
-
19
-
20
- from aiohttp import ClientSession
21
13
 
22
14
 
23
15
  async def get_one(session: ClientSession, url: str) -> None:
@@ -30,25 +22,31 @@ async def get_one(session: ClientSession, url: str) -> None:
30
22
  return text_resp
31
23
 
32
24
 
33
- async def get_one_biolm(session: ClientSession,
34
- url: str,
35
- pload: dict,
36
- headers: dict,
37
- response_key: str = None) -> None:
25
+ async def get_one_biolm(
26
+ session: ClientSession,
27
+ url: str,
28
+ pload: dict,
29
+ headers: dict,
30
+ response_key: str = None,
31
+ ) -> None:
38
32
  print("Requesting", url)
39
- pload_batch = pload.pop('batch')
40
- pload_batch_size = pload.pop('batch_size')
33
+ pload_batch = pload.pop("batch")
34
+ pload_batch_size = pload.pop("batch_size")
41
35
  t = aiohttp.ClientTimeout(
42
- total=1200,
43
- # total timeout (time consists connection establishment for a new connection or waiting for a free connection from a pool if pool connection limits are exceeded) default value is 5 minutes, set to `None` or `0` for unlimited timeout
36
+ total=1600, # 27 mins
37
+ # total timeout (time consists connection establishment for
38
+ # a new connection or waiting for a free connection from a
39
+ # pool if pool connection limits are exceeded) default value
40
+ # is 5 minutes, set to `None` or `0` for unlimited timeout
44
41
  sock_connect=None,
45
- # Maximal number of seconds for connecting to a peer for a new connection, not given from a pool. See also connect.
42
+ # Maximal number of seconds for connecting to a peer for a
43
+ # new connection, not given from a pool. See also connect.
46
44
  sock_read=None
47
45
  # Maximal number of seconds for reading a portion of data from a peer
48
46
  )
49
47
  async with session.post(url, headers=headers, json=pload, timeout=t) as resp:
50
48
  resp_json = await resp.json()
51
- resp_json['batch'] = pload_batch
49
+ resp_json["batch"] = pload_batch
52
50
  status_code = resp.status
53
51
  expected_root_key = response_key
54
52
  to_ret = []
@@ -61,9 +59,7 @@ async def get_one_biolm(session: ClientSession,
61
59
  else:
62
60
  raise ValueError("Unexpected response in parser")
63
61
  for idx, item in enumerate(list_of_individual_seq_results):
64
- d = {'status_code': status_code,
65
- 'batch_id': pload_batch,
66
- 'batch_item': idx}
62
+ d = {"status_code": status_code, "batch_id": pload_batch, "batch_item": idx}
67
63
  if not status_code or status_code != 200:
68
64
  d.update(item) # Put all resp keys at root there
69
65
  else:
@@ -77,16 +73,15 @@ async def get_one_biolm(session: ClientSession,
77
73
  # await sleep(2) # for demo purposes
78
74
  # text_resp = text.strip().split("\n", 1)[0]
79
75
  # print("Got response from", url, text_resp)
80
- return j
81
76
 
82
77
 
83
78
  async def async_range(count):
84
79
  for i in range(count):
85
- yield(i)
80
+ yield (i)
86
81
  await asyncio.sleep(0.0)
87
82
 
88
83
 
89
- async def get_all(urls: List[str], num_concurrent: int) -> List:
84
+ async def get_all(urls: List[str], num_concurrent: int) -> list:
90
85
  url_iterator = iter(urls)
91
86
  keep_going = True
92
87
  results = []
@@ -106,22 +101,26 @@ async def get_all(urls: List[str], num_concurrent: int) -> List:
106
101
  return results
107
102
 
108
103
 
109
- async def get_all_biolm(url: str,
110
- ploads: List[dict],
111
- headers: dict,
112
- num_concurrent: int,
113
- response_key: str = None) -> List:
104
+ async def get_all_biolm(
105
+ url: str,
106
+ ploads: List[Dict],
107
+ headers: dict,
108
+ num_concurrent: int,
109
+ response_key: str = None,
110
+ ) -> list:
114
111
  ploads_iterator = iter(ploads)
115
112
  keep_going = True
116
113
  results = []
117
- connector = aiohttp.TCPConnector(limit=100,
118
- limit_per_host=50,
119
- ttl_dns_cache=60)
114
+ connector = aiohttp.TCPConnector(limit=100, limit_per_host=50, ttl_dns_cache=60)
120
115
  ov_tout = aiohttp.ClientTimeout(
121
116
  total=None,
122
- # total timeout (time consists connection establishment for a new connection or waiting for a free connection from a pool if pool connection limits are exceeded) default value is 5 minutes, set to `None` or `0` for unlimited timeout
117
+ # total timeout (time consists connection establishment for
118
+ # a new connection or waiting for a free connection from a
119
+ # pool if pool connection limits are exceeded) default value
120
+ # is 5 minutes, set to `None` or `0` for unlimited timeout
123
121
  sock_connect=None,
124
- # Maximal number of seconds for connecting to a peer for a new connection, not given from a pool. See also connect.
122
+ # Maximal number of seconds for connecting to a peer for a
123
+ # new connection, not given from a pool. See also connect.
125
124
  sock_read=None
126
125
  # Maximal number of seconds for reading a portion of data from a peer
127
126
  )
@@ -134,35 +133,31 @@ async def get_all_biolm(url: str,
134
133
  except StopIteration:
135
134
  keep_going = False
136
135
  break
137
- new_task = create_task(get_one_biolm(session, url, pload,
138
- headers, response_key))
136
+ new_task = create_task(
137
+ get_one_biolm(session, url, pload, headers, response_key)
138
+ )
139
139
  tasks.append(new_task)
140
140
  res = await gather(*tasks)
141
141
  results.extend(res)
142
142
  return results
143
143
 
144
144
 
145
- async def async_main(urls, concurrency) -> List:
145
+ async def async_main(urls, concurrency) -> list:
146
146
  return await get_all(urls, concurrency)
147
147
 
148
148
 
149
- async def async_api_calls(model_name,
150
- action,
151
- headers,
152
- payloads,
153
- response_key=None):
149
+ async def async_api_calls(model_name, action, headers, payloads, response_key=None):
154
150
  """Hit an arbitrary BioLM model inference API."""
155
151
  # Normally would POST multiple sequences at once for greater efficiency,
156
152
  # but for simplicity sake will do one at at time right now
157
- url = f'{BASE_API_URL}/models/{model_name}/{action}/'
153
+ url = f"{BASE_API_URL}/models/{model_name}/{action}/"
158
154
 
159
155
  if not isinstance(payloads, (list, dict)):
160
156
  err = "API request payload must be a list or dict, got {}"
161
157
  raise AssertionError(err.format(type(payloads)))
162
158
 
163
159
  concurrency = int(MULTIPROCESS_THREADS)
164
- return await get_all_biolm(url, payloads, headers, concurrency,
165
- response_key)
160
+ return await get_all_biolm(url, payloads, headers, concurrency, response_key)
166
161
 
167
162
  # payload = json.dumps(payload)
168
163
  # session = requests_retry_session()
@@ -183,3 +178,45 @@ async def async_api_calls(model_name,
183
178
  # headers = get_user_auth_header() # Need to re-get these now
184
179
  # response = retry_minutes(session, url, headers, payload, tout, mins=10)
185
180
  # return response
181
+
182
+
183
+ def async_api_call_wrapper(grouped_df, slug, action, payload_maker, response_key):
184
+ """Wrap API calls to assist with sequence validation as a pre-cursor to
185
+ each API call.
186
+ """
187
+ model_name = slug
188
+ # payload = payload_maker(grouped_df)
189
+ init_ploads = grouped_df.groupby("batch").apply(
190
+ payload_maker, include_batch_size=True
191
+ )
192
+ ploads = init_ploads.to_list()
193
+ init_ploads = init_ploads.to_frame(name="pload")
194
+ init_ploads["batch"] = init_ploads.index
195
+ init_ploads = init_ploads.reset_index(drop=True)
196
+ assert len(ploads) == init_ploads.shape[0]
197
+ for inst, b in zip_longest(ploads, init_ploads["batch"].to_list()):
198
+ if inst is None or b is None:
199
+ raise ValueError(
200
+ "ploads and init_ploads['batch'] are not of the same length"
201
+ )
202
+ inst["batch"] = b
203
+
204
+ headers = get_user_auth_header() # Need to pull each time
205
+ # urls = [
206
+ # "https://github.com",
207
+ # "https://stackoverflow.com",
208
+ # "https://python.org",
209
+ # ]
210
+ # concurrency = 3
211
+ api_resp = run(async_api_calls(model_name, action, headers, ploads, response_key))
212
+ api_resp = [item for sublist in api_resp for item in sublist]
213
+ api_resp = sorted(api_resp, key=lambda x: x["batch_id"])
214
+ # print(api_resp)
215
+ # api_resp = biolmai.api_call(model_name, action, headers, payload,
216
+ # response_key)
217
+ # resp_json = api_resp.json()
218
+ # batch_id = int(grouped_df.batch.iloc[0])
219
+ # batch_size = grouped_df.shape[0]
220
+ # response = predict_resp_many_in_one_to_many_singles(
221
+ # resp_json, api_resp.status_code, batch_id, None, batch_size)
222
+ return api_resp