biolmai 0.1.5__py2.py3-none-any.whl → 0.1.8__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of biolmai might be problematic. Click here for more details.

biolmai/__init__.py CHANGED
@@ -1,11 +1,7 @@
1
1
  """Top-level package for BioLM AI."""
2
2
  __author__ = """Nikhil Haas"""
3
- __email__ = 'nikhil@biolm.ai'
4
- __version__ = '0.1.5'
3
+ __email__ = "nikhil@biolm.ai"
4
+ __version__ = '0.1.8'
5
5
 
6
- from biolmai.auth import get_api_token
7
- from biolmai.cls import ESMFoldSingleChain, ESMFoldMultiChain, ESM2Embeddings, ESM1v1, ESM1v2, ESM1v3, ESM1v4, ESM1v5
8
6
 
9
- __all__ = [
10
-
11
- ]
7
+ __all__ = []
biolmai/api.py CHANGED
@@ -1,33 +1,29 @@
1
1
  """References to API endpoints."""
2
2
  import datetime
3
+ import inspect
3
4
  import time
5
+ from functools import lru_cache
4
6
 
7
+ import numpy as np
8
+ import pandas as pd
5
9
  import requests
6
10
  from requests.adapters import HTTPAdapter
11
+ from requests.packages.urllib3.util.retry import Retry
7
12
 
8
- import biolmai.auth
9
13
  import biolmai
10
- import inspect
11
- import pandas as pd
12
- import numpy as np
14
+ import biolmai.auth
13
15
  from biolmai.asynch import async_api_call_wrapper
14
-
15
16
  from biolmai.biolmai import log
16
17
  from biolmai.const import MULTIPROCESS_THREADS
17
- from functools import lru_cache
18
-
19
18
  from biolmai.payloads import INST_DAT_TXT, predict_resp_many_in_one_to_many_singles
20
19
 
21
20
 
22
21
  @lru_cache(maxsize=64)
23
22
  def validate_endpoint_action(allowed_classes, method_name, api_class_name):
24
- action_method_name = method_name.split('.')[-1]
23
+ action_method_name = method_name.split(".")[-1]
25
24
  if action_method_name not in allowed_classes:
26
- err = 'Only {} supported on {}'
27
- err = err.format(
28
- list(allowed_classes),
29
- api_class_name
30
- )
25
+ err = "Only {} supported on {}"
26
+ err = err.format(list(allowed_classes), api_class_name)
31
27
  raise AssertionError(err)
32
28
 
33
29
 
@@ -47,25 +43,23 @@ def validate(f):
47
43
  # like ESMFoldSinglechain.
48
44
  class_obj_self = args[0]
49
45
  try:
50
- is_method = inspect.getfullargspec(f)[0][0] == 'self'
51
- except:
46
+ is_method = inspect.getfullargspec(f)[0][0] == "self"
47
+ except Exception:
52
48
  is_method = False
53
49
 
54
50
  # Is the function we decorated a class method?
55
51
  if is_method:
56
- name = '{}.{}.{}'.format(f.__module__,
57
- class_obj_self.__class__.__name__,
58
- f.__name__)
52
+ name = f"{f.__module__}.{class_obj_self.__class__.__name__}.{f.__name__}"
59
53
  else:
60
- name = '{}.{}'.format(f.__module__, f.__name__)
54
+ name = f"{f.__module__}.{f.__name__}"
61
55
 
62
56
  if is_method:
63
57
  # Splits name, e.g. 'biolmai.api.ESMFoldSingleChain.predict'
64
- action_method_name = name.split('.')[-1]
58
+ action_method_name = name.split(".")[-1]
65
59
  validate_endpoint_action(
66
60
  class_obj_self.action_class_strings,
67
61
  action_method_name,
68
- class_obj_self.__class__.__name__
62
+ class_obj_self.__class__.__name__,
69
63
  )
70
64
 
71
65
  input_data = args[1]
