biolmai 0.1.8__tar.gz → 0.1.10__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of biolmai might be problematic. Click here for more details.

Files changed (97) hide show
  1. {biolmai-0.1.8 → biolmai-0.1.10}/PKG-INFO +2 -1
  2. {biolmai-0.1.8 → biolmai-0.1.10}/biolmai/__init__.py +2 -1
  3. {biolmai-0.1.8 → biolmai-0.1.10}/biolmai/api.py +121 -84
  4. {biolmai-0.1.8 → biolmai-0.1.10}/biolmai/asynch.py +16 -8
  5. biolmai-0.1.10/biolmai/cls.py +176 -0
  6. {biolmai-0.1.8 → biolmai-0.1.10}/biolmai/const.py +2 -1
  7. {biolmai-0.1.8 → biolmai-0.1.10}/biolmai/payloads.py +13 -2
  8. biolmai-0.1.10/biolmai/validate.py +159 -0
  9. {biolmai-0.1.8 → biolmai-0.1.10}/biolmai.egg-info/PKG-INFO +2 -1
  10. biolmai-0.1.10/biolmai.egg-info/SOURCES.txt +83 -0
  11. {biolmai-0.1.8 → biolmai-0.1.10}/biolmai.egg-info/requires.txt +3 -0
  12. {biolmai-0.1.8 → biolmai-0.1.10}/docs/conf.py +1 -1
  13. {biolmai-0.1.8 → biolmai-0.1.10}/docs/index.rst +11 -8
  14. biolmai-0.1.10/docs/model-docs/ablang/AbLang_API.rst +200 -0
  15. biolmai-0.1.10/docs/model-docs/ablang/AbLang_Additional.rst +94 -0
  16. biolmai-0.1.10/docs/model-docs/ablang/index.rst +10 -0
  17. biolmai-0.1.10/docs/model-docs/biolmtox/BioLMTox_API.rst +293 -0
  18. biolmai-0.1.10/docs/model-docs/biolmtox/BioLMTox_Additional.rst +62 -0
  19. biolmai-0.1.10/docs/model-docs/biolmtox/index.rst +10 -0
  20. biolmai-0.1.10/docs/model-docs/dnabert/DNABERT_Additional.rst +59 -0
  21. biolmai-0.1.8/docs/model-docs/DNABERT.rst → biolmai-0.1.10/docs/model-docs/dnabert/classifier_ft.rst +34 -67
  22. biolmai-0.1.10/docs/model-docs/dnabert/index.rst +10 -0
  23. biolmai-0.1.10/docs/model-docs/esm1v/ESM-1v_API.rst +196 -0
  24. biolmai-0.1.10/docs/model-docs/esm1v/ESM-1v_Additional.rst +89 -0
  25. biolmai-0.1.10/docs/model-docs/esm1v/index.rst +10 -0
  26. biolmai-0.1.10/docs/model-docs/esm2/ESM2_API.rst +450 -0
  27. biolmai-0.1.10/docs/model-docs/esm2/ESM2_Additional.rst +99 -0
  28. biolmai-0.1.10/docs/model-docs/esm2/index.rst +10 -0
  29. biolmai-0.1.10/docs/model-docs/esmfold/ESMFold_API.rst +166 -0
  30. biolmai-0.1.10/docs/model-docs/esmfold/ESMFold_Additional.rst +108 -0
  31. biolmai-0.1.10/docs/model-docs/esmfold/index.rst +10 -0
  32. biolmai-0.1.8/docs/model-docs/ESM_InverseFold.rst → biolmai-0.1.10/docs/model-docs/esmif/ESM_InverseFold_API.rst +97 -163
  33. biolmai-0.1.10/docs/model-docs/esmif/ESM_InverseFold_Additional.rst +104 -0
  34. biolmai-0.1.10/docs/model-docs/esmif/index.rst +10 -0
  35. biolmai-0.1.10/docs/model-docs/finetuning/index.rst +11 -0
  36. biolmai-0.1.10/docs/model-docs/progen2/ProGen2_API.rst +222 -0
  37. biolmai-0.1.8/docs/model-docs/progen2/ProGen2_OAS.rst → biolmai-0.1.10/docs/model-docs/progen2/ProGen2_Additional.rst +20 -196
  38. biolmai-0.1.10/docs/model-docs/proteinfer/ProteInfer_API.rst +181 -0
  39. biolmai-0.1.10/docs/model-docs/proteinfer/ProteInfer_Additional.rst +92 -0
  40. biolmai-0.1.10/docs/model-docs/proteinfer/index.rst +10 -0
  41. biolmai-0.1.10/docs/model-docs/protgpt2/ProtGPT2.rst +75 -0
  42. biolmai-0.1.8/docs/model-docs/ProtGPT2.rst → biolmai-0.1.10/docs/model-docs/protgpt2/generator_ft.rst +32 -68
  43. biolmai-0.1.10/docs/model-docs/protgpt2/index.rst +10 -0
  44. biolmai-0.1.10/docs/python-client/get_started/authentication.rst +82 -0
  45. {biolmai-0.1.8 → biolmai-0.1.10}/docs/python-client/get_started/installation.rst +2 -2
  46. biolmai-0.1.10/docs/python-client/get_started/quickstart.rst +25 -0
  47. {biolmai-0.1.8 → biolmai-0.1.10}/docs/python-client/index.rst +1 -1
  48. {biolmai-0.1.8 → biolmai-0.1.10}/setup.cfg +1 -1
  49. {biolmai-0.1.8 → biolmai-0.1.10}/setup.py +4 -1
  50. biolmai-0.1.10/tests/test_biolmai.py +498 -0
  51. biolmai-0.1.8/biolmai/cls.py +0 -97
  52. biolmai-0.1.8/biolmai/validate.py +0 -134
  53. biolmai-0.1.8/biolmai.egg-info/SOURCES.txt +0 -64
  54. biolmai-0.1.8/docs/model-docs/ESM-1v.rst +0 -362
  55. biolmai-0.1.8/docs/model-docs/ESM2_Embeddings.rst +0 -242
  56. biolmai-0.1.8/docs/model-docs/ESMFold.rst +0 -252
  57. biolmai-0.1.8/docs/model-docs/ProteInfer_EC.rst +0 -249
  58. biolmai-0.1.8/docs/model-docs/ProteInfer_GO.rst +0 -329
  59. biolmai-0.1.8/docs/model-docs/progen2/ProGen2_BFD90.rst +0 -251
  60. biolmai-0.1.8/docs/model-docs/progen2/ProGen2_Medium.rst +0 -248
  61. biolmai-0.1.8/docs/python-client/get_started/authorization.rst +0 -9
  62. biolmai-0.1.8/docs/python-client/get_started/quickstart.rst +0 -15
  63. biolmai-0.1.8/tests/test_biolmai.py +0 -263
  64. {biolmai-0.1.8 → biolmai-0.1.10}/AUTHORS.rst +0 -0
  65. {biolmai-0.1.8 → biolmai-0.1.10}/CONTRIBUTING.rst +0 -0
  66. {biolmai-0.1.8 → biolmai-0.1.10}/HISTORY.rst +0 -0
  67. {biolmai-0.1.8 → biolmai-0.1.10}/LICENSE +0 -0
  68. {biolmai-0.1.8 → biolmai-0.1.10}/MANIFEST.in +0 -0
  69. {biolmai-0.1.8 → biolmai-0.1.10}/README.rst +0 -0
  70. {biolmai-0.1.8 → biolmai-0.1.10}/biolmai/auth.py +0 -0
  71. {biolmai-0.1.8 → biolmai-0.1.10}/biolmai/biolmai.py +0 -0
  72. {biolmai-0.1.8 → biolmai-0.1.10}/biolmai/cli.py +0 -0
  73. {biolmai-0.1.8 → biolmai-0.1.10}/biolmai/ltc.py +0 -0
  74. {biolmai-0.1.8 → biolmai-0.1.10}/biolmai.egg-info/dependency_links.txt +0 -0
  75. {biolmai-0.1.8 → biolmai-0.1.10}/biolmai.egg-info/entry_points.txt +0 -0
  76. {biolmai-0.1.8 → biolmai-0.1.10}/biolmai.egg-info/not-zip-safe +0 -0
  77. {biolmai-0.1.8 → biolmai-0.1.10}/biolmai.egg-info/top_level.txt +0 -0
  78. {biolmai-0.1.8 → biolmai-0.1.10}/docs/Makefile +0 -0
  79. {biolmai-0.1.8 → biolmai-0.1.10}/docs/_static/api_reference_icon.png +0 -0
  80. {biolmai-0.1.8 → biolmai-0.1.10}/docs/_static/biolm_docs_logo_dark.png +0 -0
  81. {biolmai-0.1.8 → biolmai-0.1.10}/docs/_static/biolm_docs_logo_light.png +0 -0
  82. {biolmai-0.1.8 → biolmai-0.1.10}/docs/_static/chat_agents_icon.png +0 -0
  83. {biolmai-0.1.8 → biolmai-0.1.10}/docs/_static/jupyter_notebooks_icon.png +0 -0
  84. {biolmai-0.1.8 → biolmai-0.1.10}/docs/_static/model_docs_icon.png +0 -0
  85. {biolmai-0.1.8 → biolmai-0.1.10}/docs/_static/python_sdk_icon.png +0 -0
  86. {biolmai-0.1.8 → biolmai-0.1.10}/docs/_static/tutorials_icon.png +0 -0
  87. {biolmai-0.1.8 → biolmai-0.1.10}/docs/biolmai.rst +0 -0
  88. {biolmai-0.1.8 → biolmai-0.1.10}/docs/make.bat +0 -0
  89. {biolmai-0.1.8 → biolmai-0.1.10}/docs/model-docs/img/book_icon.png +0 -0
  90. {biolmai-0.1.8 → biolmai-0.1.10}/docs/model-docs/img/esmfold_perf.png +0 -0
  91. {biolmai-0.1.8 → biolmai-0.1.10}/docs/model-docs/index.rst +0 -0
  92. {biolmai-0.1.8 → biolmai-0.1.10}/docs/model-docs/progen2/index.rst +0 -0
  93. {biolmai-0.1.8 → biolmai-0.1.10}/docs/modules.rst +0 -0
  94. {biolmai-0.1.8 → biolmai-0.1.10}/docs/python-client/usage.rst +0 -0
  95. {biolmai-0.1.8 → biolmai-0.1.10}/docs/tutorials_use_cases/notebooks.rst +0 -0
  96. {biolmai-0.1.8 → biolmai-0.1.10}/pyproject.toml +0 -0
  97. {biolmai-0.1.8 → biolmai-0.1.10}/tests/__init__.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: biolmai