@@ -73,35 +67,38 @@ def validate(f):
73
67
  for c in class_obj_self.seq_classes:
74
68
  # Validate input data against regex
75
69
  if class_obj_self.multiprocess_threads:
76
- validation = input_data.text.apply(text_validator, args=(c, ))
70
+ validation = input_data.text.apply(text_validator, args=(c,))
77
71
  else:
78
- validation = input_data.text.apply(text_validator, args=(c, ))
79
- if 'validation' not in input_data.columns:
80
- input_data['validation'] = validation
72
+ validation = input_data.text.apply(text_validator, args=(c,))
73
+ if "validation" not in input_data.columns:
74
+ input_data["validation"] = validation
81
75
  else:
82
- input_data['validation'] = input_data['validation'].str.cat(
83
- validation, sep='\n', na_rep='')
76
+ input_data["validation"] = input_data["validation"].str.cat(
77
+ validation, sep="\n", na_rep=""
78
+ )
84
79
 
85
80
  # Mark your batches, excluding invalid rows
86
81
  valid_dat = input_data.loc[input_data.validation.isnull(), :].copy()
87
82
  N = class_obj_self.batch_size # N rows will go per API request
88
83
  # JOIN back, which is by index
89
84
  if valid_dat.shape[0] != input_data.shape[0]:
90
- valid_dat['batch'] = np.arange(valid_dat.shape[0])//N
85
+ valid_dat["batch"] = np.arange(valid_dat.shape[0]) // N
91
86
  input_data = input_data.merge(
92
- valid_dat.batch, left_index=True, right_index=True, how='left')
87
+ valid_dat.batch, left_index=True, right_index=True, how="left"
88
+ )
93
89
  else:
94
- input_data['batch'] = np.arange(input_data.shape[0])//N
90
+ input_data["batch"] = np.arange(input_data.shape[0]) // N
95
91
 
96
92
  res = f(class_obj_self, input_data, **kwargs)
97
93
  return res
94
+
98
95
  return wrapper
99
96
 
100
97
 
101
98
  def convert_input(f):
102
99
  def wrapper(*args, **kwargs):
103
100
  # Get the user-input data argument to the decorated function
104
- class_obj_self = args[0]
101
+ # class_obj_self = args[0]
105
102
  input_data = args[1]
106
103
  # Make sure we have expected input types
107
104
  acceptable_inputs = (str, list, tuple, np.ndarray, pd.DataFrame)
@@ -119,12 +116,13 @@ def convert_input(f):
119
116
  if isinstance(input_data, pd.DataFrame) and len(input_data.shape) > 1:
120
117
  err = "Detected Pandas DataFrame - input a single vector or Series"
121
118
  raise AssertionError(err)
122
- input_data = pd.DataFrame(input_data, columns=['text'])
119
+ input_data = pd.DataFrame(input_data, columns=["text"])
123
120
  return f(args[0], input_data, **kwargs)
121
+
124
122
  return wrapper
125
123
 
126
124
 
127
- class APIEndpoint(object):
125
+ class APIEndpoint:
128
126
  batch_size = 3 # Overwrite in parent classes as needed
129
127
 
130
128
  def __init__(self, multiprocess_threads=None):
@@ -135,29 +133,25 @@ class APIEndpoint(object):
135
133
  self.multiprocess_threads = MULTIPROCESS_THREADS # Could be False
136
134
  # Get correct auth-like headers
137
135
  self.auth_headers = biolmai.auth.get_user_auth_header()
138
- self.action_class_strings = tuple([
139
- c.__name__.replace('Action', '').lower() for c in self.action_classes
140
- ])
136
+ self.action_class_strings = tuple(
137
+ [c.__name__.replace("Action", "").lower() for c in self.action_classes]
138
+ )
141
139
 
142
140
  def post_batches(self, dat, slug, action, payload_maker, resp_key):
143
- keep_batches = dat.loc[~dat.batch.isnull(), ['text', 'batch']]
141
+ keep_batches = dat.loc[~dat.batch.isnull(), ["text", "batch"]]
144
142
  if keep_batches.shape[0] == 0:
145
143
  pass # Do nothing - we made nice JSON errors to return in the DF
146
144
  # err = "No inputs found following local validation"
147
145
  # raise AssertionError(err)
148
146
  if keep_batches.shape[0] > 0:
149
147
  api_resps = async_api_call_wrapper(
150
- keep_batches,
151
- slug,
152
- action,
153
- payload_maker,
154
- resp_key
148
+ keep_batches, slug, action, payload_maker, resp_key
155
149
  )
156
150
  if isinstance(api_resps, pd.DataFrame):
157
- batch_res = api_resps.explode('api_resp') # Should be lists of results
151
+ batch_res = api_resps.explode("api_resp") # Should be lists of results
158
152
  len_res = batch_res.shape[0]
159
153
  else:
160
- batch_res = pd.DataFrame({'api_resp': api_resps})
154
+ batch_res = pd.DataFrame({"api_resp": api_resps})
161
155
  len_res = batch_res.shape[0]
162
156
  orig_request_rows = keep_batches.shape[0]
163
157
  if len_res != orig_request_rows:
@@ -166,29 +160,31 @@ class APIEndpoint(object):
166
160
  raise AssertionError(err)
167
161
 
168
162
  # Stack the results horizontally w/ original rows of batches
169
- keep_batches['prev_idx'] = keep_batches.index
163
+ keep_batches["prev_idx"] = keep_batches.index
170
164
  keep_batches.reset_index(drop=False, inplace=True)
171
165
  batch_res.reset_index(drop=True, inplace=True)
172
- keep_batches['api_resp'] = batch_res
173
- keep_batches.set_index('prev_idx', inplace=True)
174
- dat = dat.join(keep_batches.reindex(['api_resp'], axis=1))
166
+ keep_batches["api_resp"] = batch_res
167
+ keep_batches.set_index("prev_idx", inplace=True)
168
+ dat = dat.join(keep_batches.reindex(["api_resp"], axis=1))
175
169
  else:
176
- dat['api_resp'] = None
170
+ dat["api_resp"] = None
177
171
  return dat
178
172
 
179
173
  def unpack_local_validations(self, dat):
180
- dat.loc[
181
- dat.api_resp.isnull(), 'api_resp'
182
- ] = dat.loc[~dat.validation.isnull(), 'validation'].apply(
183
- predict_resp_many_in_one_to_many_singles,
184
- args=(None, None, True, None)).explode()
174
+ dat.loc[dat.api_resp.isnull(), "api_resp"] = (
175
+ dat.loc[~dat.validation.isnull(), "validation"]
176
+ .apply(
177
+ predict_resp_many_in_one_to_many_singles, args=(None, None, True, None)
178
+ )
179
+ .explode()
180
+ )
185
181
 
186
182
  return dat
187
183
 
188
184
  @convert_input
189
185
  @validate
190
186
  def predict(self, dat):
191
- dat = self.post_batches(dat, self.slug, 'predict', INST_DAT_TXT, 'predictions')
187
+ dat = self.post_batches(dat, self.slug, "predict", INST_DAT_TXT, "predictions")
192
188
  dat = self.unpack_local_validations(dat)
193
189
  return dat.api_resp.replace(np.nan, None).tolist()
194
190
 
@@ -198,21 +194,33 @@ class APIEndpoint(object):
198
194
  @convert_input
199
195
  @validate
200
196
  def transform(self, dat):
201
- dat = self.post_batches(dat, self.slug, 'transform', INST_DAT_TXT, 'predictions')
197
+ dat = self.post_batches(
198
+ dat, self.slug, "transform", INST_DAT_TXT, "predictions"
199
+ )
202
200
  dat = self.unpack_local_validations(dat)
203
201
  return dat.api_resp.replace(np.nan, None).tolist()
204
202
 