3
- Version: 0.1.8
3
+ Version: 0.1.10
4
4
  Summary: Python client and SDK for https://biolm.ai
5
5
  Home-page: https://github.com/BioLM/py-biolm
6
6
  Author: Nikhil Haas
@@ -18,6 +18,7 @@ Classifier: Programming Language :: Python :: 3.9
18
18
  Classifier: Programming Language :: Python :: 3.10
19
19
  Classifier: Programming Language :: Python :: 3.11
20
20
  Requires-Python: >=3.6
21
+ Provides-Extra: aiodns
21
22
  License-File: LICENSE
22
23
  License-File: AUTHORS.rst
23
24
 
@@ -1,7 +1,8 @@
1
1
  """Top-level package for BioLM AI."""
2
2
  __author__ = """Nikhil Haas"""
3
3
  __email__ = "nikhil@biolm.ai"
4
- __version__ = '0.1.8'
4
+ __version__ = '0.1.10'
5
5
 
6
+ from biolmai.cls import *
6
7
 
7
8
  __all__ = []
@@ -15,7 +15,7 @@ import biolmai.auth
15
15
  from biolmai.asynch import async_api_call_wrapper
16
16
  from biolmai.biolmai import log
17
17
  from biolmai.const import MULTIPROCESS_THREADS