203
+ # @convert_input
204
+ # @validate
205
+ # def encode(self, dat):
206
+ # # NOTE: we defined this for the specific case of ESM2
207
+ # # TODO: this will be need again in v2 of API contract
208
+ # dat = self.post_batches(dat, self.slug, "transform",
209
+ # INST_DAT_TXT, "embeddings")
210
+ # dat = self.unpack_local_validations(dat)
211
+ # return dat.api_resp.replace(np.nan, None).tolist()
212
+
205
213
  @convert_input
206
214
  @validate
207
215
  def generate(self, dat):
208
- dat = self.post_batches(dat, self.slug, 'generate', INST_DAT_TXT, 'generated')
216
+ dat = self.post_batches(dat, self.slug, "generate", INST_DAT_TXT, "generated")
209
217
  dat = self.unpack_local_validations(dat)
210
218
  return dat.api_resp.replace(np.nan, None).tolist()
211
219
 
212
220
 
213
221
  def retry_minutes(sess, URL, HEADERS, dat, timeout, mins):
214
222
  """Retry for N minutes."""
215
- HEADERS.update({'Content-Type': 'application/json'})
223
+ HEADERS.update({"Content-Type": "application/json"})
216
224
  attempts, max_attempts = 0, 5
217
225
  try:
218
226
  now = datetime.datetime.now()
@@ -220,16 +228,11 @@ def retry_minutes(sess, URL, HEADERS, dat, timeout, mins):
220
228
  while datetime.datetime.now() < try_until and attempts < max_attempts:
221
229
  response = None
222
230
  try:
223
- log.info('Trying {}'.format(datetime.datetime.now()))
224
- response = sess.post(
225
- URL,
226
- headers=HEADERS,
227
- data=dat,
228
- timeout=timeout
229
- )
231
+ log.info(f"Trying {datetime.datetime.now()}")
232
+ response = sess.post(URL, headers=HEADERS, data=dat, timeout=timeout)
230
233
  if response.status_code not in (400, 404):
231
234
  response.raise_for_status()
232
- if 'error' in response.json():
235
+ if "error" in response.json():
233
236
  raise ValueError(response.json().dumps())
234
237
  else:
235
238
  break
@@ -242,10 +245,10 @@ def retry_minutes(sess, URL, HEADERS, dat, timeout, mins):
242
245
  if response is None:
243
246
  err = "Got Nonetype response"
244
247
  raise ValueError(err)
245
- elif 'Server Error' in response.text:
248
+ elif "Server Error" in response.text:
246
249
  err = "Got Server Error"
247
250
  raise ValueError(err)
248
- except Exception as e:
251
+ except Exception:
249
252
  return response
250
253
  return response
251
254
 
@@ -253,54 +256,55 @@ def retry_minutes(sess, URL, HEADERS, dat, timeout, mins):
253
256
  def requests_retry_session(
254
257
  retries=3,
255
258
  backoff_factor=0.3,
256
- status_forcelist=list(range(400, 599)),
259
+ status_forcelist=None,
257
260
  session=None,
258
261
  ):
262
+ if status_forcelist is None:
263
+ status_forcelist = list(range(400, 599))
259
264
  session = session or requests.Session()
260
265
  retry = Retry(
261
266
  total=retries,
262
267
  read=retries,
263
268
  connect=retries,
264
269
  backoff_factor=backoff_factor,
265
- status_forcelist=status_forcelist
270
+ status_forcelist=status_forcelist,
266
271
  )
267
272
  adapter = HTTPAdapter(max_retries=retry)
268
- session.mount('http://', adapter)
269
- session.mount('https://', adapter)
273
+ session.mount("http://", adapter)
274
+ session.mount("https://", adapter)
270
275
  return session
271
276
 
272
277
 
273
- class PredictAction(object):
274
-
278
+ class PredictAction:
275
279
  def __str__(self):
276
- return 'PredictAction'
280
+ return "PredictAction"
277
281
 
278
282
 