18
- from biolmai.payloads import INST_DAT_TXT, predict_resp_many_in_one_to_many_singles
18
+ from biolmai.payloads import INST_DAT_TXT, PARAMS_ITEMS, predict_resp_many_in_one_to_many_singles
19
19
 
20
20
 
21
21
  @lru_cache(maxsize=64)
@@ -35,65 +35,82 @@ def text_validator(text, c):
35
35
  except Exception as e:
36
36
  return str(e)
37
37
 
38
+ def combine_validation(x, y):
39
+ if x is None and y is None:
40
+ return None
41
+ elif isinstance(x, str) and y is None:
42
+ return x
43
+ elif x is None and isinstance(y, str):
44
+ return y
45
+ elif isinstance(x, str) and isinstance(y, str):
46
+ return f"{x}\n{y}"
47
+
48
+
49
+ def validate_action(action):
50
+ def validate(f):
51
+ def wrapper(*args, **kwargs):
52
+ # Get class instance at runtime, so you can access not just
53
+ # APIEndpoints, but any *parent* classes of that,
54
+ # like ESMFoldSinglechain.
55
+ class_obj_self = args[0]
56
+ try:
57
+ is_method = inspect.getfullargspec(f)[0][0] == "self"
58
+ except Exception:
59
+ is_method = False
38
60
 
39
- def validate(f):
40
- def wrapper(*args, **kwargs):
41
- # Get class instance at runtime, so you can access not just
42
- # APIEndpoints, but any *parent* classes of that,
43
- # like ESMFoldSinglechain.
44
- class_obj_self = args[0]
45
- try:
46
- is_method = inspect.getfullargspec(f)[0][0] == "self"
47
- except Exception:
48
- is_method = False
49
-
50
- # Is the function we decorated a class method?
51
- if is_method:
52
- name = f"{f.__module__}.{class_obj_self.__class__.__name__}.{f.__name__}"
53
- else:
54
- name = f"{f.__module__}.{f.__name__}"
55
-
56
- if is_method:
57
- # Splits name, e.g. 'biolmai.api.ESMFoldSingleChain.predict'
58
- action_method_name = name.split(".")[-1]
59
- validate_endpoint_action(
60
- class_obj_self.action_class_strings,
61
- action_method_name,
62
- class_obj_self.__class__.__name__,
63
- )
64
-
65
- input_data = args[1]
66
- # Validate each row's text/input based on class attribute `seq_classes`
67
- for c in class_obj_self.seq_classes:
68
- # Validate input data against regex
69
- if class_obj_self.multiprocess_threads:
70
- validation = input_data.text.apply(text_validator, args=(c,))
61
+ # Is the function we decorated a class method?
62
+ if is_method:
63
+ name = f"{f.__module__}.{class_obj_self.__class__.__name__}.{f.__name__}"
71
64
  else:
72
- validation = input_data.text.apply(text_validator, args=(c,))
73
- if "validation" not in input_data.columns:
74
- input_data["validation"] = validation
75
- else:
76
- input_data["validation"] = input_data["validation"].str.cat(
77
- validation, sep="\n", na_rep=""
65
+ name = f"{f.__module__}.{f.__name__}"
66
+
67
+ if is_method:
68
+ # Splits name, e.g. 'biolmai.api.ESMFoldSingleChain.predict'
69
+ action_method_name = name.split(".")[-1]
70
+ validate_endpoint_action(
71
+ class_obj_self.action_class_strings,
72
+ action_method_name,
73
+ class_obj_self.__class__.__name__,
78
74
  )
79
75
 
80
- # Mark your batches, excluding invalid rows
81
- valid_dat = input_data.loc[input_data.validation.isnull(), :].copy()
82
- N = class_obj_self.batch_size # N rows will go per API request
83
- # JOIN back, which is by index
84
- if valid_dat.shape[0] != input_data.shape[0]:
85
- valid_dat["batch"] = np.arange(valid_dat.shape[0]) // N
86
- input_data = input_data.merge(
87
- valid_dat.batch, left_index=True, right_index=True, how="left"
88
- )
89
- else:
90
- input_data["batch"] = np.arange(input_data.shape[0]) // N
91
-
92
- res = f(class_obj_self, input_data, **kwargs)
93
- return res
94
-
95
- return wrapper
76
+ input_data = args[1]
77
+ # Validate each row's text/input based on class attribute `seq_classes`
78
+ if action == "predict":
79
+ input_classes = class_obj_self.predict_input_classes
80
+ elif action == "encode":
81
+ input_classes = class_obj_self.encode_input_classes
82
+ elif action == "generate":
83
+ input_classes = class_obj_self.generate_input_classes
84
+ elif action == "transform":
85
+ input_classes = class_obj_self.transform_input_classes
86
+ for c in input_classes:
87
+ # Validate input data against regex
88
+ if class_obj_self.multiprocess_threads:
89
+ validation = input_data.text.apply(text_validator, args=(c,))
90
+ else:
91
+ validation = input_data.text.apply(text_validator, args=(c,))
92
+ if "validation" not in input_data.columns:
93
+ input_data["validation"] = validation
94
+ else:
95
+ # masking and loc may be more performant option
96
+ input_data["validation"] = input_data["validation"].combine(validation, combine_validation)
97
+
98
+ # Mark your batches, excluding invalid rows
99
+ valid_dat = input_data.loc[input_data.validation.isnull(), :].copy()
100
+ N = class_obj_self.batch_size # N rows will go per API request
101
+ # JOIN back, which is by index
102
+ if valid_dat.shape[0] != input_data.shape[0]:
103
+ valid_dat["batch"] = np.arange(valid_dat.shape[0]) // N
104
+ input_data = input_data.merge(
105
+ valid_dat.batch, left_index=True, right_index=True, how="left"
106
+ )
107
+ else:
108
+ input_data["batch"] = np.arange(input_data.shape[0]) // N
109
+ res = f(class_obj_self, input_data, **kwargs)
110
+ return res
96
111
 
112
+ return wrapper
113
+ return validate
97
114
 
98
115
  def convert_input(f):
99
116
  def wrapper(*args, **kwargs):
@@ -123,7 +140,20 @@ def convert_input(f):
123
140
 
124
141
 
125
142
  class APIEndpoint:
126
- batch_size = 3 # Overwrite in parent classes as needed
143
+ # Overwrite in parent classes as needed
144
+ batch_size = 3
145
+ params = None
146
+ action_classes = ()
147
+ api_version = 2
148
+
149
+ predict_input_key = "sequence"
150
+ encode_input_key = "sequence"
151
+ generate_input_key = "context"
152
+
153
+ predict_input_classes = ()
154
+ encode_input_classes = ()
155
+ generate_input_classes = ()
156
+ transform_input_classes = ()
127
157
 
128
158
  def __init__(self, multiprocess_threads=None):
129
159
  # Check for instance-specific threads, otherwise read from env var
@@ -137,7 +167,7 @@ class APIEndpoint:
137
167
  [c.__name__.replace("Action", "").lower() for c in self.action_classes]
138
168
  )