279
- class GenerateAction(object):
280
-
283
+ class GenerateAction:
281
284
  def __str__(self):
282
- return 'GenerateAction'
283
-
285
+ return "GenerateAction"
284
286
 
285
- class TransformAction(object):
286
287
 
288
+ class TransformAction:
287
289
  def __str__(self):
288
- return 'TransformAction'
290
+ return "TransformAction"
289
291
 
290
292
 
291
- class ExplainAction(object):
293
+ # class EncodeAction:
294
+ # def __str__(self):
295
+ # return "EncodeAction"
292
296
 
293
- def __str__(self):
294
- return 'ExplainAction'
295
297
 
298
+ class ExplainAction:
299
+ def __str__(self):
300
+ return "ExplainAction"
296
301
 
297
- class SimilarityAction(object):
298
302
 
303
+ class SimilarityAction:
299
304
  def __str__(self):
300
- return 'SimilarityAction'
301
-
305
+ return "SimilarityAction"
302
306
 
303
- class FinetuneAction(object):
304
307
 
308
+ class FinetuneAction:
305
309
  def __str__(self):
306
- return 'FinetuneAction'
310
+ return "FinetuneAction"
biolmai/asynch.py CHANGED
@@ -1,24 +1,15 @@
1
+ import asyncio
2
+ from asyncio import create_task, gather, run
3
+ from itertools import zip_longest
4
+ from typing import Dict, List
5
+
1
6
  import aiohttp.resolver
7
+ from aiohttp import ClientSession
2
8
 
3
9
  from biolmai.auth import get_user_auth_header
4
10
  from biolmai.const import BASE_API_URL, MULTIPROCESS_THREADS
5
11
 
6
12
  aiohttp.resolver.DefaultResolver = aiohttp.resolver.AsyncResolver
7
- from aiohttp import ClientSession, TCPConnector
8
- from typing import List
9
- import json
10
- import asyncio
11
-
12
- from asyncio import create_task, gather, run, sleep
13
-
14
-
15
-
16
- async def get_one(session: ClientSession, slug: str, action: str,
17
- payload: dict, response_key: str):
18
- pass
19
-
20
-
21
- from aiohttp import ClientSession
22
13
 
23
14
 
24
15
  async def get_one(session: ClientSession, url: str) -> None:
@@ -31,25 +22,31 @@ async def get_one(session: ClientSession, url: str) -> None:
31
22
  return text_resp
32
23
 
33
24
 
34
- async def get_one_biolm(session: ClientSession,
35
- url: str,
36
- pload: dict,
37
- headers: dict,
38
- response_key: str = None) -> None:
25
+ async def get_one_biolm(
26
+ session: ClientSession,
27
+ url: str,
28
+ pload: dict,
29
+ headers: dict,
30
+ response_key: str = None,
31
+ ) -> None:
39
32
  print("Requesting", url)
40
- pload_batch = pload.pop('batch')
41
- pload_batch_size = pload.pop('batch_size')
33
+ pload_batch = pload.pop("batch")
34
+ pload_batch_size = pload.pop("batch_size")
42
35
  t = aiohttp.ClientTimeout(
43
36
  total=1600, # 27 mins
44
- # total timeout (time consists connection establishment for a new connection or waiting for a free connection from a pool if pool connection limits are exceeded) default value is 5 minutes, set to `None` or `0` for unlimited timeout
37
+ # total timeout (time consists connection establishment for
38
+ # a new connection or waiting for a free connection from a
39
+ # pool if pool connection limits are exceeded) default value
40
+ # is 5 minutes, set to `None` or `0` for unlimited timeout
45
41
  sock_connect=None,
46
- # Maximal number of seconds for connecting to a peer for a new connection, not given from a pool. See also connect.
42
+ # Maximal number of seconds for connecting to a peer for a
43
+ # new connection, not given from a pool. See also connect.
47
44
  sock_read=None
48
45
  # Maximal number of seconds for reading a portion of data from a peer
49
46
  )