139
169
 
140
- def post_batches(self, dat, slug, action, payload_maker, resp_key):
170
+ def post_batches(self, dat, slug, action, payload_maker, resp_key, key="sequence", params=None):
141
171
  keep_batches = dat.loc[~dat.batch.isnull(), ["text", "batch"]]
142
172
  if keep_batches.shape[0] == 0:
143
173
  pass # Do nothing - we made nice JSON errors to return in the DF
@@ -145,7 +175,7 @@ class APIEndpoint:
145
175
  # raise AssertionError(err)
146
176
  if keep_batches.shape[0] > 0:
147
177
  api_resps = async_api_call_wrapper(
148
- keep_batches, slug, action, payload_maker, resp_key
178
+ keep_batches, slug, action, payload_maker, resp_key, api_version=self.api_version, key=key, params=params,
149
179
  )
150
180
  if isinstance(api_resps, pd.DataFrame):
151
181
  batch_res = api_resps.explode("api_resp") # Should be lists of results
@@ -170,11 +200,11 @@ class APIEndpoint:
170
200
  dat["api_resp"] = None
171
201
  return dat
172
202
 
173
- def unpack_local_validations(self, dat):
203
+ def unpack_local_validations(self, dat, response_key):
174
204
  dat.loc[dat.api_resp.isnull(), "api_resp"] = (
175
205
  dat.loc[~dat.validation.isnull(), "validation"]
176
206
  .apply(
177
- predict_resp_many_in_one_to_many_singles, args=(None, None, True, None)
207
+ predict_resp_many_in_one_to_many_singles, args=(None, None, True, None), response_key=response_key
178
208
  )
179
209
  .explode()
180
210
  )
@@ -182,39 +212,46 @@ class APIEndpoint:
182
212
  return dat
183
213
 
184
214
  @convert_input
185
- @validate
186
- def predict(self, dat):
187
- dat = self.post_batches(dat, self.slug, "predict", INST_DAT_TXT, "predictions")
188
- dat = self.unpack_local_validations(dat)
215
+ @validate_action("predict")
216
+ def predict(self, dat, params=None):
217
+ if self.api_version == 1:
218
+ dat = self.post_batches(dat, self.slug, "predict", INST_DAT_TXT, "predictions")
219
+ dat = self.unpack_local_validations(dat, "predictions")
220
+ else:
221
+ dat = self.post_batches(dat, self.slug, "predict", PARAMS_ITEMS, "results", key=self.predict_input_key, params=params)
222
+ dat = self.unpack_local_validations(dat,"results")
189
223
  return dat.api_resp.replace(np.nan, None).tolist()
190
224
 
191
- def infer(self, dat):
192
- return self.predict(dat)
225
+ def infer(self, dat, params=None):
226
+ return self.predict(dat, params)
193
227
 
194
228
  @convert_input
195
- @validate
229
+ @validate_action("transform") # api v1 legacy action
196
230
  def transform(self, dat):
197
231
  dat = self.post_batches(
198
232
  dat, self.slug, "transform", INST_DAT_TXT, "predictions"
199
233
  )
200
- dat = self.unpack_local_validations(dat)
234
+ dat = self.unpack_local_validations(dat,"predictions")
201
235
  return dat.api_resp.replace(np.nan, None).tolist()
202
236
 