50
47
  async with session.post(url, headers=headers, json=pload, timeout=t) as resp:
51
48
  resp_json = await resp.json()
52
- resp_json['batch'] = pload_batch
49
+ resp_json["batch"] = pload_batch
53
50
  status_code = resp.status
54
51
  expected_root_key = response_key
55
52
  to_ret = []
@@ -62,9 +59,7 @@ async def get_one_biolm(session: ClientSession,
62
59
  else:
63
60
  raise ValueError("Unexpected response in parser")
64
61
  for idx, item in enumerate(list_of_individual_seq_results):
65
- d = {'status_code': status_code,
66
- 'batch_id': pload_batch,
67
- 'batch_item': idx}
62
+ d = {"status_code": status_code, "batch_id": pload_batch, "batch_item": idx}
68
63
  if not status_code or status_code != 200:
69
64
  d.update(item) # Put all resp keys at root there
70
65
  else:
@@ -78,16 +73,15 @@ async def get_one_biolm(session: ClientSession,
78
73
  # await sleep(2) # for demo purposes
79
74
  # text_resp = text.strip().split("\n", 1)[0]
80
75
  # print("Got response from", url, text_resp)
81
- return j
82
76
 
83
77
 
84
78
  async def async_range(count):
85
79
  for i in range(count):
86
- yield(i)
80
+ yield (i)
87
81
  await asyncio.sleep(0.0)
88
82
 
89
83
 
90
- async def get_all(urls: List[str], num_concurrent: int) -> List:
84
+ async def get_all(urls: List[str], num_concurrent: int) -> list:
91
85
  url_iterator = iter(urls)
92
86
  keep_going = True
93
87
  results = []
@@ -107,22 +101,26 @@ async def get_all(urls: List[str], num_concurrent: int) -> List:
107
101
  return results
108
102
 
109
103
 
110
- async def get_all_biolm(url: str,
111
- ploads: List[dict],
112
- headers: dict,
113
- num_concurrent: int,
114
- response_key: str = None) -> List:
104
+ async def get_all_biolm(
105
+ url: str,
106
+ ploads: List[Dict],
107
+ headers: dict,
108
+ num_concurrent: int,
109
+ response_key: str = None,
110
+ ) -> list:
115
111
  ploads_iterator = iter(ploads)
116
112
  keep_going = True
117
113
  results = []
118
- connector = aiohttp.TCPConnector(limit=100,
119
- limit_per_host=50,
120
- ttl_dns_cache=60)
114
+ connector = aiohttp.TCPConnector(limit=100, limit_per_host=50, ttl_dns_cache=60)
121
115
  ov_tout = aiohttp.ClientTimeout(
122
116
  total=None,
123
- # total timeout (time consists connection establishment for a new connection or waiting for a free connection from a pool if pool connection limits are exceeded) default value is 5 minutes, set to `None` or `0` for unlimited timeout
117
+ # total timeout (time consists connection establishment for
118
+ # a new connection or waiting for a free connection from a
119
+ # pool if pool connection limits are exceeded) default value
120
+ # is 5 minutes, set to `None` or `0` for unlimited timeout
124
121
  sock_connect=None,
125
- # Maximal number of seconds for connecting to a peer for a new connection, not given from a pool. See also connect.
122
+ # Maximal number of seconds for connecting to a peer for a
123
+ # new connection, not given from a pool. See also connect.
126
124
  sock_read=None
127
125
  # Maximal number of seconds for reading a portion of data from a peer
128
126
  )
@@ -135,35 +133,31 @@ async def get_all_biolm(url: str,
135
133
  except StopIteration:
136
134
  keep_going = False
137
135
  break
138
- new_task = create_task(get_one_biolm(session, url, pload,
139
- headers, response_key))
136
+ new_task = create_task(
137
+ get_one_biolm(session, url, pload, headers, response_key)
138
+ )
140
139
  tasks.append(new_task)
141
140
  res = await gather(*tasks)
142
141
  results.extend(res)
143
142
  return results
144
143
 
145
144
 
146
- async def async_main(urls, concurrency) -> List:
145
+ async def async_main(urls, concurrency) -> list:
147
146
  return await get_all(urls, concurrency)
148
147
 
149
148
 
150
- async def async_api_calls(model_name,
151
- action,
152
- headers,
153
- payloads,
154
- response_key=None):
149
+ async def async_api_calls(model_name, action, headers, payloads, response_key=None):
155
150
  """Hit an arbitrary BioLM model inference API."""
156
151
  # Normally would POST multiple sequences at once for greater efficiency,
157
152
  # but for simplicity sake will do one at at time right now
158
- url = f'{BASE_API_URL}/models/{model_name}/{action}/'
153
+ url = f"{BASE_API_URL}/models/{model_name}/{action}/"
159
154
 
160
155
  if not isinstance(payloads, (list, dict)):
161
156
  err = "API request payload must be a list or dict, got {}"
162
157
  raise AssertionError(err.format(type(payloads)))
163
158
 
164
159
  concurrency = int(MULTIPROCESS_THREADS)
165
- return await get_all_biolm(url, payloads, headers, concurrency,
166
- response_key)
160
+ return await get_all_biolm(url, payloads, headers, concurrency, response_key)
167
161
 
168
162
  # payload = json.dumps(payload)
169
163
  # session = requests_retry_session()
@@ -186,33 +180,37 @@ async def async_api_calls(model_name,
186
180
  # return response
187
181
 
188
182
 
189
- def async_api_call_wrapper(grouped_df, slug, action, payload_maker,
190
- response_key):
183
+ def async_api_call_wrapper(grouped_df, slug, action, payload_maker, response_key):
191
184
  """Wrap API calls to assist with sequence validation as a pre-cursor to
192
185
  each API call.
193
186
  """
194
187
  model_name = slug
195
188
  # payload = payload_maker(grouped_df)
196
- init_ploads = grouped_df.groupby('batch').apply(payload_maker, include_batch_size=True)
189
+ init_ploads = grouped_df.groupby("batch").apply(
190
+ payload_maker, include_batch_size=True
191
+ )
197
192
  ploads = init_ploads.to_list()
198
- init_ploads = init_ploads.to_frame(name='pload')
199
- init_ploads['batch'] = init_ploads.index
193
+ init_ploads = init_ploads.to_frame(name="pload")
194
+ init_ploads["batch"] = init_ploads.index
200
195
  init_ploads = init_ploads.reset_index(drop=True)
201
196
  assert len(ploads) == init_ploads.shape[0]
202
- for inst, b in zip(ploads, init_ploads['batch'].to_list()):
203
- inst['batch'] = b
197
+ for inst, b in zip_longest(ploads, init_ploads["batch"].to_list()):
198
+ if inst is None or b is None:
199
+ raise ValueError(
200
+ "ploads and init_ploads['batch'] are not of the same length"
201
+ )
202
+ inst["batch"] = b
204
203
 
205
204
  headers = get_user_auth_header() # Need to pull each time
206
- urls = [
207
- "https://github.com",
208
- "https://stackoverflow.com",
209
- "https://python.org",
210
- ]
205
+ # urls = [
206
+ # "https://github.com",
207
+ # "https://stackoverflow.com",
208
+ # "https://python.org",
209
+ # ]
211
210
  # concurrency = 3
212
- api_resp = run(async_api_calls(model_name, action, headers,
213
- ploads, response_key))
211
+ api_resp = run(async_api_calls(model_name, action, headers, ploads, response_key))
214
212
  api_resp = [item for sublist in api_resp for item in sublist]
215
- api_resp = sorted(api_resp, key=lambda x: x['batch_id'])
213
+ api_resp = sorted(api_resp, key=lambda x: x["batch_id"])
216
214
  # print(api_resp)
217
215
  # api_resp = biolmai.api_call(model_name, action, headers, payload,
218
216
  # response_key)