203
- # @convert_input
204
- # @validate
205
- # def encode(self, dat):
206
- # # NOTE: we defined this for the specific case of ESM2
207
- # # TODO: this will be need again in v2 of API contract
208
- # dat = self.post_batches(dat, self.slug, "transform",
209
- # INST_DAT_TXT, "embeddings")
210
- # dat = self.unpack_local_validations(dat)
211
- # return dat.api_resp.replace(np.nan, None).tolist()
237
+ @convert_input
238
+ @validate_action("encode")
239
+ def encode(self, dat, params=None):
240
+
241
+ dat = self.post_batches(dat, self.slug, "encode", PARAMS_ITEMS, "results", key=self.encode_input_key, params=params)
242
+ dat = self.unpack_local_validations(dat, "results")
243
+ return dat.api_resp.replace(np.nan, None).tolist()
212
244
 
213
245
  @convert_input
214
- @validate
215
- def generate(self, dat):
216
- dat = self.post_batches(dat, self.slug, "generate", INST_DAT_TXT, "generated")
217
- dat = self.unpack_local_validations(dat)
246
+ @validate_action("generate")
247
+ def generate(self, dat, params=None):
248
+ if self.api_version == 1:
249
+ dat = self.post_batches(dat, self.slug, "generate", INST_DAT_TXT, "generated")
250
+ dat = self.unpack_local_validations(dat, "predictions")
251
+ else:
252
+ dat = self.post_batches(dat, self.slug, "generate", PARAMS_ITEMS, "results", key=self.generate_input_key, params=params)
253
+ dat = self.unpack_local_validations(dat, "results")
254
+
218
255
  return dat.api_resp.replace(np.nan, None).tolist()
219
256
 
220
257
 
@@ -290,9 +327,9 @@ class TransformAction:
290
327
  return "TransformAction"
291
328
 
292
329
 
293
- # class EncodeAction:
294
- # def __str__(self):
295
- # return "EncodeAction"
330
+ class EncodeAction:
331
+ def __str__(self):
332
+ return "EncodeAction"
296
333
 
297
334
 
298
335
  class ExplainAction:
@@ -7,7 +7,7 @@ import aiohttp.resolver
7
7
  from aiohttp import ClientSession
8
8
 
9
9
  from biolmai.auth import get_user_auth_header
10
- from biolmai.const import BASE_API_URL, MULTIPROCESS_THREADS
10
+ from biolmai.const import BASE_API_URL, BASE_API_URL_V1, MULTIPROCESS_THREADS
11
11
 
12
12
  aiohttp.resolver.DefaultResolver = aiohttp.resolver.AsyncResolver
13
13
 
@@ -146,11 +146,14 @@ async def async_main(urls, concurrency) -> list:
146
146
  return await get_all(urls, concurrency)
147
147
 
148
148
 
149
- async def async_api_calls(model_name, action, headers, payloads, response_key=None):
149
+ async def async_api_calls(model_name, action, headers, payloads, response_key=None, api_version=2):
150
150
  """Hit an arbitrary BioLM model inference API."""
151
151
  # Normally would POST multiple sequences at once for greater efficiency,
152
152
  # but for simplicity sake will do one at at time right now
153
- url = f"{BASE_API_URL}/models/{model_name}/{action}/"
153
+ if api_version == 1:
154
+ url = f"{BASE_API_URL_V1}/models/{model_name}/{action}/"
155
+ else:
156
+ url = f"{BASE_API_URL}/{model_name}/{action}/"
154
157
 
155
158
  if not isinstance(payloads, (list, dict)):
156
159
  err = "API request payload must be a list or dict, got {}"
@@ -180,15 +183,20 @@ async def async_api_calls(model_name, action, headers, payloads, response_key=No
180
183
  # return response
181
184
 
182
185
 
183
- def async_api_call_wrapper(grouped_df, slug, action, payload_maker, response_key):
186
+ def async_api_call_wrapper(grouped_df, slug, action, payload_maker, response_key, api_version=2, key="sequence", params=None):
184
187
  """Wrap API calls to assist with sequence validation as a pre-cursor to
185
188
  each API call.
186
189
  """
187
190
  model_name = slug
188
191
  # payload = payload_maker(grouped_df)
189
- init_ploads = grouped_df.groupby("batch").apply(
190
- payload_maker, include_batch_size=True
191
- )
192
+ if api_version == 1:
193
+ init_ploads = grouped_df.groupby("batch").apply(
194
+ payload_maker, include_batch_size=True
195
+ )
196
+ else:
197
+ init_ploads = grouped_df.groupby("batch").apply(
198
+ payload_maker, key=key, params=params, include_batch_size=True
199
+ )
192
200
  ploads = init_ploads.to_list()
193
201
  init_ploads = init_ploads.to_frame(name="pload")
194
202
  init_ploads["batch"] = init_ploads.index
@@ -208,7 +216,7 @@ def async_api_call_wrapper(grouped_df, slug, action, payload_maker, response_key
208
216
  # "https://python.org",
209
217
  # ]
210
218
  # concurrency = 3
211
- api_resp = run(async_api_calls(model_name, action, headers, ploads, response_key))
219
+ api_resp = run(async_api_calls(model_name, action, headers, ploads, response_key, api_version))
212
220
  api_resp = [item for sublist in api_resp for item in sublist]
213
221
  api_resp = sorted(api_resp, key=lambda x: x["batch_id"])
214
222
  # print(api_resp)
@@ -0,0 +1,176 @@
1
+ """API inference classes."""
2
+ from biolmai.api import APIEndpoint, GenerateAction, PredictAction, TransformAction, EncodeAction
3
+ from biolmai.validate import (AAExtended,
4
+ AAExtendedPlusExtra,
5
+ AAUnambiguous,
6
+ AAUnambiguousPlusExtra,
7
+ DNAUnambiguous,
8
+ SingleOrMoreOccurrencesOf,
9
+ SingleOccurrenceOf,
10
+ PDB,
11
+ AAUnambiguousEmpty
12
+ )
13
+
14
+
15
+ class ESMFoldSingleChain(APIEndpoint):
16
+ slug = "esmfold-singlechain"
17
+ action_classes = (PredictAction,)
18
+ predict_input_classes = (AAUnambiguous(),)
19
+ batch_size = 2
20
+
21
+
22
+ class ESMFoldMultiChain(APIEndpoint):
23
+ slug = "esmfold-multichain"
24
+ action_classes = (PredictAction,)
25
+ predict_input_classes = (AAExtendedPlusExtra(extra=[":"]),)
26
+ batch_size = 2
27
+
28
+
29
+ class ESM2(APIEndpoint):
30
+ """Example.
31
+
32
+ .. highlight:: python
33
+ .. code-block:: python
34
+
35
+ {
36
+ "items": [{
37
+ "sequence": "MSILVTRPSPAGEELVSRLRTLGQVAWHFPLIEFSPGQQLPQ"
38
+ }]
39
+ }
40
+ """
41
+
42
+ action_classes = (EncodeAction, PredictAction, )
43
+ encode_input_classes = (AAUnambiguous(),)
44
+ predict_input_classes = (SingleOrMoreOccurrencesOf(token="<mask>"), AAExtendedPlusExtra(extra=["<mask>"]))
45
+ batch_size = 1
46
+
47
+ class ESM2_8M(ESM2):
48
+ slug = "esm2-8m"
49
+
50
+ class ESM2_35M(ESM2):
51
+ slug = "esm2-35m"
52
+
53
+ class ESM2_150M(ESM2):
54
+ slug = "esm2-150m"
55
+
56
+ class ESM2_650M(ESM2):
57
+ slug = "esm2-650m"
58
+
59
+ class ESM2_3B(ESM2):
60
+ slug = "esm2-3b"
61
+
62
+ class ESM1v(APIEndpoint):
63
+ """Example.
64
+
65
+ .. highlight:: python
66
+ .. code-block:: python
67
+
68
+ {
69
+ "items": [{
70
+ "sequence": "QERLEUTGR<mask>SLGYNIVAT"
71
+ }]
72
+ }
73
+ """
74
+ action_classes = (PredictAction,)
75
+ predict_input_classes = (SingleOccurrenceOf("<mask>"), AAExtendedPlusExtra(extra=["<mask>"]))
76
+ batch_size = 5
77
+
78
+ class ESM1v1(ESM1v):
79
+ slug = "esm1v-n1"
80
+
81
+ class ESM1v2(ESM1v):
82
+ slug = "esm1v-n2"
83
+
84
+ class ESM1v3(ESM1v):
85
+ slug = "esm1v-n3"
86
+
87
+ class ESM1v4(ESM1v):
88
+ slug = "esm1v-n4"
89
+
90
+ class ESM1v5(ESM1v):
91
+ slug = "esm1v-n5"
92
+
93
+ class ESM1vAll(ESM1v):
94
+ slug = "esm1v-all"
95
+
96
+ class ESMIF1(APIEndpoint):
97
+ slug = "esm-if1"
98
+ action_classes = (GenerateAction,)
99
+ generate_input_classes = PDB
100
+ batch_size = 2
101
+ generate_input_key = "pdb"
102
+
103
+
104
+ class ProGen2(APIEndpoint):
105
+ action_classes = (GenerateAction,)
106
+ generate_input_classes = (AAUnambiguousEmpty(),)
107
+ batch_size = 1
108
+
109
+ class ProGen2Oas(ProGen2):
110
+ slug = "progen2-oas"
111
+
112
+ class ProGen2Medium(ProGen2):
113
+ slug = "progen2-medium"
114
+
115
+ class ProGen2Large(ProGen2):
116
+ slug = "progen2-large"
117
+
118
+ class ProGen2BFD90(ProGen2):
119
+ slug = "progen2-bfd90"
120
+
121
+ class AbLang(APIEndpoint):
122
+ action_classes = (PredictAction, EncodeAction, GenerateAction,)
123
+ predict_input_classes = (AAUnambiguous(),)
124
+ encode_input_classes = (AAUnambiguous(),)
125
+ generate_input_classes = (SingleOrMoreOccurrencesOf(token="*"), AAUnambiguousPlusExtra(extra=["*"]))
126
+ batch_size = 32
127
+ generate_input_key = "sequence"
128
+
129
+ class AbLangHeavy(AbLang):
130
+ slug = "ablang-heavy"
131
+
132
+ class AbLangLight(AbLang):
133
+ slug = "ablang-light"
134
+
135
+ class DNABERT(APIEndpoint):
136
+ slug = "dnabert"
137
+ action_classes = (EncodeAction,)
138
+ encode_input_classes = (DNAUnambiguous(),)
139
+ batch_size = 10
140
+
141
+ class DNABERT2(APIEndpoint):
142
+ slug = "dnabert2"
143
+ action_classes = (EncodeAction,)
144
+ encode_input_classes = (DNAUnambiguous(),)
145
+ batch_size = 10
146
+
147
+ class BioLMToxV1(APIEndpoint):
148
+ """Example.
149
+
150
+ .. highlight:: python
151
+ .. code-block:: python
152
+
153
+ {
154
+ "instances": [{
155
+ "data": {"text": "MSILVTRPSPAGEELVSRLRTLGQVAWHFPLIEFSPGQQLPQ"}
156
+ }]
157
+ }
158
+ """
159
+
160
+ slug = "biolmtox_v1"
161
+ action_classes = (TransformAction, PredictAction,)
162
+ predict_input_classes = (AAUnambiguous(),)
163
+ transform_input_classes = (AAUnambiguous(),)
164
+ batch_size = 1
165
+ api_version = 1
166
+
167
+ class ProteInfer(APIEndpoint):
168
+ action_classes = (PredictAction,)
169
+ predict_input_classes = (AAExtended(),)
170
+ batch_size = 64
171
+
172
+ class ProteInferEC(ProteInfer):
173
+ slug = "proteinfer-ec"
174
+
175
+ class ProteInferGO(ProteInfer):
176
+ slug = "proteinfer-go"
@@ -26,4 +26,5 @@ if int(MULTIPROCESS_THREADS) > max_threads or int(MULTIPROCESS_THREADS) > 128:
26
26
  elif int(MULTIPROCESS_THREADS) <= 0:
27
27
  err = "Environment variable BIOLMAI_THREADS must be a positive integer."
28
28
  raise ValueError(err)
29
- BASE_API_URL = f"{BASE_DOMAIN}/api/v1"
29
+ BASE_API_URL_V1 = f"{BASE_DOMAIN}/api/v1"
30
+ BASE_API_URL = f"{BASE_DOMAIN}/api/v2"
@@ -7,11 +7,22 @@ def INST_DAT_TXT(batch, include_batch_size=False):
7
7
  d["batch_size"] = len(d["instances"])
8
8
  return d
9
9
 
10
+ def PARAMS_ITEMS(batch, key="sequence", params=None, include_batch_size=False):
11
+ d = {"items": []}
12
+ for _, row in batch.iterrows():
13
+ inst = {key: row.text}
14
+ d["items"].append(inst)
15
+ if include_batch_size is True:
16
+ d["batch_size"] = len(d["items"])
17
+ if isinstance(params, dict):
18
+ d["params"] = params
19
+ return d
20
+
10
21
 
11
22
  def predict_resp_many_in_one_to_many_singles(
12
- resp_json, status_code, batch_id, local_err, batch_size
23
+ resp_json, status_code, batch_id, local_err, batch_size, response_key = "results"
13
24
  ):
14
- expected_root_key = "predictions"
25
+ expected_root_key = response_key
15
26
  to_ret = []
16
27
  if not local_err and status_code and status_code == 200:
17
28
  list_of_individual_seq_results = resp_json[expected_root_key]