biolmai 0.1.7__py2.py3-none-any.whl → 0.1.9__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of biolmai might be problematic. Click here for more details.

biolmai/__init__.py CHANGED
@@ -1,7 +1,8 @@
1
1
  """Top-level package for BioLM AI."""
2
2
  __author__ = """Nikhil Haas"""
3
3
  __email__ = "nikhil@biolm.ai"
4
- __version__ = '0.1.7'
4
+ __version__ = '0.1.9'
5
5
 
6
+ from biolmai.cls import *
6
7
 
7
8
  __all__ = []
biolmai/api.py CHANGED
@@ -15,7 +15,7 @@ import biolmai.auth
15
15
  from biolmai.asynch import async_api_call_wrapper
16
16
  from biolmai.biolmai import log
17
17
  from biolmai.const import MULTIPROCESS_THREADS
18
- from biolmai.payloads import INST_DAT_TXT, predict_resp_many_in_one_to_many_singles
18
+ from biolmai.payloads import INST_DAT_TXT, PARAMS_ITEMS, predict_resp_many_in_one_to_many_singles
19
19
 
20
20
 
21
21
  @lru_cache(maxsize=64)
@@ -35,65 +35,82 @@ def text_validator(text, c):
35
35
  except Exception as e:
36
36
  return str(e)
37
37
 
38
+ def combine_validation(x, y):
39
+ if x is None and y is None:
40
+ return None
41
+ elif isinstance(x, str) and y is None:
42
+ return x
43
+ elif x is None and isinstance(y, str):
44
+ return y
45
+ elif isinstance(x, str) and isinstance(y, str):
46
+ return f"{x}\n{y}"
47
+
48
+
49
+ def validate_action(action):
50
+ def validate(f):
51
+ def wrapper(*args, **kwargs):
52
+ # Get class instance at runtime, so you can access not just
53
+ # APIEndpoints, but any *parent* classes of that,
54
+ # like ESMFoldSinglechain.
55
+ class_obj_self = args[0]
56
+ try:
57
+ is_method = inspect.getfullargspec(f)[0][0] == "self"
58
+ except Exception:
59
+ is_method = False
38
60
 
39
- def validate(f):
40
- def wrapper(*args, **kwargs):
41
- # Get class instance at runtime, so you can access not just
42
- # APIEndpoints, but any *parent* classes of that,
43
- # like ESMFoldSinglechain.
44
- class_obj_self = args[0]
45
- try:
46
- is_method = inspect.getfullargspec(f)[0][0] == "self"
47
- except Exception:
48
- is_method = False
49
-
50
- # Is the function we decorated a class method?
51
- if is_method:
52
- name = f"{f.__module__}.{class_obj_self.__class__.__name__}.{f.__name__}"
53
- else:
54
- name = f"{f.__module__}.{f.__name__}"
55
-
56
- if is_method:
57
- # Splits name, e.g. 'biolmai.api.ESMFoldSingleChain.predict'
58
- action_method_name = name.split(".")[-1]
59
- validate_endpoint_action(
60
- class_obj_self.action_class_strings,
61
- action_method_name,
62
- class_obj_self.__class__.__name__,
63
- )
64
-
65
- input_data = args[1]
66
- # Validate each row's text/input based on class attribute `seq_classes`
67
- for c in class_obj_self.seq_classes:
68
- # Validate input data against regex
69
- if class_obj_self.multiprocess_threads:
70
- validation = input_data.text.apply(text_validator, args=(c,))
61
+ # Is the function we decorated a class method?
62
+ if is_method:
63
+ name = f"{f.__module__}.{class_obj_self.__class__.__name__}.{f.__name__}"
71
64
  else:
72
- validation = input_data.text.apply(text_validator, args=(c,))
73
- if "validation" not in input_data.columns:
74
- input_data["validation"] = validation
75
- else:
76
- input_data["validation"] = input_data["validation"].str.cat(
77
- validation, sep="\n", na_rep=""
65
+ name = f"{f.__module__}.{f.__name__}"
66
+
67
+ if is_method:
68
+ # Splits name, e.g. 'biolmai.api.ESMFoldSingleChain.predict'
69
+ action_method_name = name.split(".")[-1]
70
+ validate_endpoint_action(
71
+ class_obj_self.action_class_strings,
72
+ action_method_name,
73
+ class_obj_self.__class__.__name__,
78
74
  )
79
75
 
80
- # Mark your batches, excluding invalid rows
81
- valid_dat = input_data.loc[input_data.validation.isnull(), :].copy()
82
- N = class_obj_self.batch_size # N rows will go per API request
83
- # JOIN back, which is by index
84
- if valid_dat.shape[0] != input_data.shape[0]:
85
- valid_dat["batch"] = np.arange(valid_dat.shape[0]) // N
86
- input_data = input_data.merge(
87
- valid_dat.batch, left_index=True, right_index=True, how="left"
88
- )
89
- else:
90
- input_data["batch"] = np.arange(input_data.shape[0]) // N
91
-
92
- res = f(class_obj_self, input_data, **kwargs)
93
- return res
94
-
95
- return wrapper
76
+ input_data = args[1]
77
+ # Validate each row's text/input based on class attribute `seq_classes`
78
+ if action == "predict":
79
+ input_classes = class_obj_self.predict_input_classes
80
+ elif action == "encode":
81
+ input_classes = class_obj_self.encode_input_classes
82
+ elif action == "generate":
83
+ input_classes = class_obj_self.generate_input_classes
84
+ elif action == "transform":
85
+ input_classes = class_obj_self.transform_input_classes
86
+ for c in input_classes:
87
+ # Validate input data against regex
88
+ if class_obj_self.multiprocess_threads:
89
+ validation = input_data.text.apply(text_validator, args=(c,))
90
+ else:
91
+ validation = input_data.text.apply(text_validator, args=(c,))
92
+ if "validation" not in input_data.columns:
93
+ input_data["validation"] = validation
94
+ else:
95
+ # masking and loc may be more performant option
96
+ input_data["validation"] = input_data["validation"].combine(validation, combine_validation)
97
+
98
+ # Mark your batches, excluding invalid rows
99
+ valid_dat = input_data.loc[input_data.validation.isnull(), :].copy()
100
+ N = class_obj_self.batch_size # N rows will go per API request
101
+ # JOIN back, which is by index
102
+ if valid_dat.shape[0] != input_data.shape[0]:
103
+ valid_dat["batch"] = np.arange(valid_dat.shape[0]) // N
104
+ input_data = input_data.merge(
105
+ valid_dat.batch, left_index=True, right_index=True, how="left"
106
+ )
107
+ else:
108
+ input_data["batch"] = np.arange(input_data.shape[0]) // N
109
+ res = f(class_obj_self, input_data, **kwargs)
110
+ return res
96
111
 
112
+ return wrapper
113
+ return validate
97
114
 
98
115
  def convert_input(f):
99
116
  def wrapper(*args, **kwargs):
@@ -123,7 +140,20 @@ def convert_input(f):
123
140
 
124
141
 
125
142
  class APIEndpoint:
126
- batch_size = 3 # Overwrite in parent classes as needed
143
+ # Overwrite in parent classes as needed
144
+ batch_size = 3
145
+ params = None
146
+ action_classes = ()
147
+ api_version = 2
148
+
149
+ predict_input_key = "sequence"
150
+ encode_input_key = "sequence"
151
+ generate_input_key = "context"
152
+
153
+ predict_input_classes = ()
154
+ encode_input_classes = ()
155
+ generate_input_classes = ()
156
+ transform_input_classes = ()
127
157
 
128
158
  def __init__(self, multiprocess_threads=None):
129
159
  # Check for instance-specific threads, otherwise read from env var
@@ -137,7 +167,7 @@ class APIEndpoint:
137
167
  [c.__name__.replace("Action", "").lower() for c in self.action_classes]
138
168
  )
139
169
 
140
- def post_batches(self, dat, slug, action, payload_maker, resp_key):
170
+ def post_batches(self, dat, slug, action, payload_maker, resp_key, key="sequence", params=None):
141
171
  keep_batches = dat.loc[~dat.batch.isnull(), ["text", "batch"]]
142
172
  if keep_batches.shape[0] == 0:
143
173
  pass # Do nothing - we made nice JSON errors to return in the DF
@@ -145,7 +175,7 @@ class APIEndpoint:
145
175
  # raise AssertionError(err)
146
176
  if keep_batches.shape[0] > 0:
147
177
  api_resps = async_api_call_wrapper(
148
- keep_batches, slug, action, payload_maker, resp_key
178
+ keep_batches, slug, action, payload_maker, resp_key, api_version=self.api_version, key=key, params=params,
149
179
  )
150
180
  if isinstance(api_resps, pd.DataFrame):
151
181
  batch_res = api_resps.explode("api_resp") # Should be lists of results
@@ -170,11 +200,11 @@ class APIEndpoint:
170
200
  dat["api_resp"] = None
171
201
  return dat
172
202
 
173
- def unpack_local_validations(self, dat):
203
+ def unpack_local_validations(self, dat, response_key):
174
204
  dat.loc[dat.api_resp.isnull(), "api_resp"] = (
175
205
  dat.loc[~dat.validation.isnull(), "validation"]
176
206
  .apply(
177
- predict_resp_many_in_one_to_many_singles, args=(None, None, True, None)
207
+ predict_resp_many_in_one_to_many_singles, args=(None, None, True, None), response_key=response_key
178
208
  )
179
209
  .explode()
180
210
  )
@@ -182,39 +212,46 @@ class APIEndpoint:
182
212
  return dat
183
213
 
184
214
  @convert_input
185
- @validate
186
- def predict(self, dat):
187
- dat = self.post_batches(dat, self.slug, "predict", INST_DAT_TXT, "predictions")
188
- dat = self.unpack_local_validations(dat)
215
+ @validate_action("predict")
216
+ def predict(self, dat, params=None):
217
+ if self.api_version == 1:
218
+ dat = self.post_batches(dat, self.slug, "predict", INST_DAT_TXT, "predictions")
219
+ dat = self.unpack_local_validations(dat, "predictions")
220
+ else:
221
+ dat = self.post_batches(dat, self.slug, "predict", PARAMS_ITEMS, "results", key=self.predict_input_key, params=params)
222
+ dat = self.unpack_local_validations(dat,"results")
189
223
  return dat.api_resp.replace(np.nan, None).tolist()
190
224
 
191
- def infer(self, dat):
192
- return self.predict(dat)
225
+ def infer(self, dat, params=None):
226
+ return self.predict(dat, params)
193
227
 
194
228
  @convert_input
195
- @validate
229
+ @validate_action("transform") # api v1 legacy action
196
230
  def transform(self, dat):
197
231
  dat = self.post_batches(
198
232
  dat, self.slug, "transform", INST_DAT_TXT, "predictions"
199
233
  )
200
- dat = self.unpack_local_validations(dat)
234
+ dat = self.unpack_local_validations(dat,"predictions")
201
235
  return dat.api_resp.replace(np.nan, None).tolist()
202
236
 
203
- # @convert_input
204
- # @validate
205
- # def encode(self, dat):
206
- # # NOTE: we defined this for the specific case of ESM2
207
- # # TODO: this will be need again in v2 of API contract
208
- # dat = self.post_batches(dat, self.slug, "transform",
209
- # INST_DAT_TXT, "embeddings")
210
- # dat = self.unpack_local_validations(dat)
211
- # return dat.api_resp.replace(np.nan, None).tolist()
237
+ @convert_input
238
+ @validate_action("encode")
239
+ def encode(self, dat, params=None):
240
+
241
+ dat = self.post_batches(dat, self.slug, "encode", PARAMS_ITEMS, "results", key=self.encode_input_key, params=params)
242
+ dat = self.unpack_local_validations(dat, "results")
243
+ return dat.api_resp.replace(np.nan, None).tolist()
212
244
 
213
245
  @convert_input
214
- @validate
215
- def generate(self, dat):
216
- dat = self.post_batches(dat, self.slug, "generate", INST_DAT_TXT, "generated")
217
- dat = self.unpack_local_validations(dat)
246
+ @validate_action("generate")
247
+ def generate(self, dat, params=None):
248
+ if self.api_version == 1:
249
+ dat = self.post_batches(dat, self.slug, "generate", INST_DAT_TXT, "generated")
250
+ dat = self.unpack_local_validations(dat, "predictions")
251
+ else:
252
+ dat = self.post_batches(dat, self.slug, "generate", PARAMS_ITEMS, "results", key=self.generate_input_key, params=params)
253
+ dat = self.unpack_local_validations(dat, "results")
254
+
218
255
  return dat.api_resp.replace(np.nan, None).tolist()
219
256
 
220
257
 
@@ -290,9 +327,9 @@ class TransformAction:
290
327
  return "TransformAction"
291
328
 
292
329
 
293
- # class EncodeAction:
294
- # def __str__(self):
295
- # return "EncodeAction"
330
+ class EncodeAction:
331
+ def __str__(self):
332
+ return "EncodeAction"
296
333
 
297
334
 
298
335
  class ExplainAction:
biolmai/asynch.py CHANGED
@@ -7,7 +7,7 @@ import aiohttp.resolver
7
7
  from aiohttp import ClientSession
8
8
 
9
9
  from biolmai.auth import get_user_auth_header
10
- from biolmai.const import BASE_API_URL, MULTIPROCESS_THREADS
10
+ from biolmai.const import BASE_API_URL, BASE_API_URL_V1, MULTIPROCESS_THREADS
11
11
 
12
12
  aiohttp.resolver.DefaultResolver = aiohttp.resolver.AsyncResolver
13
13
 
@@ -146,11 +146,14 @@ async def async_main(urls, concurrency) -> list:
146
146
  return await get_all(urls, concurrency)
147
147
 
148
148
 
149
- async def async_api_calls(model_name, action, headers, payloads, response_key=None):
149
+ async def async_api_calls(model_name, action, headers, payloads, response_key=None, api_version=2):
150
150
  """Hit an arbitrary BioLM model inference API."""
151
151
  # Normally would POST multiple sequences at once for greater efficiency,
152
152
  # but for simplicity sake will do one at at time right now
153
- url = f"{BASE_API_URL}/models/{model_name}/{action}/"
153
+ if api_version == 1:
154
+ url = f"{BASE_API_URL_V1}/models/{model_name}/{action}/"
155
+ else:
156
+ url = f"{BASE_API_URL}/{model_name}/{action}/"
154
157
 
155
158
  if not isinstance(payloads, (list, dict)):
156
159
  err = "API request payload must be a list or dict, got {}"
@@ -180,15 +183,20 @@ async def async_api_calls(model_name, action, headers, payloads, response_key=No
180
183
  # return response
181
184
 
182
185
 
183
- def async_api_call_wrapper(grouped_df, slug, action, payload_maker, response_key):
186
+ def async_api_call_wrapper(grouped_df, slug, action, payload_maker, response_key, api_version=2, key="sequence", params=None):
184
187
  """Wrap API calls to assist with sequence validation as a pre-cursor to
185
188
  each API call.
186
189
  """
187
190
  model_name = slug
188
191
  # payload = payload_maker(grouped_df)
189
- init_ploads = grouped_df.groupby("batch").apply(
190
- payload_maker, include_batch_size=True
191
- )
192
+ if api_version == 1:
193
+ init_ploads = grouped_df.groupby("batch").apply(
194
+ payload_maker, include_batch_size=True
195
+ )
196
+ else:
197
+ init_ploads = grouped_df.groupby("batch").apply(
198
+ payload_maker, key=key, params=params, include_batch_size=True
199
+ )
192
200
  ploads = init_ploads.to_list()
193
201
  init_ploads = init_ploads.to_frame(name="pload")
194
202
  init_ploads["batch"] = init_ploads.index
@@ -208,7 +216,7 @@ def async_api_call_wrapper(grouped_df, slug, action, payload_maker, response_key
208
216
  # "https://python.org",
209
217
  # ]
210
218
  # concurrency = 3
211
- api_resp = run(async_api_calls(model_name, action, headers, ploads, response_key))
219
+ api_resp = run(async_api_calls(model_name, action, headers, ploads, response_key, api_version))
212
220
  api_resp = [item for sublist in api_resp for item in sublist]
213
221
  api_resp = sorted(api_resp, key=lambda x: x["batch_id"])
214
222
  # print(api_resp)
biolmai/cls.py CHANGED
@@ -1,97 +1,176 @@
1
1
  """API inference classes."""
2
- from biolmai.api import APIEndpoint, GenerateAction, PredictAction, TransformAction
3
- from biolmai.validate import ExtendedAAPlusExtra, SingleOccurrenceOf, UnambiguousAA
2
+ from biolmai.api import APIEndpoint, GenerateAction, PredictAction, TransformAction, EncodeAction
3
+ from biolmai.validate import (AAExtended,
4
+ AAExtendedPlusExtra,
5
+ AAUnambiguous,
6
+ AAUnambiguousPlusExtra,
7
+ DNAUnambiguous,
8
+ SingleOrMoreOccurrencesOf,
9
+ SingleOccurrenceOf,
10
+ PDB,
11
+ AAUnambiguousEmpty
12
+ )
4
13
 
5
14
 
6
15
  class ESMFoldSingleChain(APIEndpoint):
7
16
  slug = "esmfold-singlechain"
8
17
  action_classes = (PredictAction,)
9
- seq_classes = (UnambiguousAA(),)
18
+ predict_input_classes = (AAUnambiguous(),)
10
19
  batch_size = 2
11
20
 
12
21
 
13
22
  class ESMFoldMultiChain(APIEndpoint):
14
23
  slug = "esmfold-multichain"
15
24
  action_classes = (PredictAction,)
16
- seq_classes = (ExtendedAAPlusExtra(extra=[":"]),)
25
+ predict_input_classes = (AAExtendedPlusExtra(extra=[":"]),)
17
26
  batch_size = 2
18
27
 
19
28
 
20
- class ESM2Embeddings(APIEndpoint):
29
+ class ESM2(APIEndpoint):
21
30
  """Example.
22
31
 
23
32
  .. highlight:: python
24
33
  .. code-block:: python
25
34
 
26
35
  {
27
- "instances": [{
28
- "data": {"text": "MSILVTRPSPAGEELVSRLRTLGQVAWHFPLIEFSPGQQLPQ"}
36
+ "items": [{
37
+ "sequence": "MSILVTRPSPAGEELVSRLRTLGQVAWHFPLIEFSPGQQLPQ"
29
38
  }]
30
39
  }
31
40
  """
32
41
 
33
- slug = "esm2_t33_650M_UR50D"
34
- action_classes = (TransformAction,)
35
- seq_classes = (UnambiguousAA(),)
42
+ action_classes = (EncodeAction, PredictAction, )
43
+ encode_input_classes = (AAUnambiguous(),)
44
+ predict_input_classes = (SingleOrMoreOccurrencesOf(token="<mask>"), AAExtendedPlusExtra(extra=["<mask>"]))
36
45
  batch_size = 1
37
46
 
47
+ class ESM2_8M(ESM2):
48
+ slug = "esm2-8m"
49
+
50
+ class ESM2_35M(ESM2):
51
+ slug = "esm2-35m"
52
+
53
+ class ESM2_150M(ESM2):
54
+ slug = "esm2-150m"
55
+
56
+ class ESM2_650M(ESM2):
57
+ slug = "esm2-650m"
58
+
59
+ class ESM2_3B(ESM2):
60
+ slug = "esm2-3b"
38
61
 
39
- class ESM1v1(APIEndpoint):
62
+ class ESM1v(APIEndpoint):
40
63
  """Example.
41
64
 
42
65
  .. highlight:: python
43
66
  .. code-block:: python
44
67
 
45
68
  {
46
- "instances": [{
47
- "data": {"text": "QERLEUTGR<mask>SLGYNIVAT"}
69
+ "items": [{
70
+ "sequence": "QERLEUTGR<mask>SLGYNIVAT"
48
71
  }]
49
72
  }
50
73
  """
51
-
52
- slug = "esm1v_t33_650M_UR90S_1"
53
74
  action_classes = (PredictAction,)
54
- seq_classes = (SingleOccurrenceOf("<mask>"), ExtendedAAPlusExtra(extra=["<mask>"]))
75
+ predict_input_classes = (SingleOccurrenceOf("<mask>"), AAExtendedPlusExtra(extra=["<mask>"]))
55
76
  batch_size = 5
56
77
 
78
+ class ESM1v1(ESM1v):
79
+ slug = "esm1v-n1"
57
80
 
58
- class ESM1v2(APIEndpoint):
59
- slug = "esm1v_t33_650M_UR90S_2"
60
- action_classes = (PredictAction,)
61
- seq_classes = (SingleOccurrenceOf("<mask>"), ExtendedAAPlusExtra(extra=["<mask>"]))
62
- batch_size = 5
81
+ class ESM1v2(ESM1v):
82
+ slug = "esm1v-n2"
63
83
 
84
+ class ESM1v3(ESM1v):
85
+ slug = "esm1v-n3"
64
86
 
65
- class ESM1v3(APIEndpoint):
66
- slug = "esm1v_t33_650M_UR90S_3"
67
- action_classes = (PredictAction,)
68
- seq_classes = (SingleOccurrenceOf("<mask>"), ExtendedAAPlusExtra(extra=["<mask>"]))
69
- batch_size = 5
87
+ class ESM1v4(ESM1v):
88
+ slug = "esm1v-n4"
70
89
 
90
+ class ESM1v5(ESM1v):
91
+ slug = "esm1v-n5"
71
92
 
72
- class ESM1v4(APIEndpoint):
73
- slug = "esm1v_t33_650M_UR90S_4"
74
- action_classes = (PredictAction,)
75
- seq_classes = (SingleOccurrenceOf("<mask>"), ExtendedAAPlusExtra(extra=["<mask>"]))
76
- batch_size = 5
77
-
78
-
79
- class ESM1v5(APIEndpoint):
80
- slug = "esm1v_t33_650M_UR90S_5"
81
- action_classes = (PredictAction,)
82
- seq_classes = (SingleOccurrenceOf("<mask>"), ExtendedAAPlusExtra(extra=["<mask>"]))
83
- batch_size = 5
84
-
93
+ class ESM1vAll(ESM1v):
94
+ slug = "esm1v-all"
85
95
 
86
96
  class ESMIF1(APIEndpoint):
87
- slug = "esmif1"
97
+ slug = "esm-if1"
88
98
  action_classes = (GenerateAction,)
89
- seq_classes = ()
99
+ generate_input_classes = PDB
90
100
  batch_size = 2
101
+ generate_input_key = "pdb"
91
102
 
92
103
 
93
- class Progen2(APIEndpoint):
94
- slug = "progen2"
104
+ class ProGen2(APIEndpoint):
95
105
  action_classes = (GenerateAction,)
96
- seq_classes = ()
106
+ generate_input_classes = (AAUnambiguousEmpty(),)
97
107
  batch_size = 1
108
+
109
+ class ProGen2Oas(ProGen2):
110
+ slug = "progen2-oas"
111
+
112
+ class ProGen2Medium(ProGen2):
113
+ slug = "progen2-medium"
114
+
115
+ class ProGen2Large(ProGen2):
116
+ slug = "progen2-large"
117
+
118
+ class ProGen2BFD90(ProGen2):
119
+ slug = "progen2-bfd90"
120
+
121
+ class AbLang(APIEndpoint):
122
+ action_classes = (PredictAction, EncodeAction, GenerateAction,)
123
+ predict_input_classes = (AAUnambiguous(),)
124
+ encode_input_classes = (AAUnambiguous(),)
125
+ generate_input_classes = (SingleOrMoreOccurrencesOf(token="*"), AAUnambiguousPlusExtra(extra=["*"]))
126
+ batch_size = 32
127
+ generate_input_key = "sequence"
128
+
129
+ class AbLangHeavy(AbLang):
130
+ slug = "ablang-heavy"
131
+
132
+ class AbLangLight(AbLang):
133
+ slug = "ablang-light"
134
+
135
+ class DNABERT(APIEndpoint):
136
+ slug = "dnabert"
137
+ action_classes = (EncodeAction,)
138
+ encode_input_classes = (DNAUnambiguous(),)
139
+ batch_size = 10
140
+
141
+ class DNABERT2(APIEndpoint):
142
+ slug = "dnabert2"
143
+ action_classes = (EncodeAction,)
144
+ encode_input_classes = (DNAUnambiguous(),)
145
+ batch_size = 10
146
+
147
+ class BioLMToxV1(APIEndpoint):
148
+ """Example.
149
+
150
+ .. highlight:: python
151
+ .. code-block:: python
152
+
153
+ {
154
+ "instances": [{
155
+ "data": {"text": "MSILVTRPSPAGEELVSRLRTLGQVAWHFPLIEFSPGQQLPQ"}
156
+ }]
157
+ }
158
+ """
159
+
160
+ slug = "biolmtox_v1"
161
+ action_classes = (TransformAction, PredictAction,)
162
+ predict_input_classes = (AAUnambiguous(),)
163
+ transform_input_classes = (AAUnambiguous(),)
164
+ batch_size = 1
165
+ api_version = 1
166
+
167
+ class ProteInfer(APIEndpoint):
168
+ action_classes = (PredictAction,)
169
+ predict_input_classes = (AAExtended(),)
170
+ batch_size = 64
171
+
172
+ class ProteInferEC(ProteInfer):
173
+ slug = "proteinfer-ec"
174
+
175
+ class ProteInferGO(ProteInfer):
176
+ slug = "proteinfer-go"
biolmai/const.py CHANGED
@@ -26,4 +26,5 @@ if int(MULTIPROCESS_THREADS) > max_threads or int(MULTIPROCESS_THREADS) > 128:
26
26
  elif int(MULTIPROCESS_THREADS) <= 0:
27
27
  err = "Environment variable BIOLMAI_THREADS must be a positive integer."
28
28
  raise ValueError(err)
29
- BASE_API_URL = f"{BASE_DOMAIN}/api/v1"
29
+ BASE_API_URL_V1 = f"{BASE_DOMAIN}/api/v1"
30
+ BASE_API_URL = f"{BASE_DOMAIN}/api/v2"
biolmai/payloads.py CHANGED
@@ -7,11 +7,22 @@ def INST_DAT_TXT(batch, include_batch_size=False):
7
7
  d["batch_size"] = len(d["instances"])
8
8
  return d
9
9
 
10
+ def PARAMS_ITEMS(batch, key="sequence", params=None, include_batch_size=False):
11
+ d = {"items": []}
12
+ for _, row in batch.iterrows():
13
+ inst = {key: row.text}
14
+ d["items"].append(inst)
15
+ if include_batch_size is True:
16
+ d["batch_size"] = len(d["items"])
17
+ if isinstance(params, dict):
18
+ d["params"] = params
19
+ return d
20
+
10
21
 
11
22
  def predict_resp_many_in_one_to_many_singles(
12
- resp_json, status_code, batch_id, local_err, batch_size
23
+ resp_json, status_code, batch_id, local_err, batch_size, response_key = "results"
13
24
  ):
14
- expected_root_key = "predictions"
25
+ expected_root_key = response_key
15
26
  to_ret = []
16
27
  if not local_err and status_code and status_code == 200:
17
28
  list_of_individual_seq_results = resp_json[expected_root_key]
biolmai/validate.py CHANGED
@@ -1,4 +1,5 @@
1
1
  import re
2
+ from typing import List
2
3
 
3
4
  UNAMBIGUOUS_AA = (
4
5
  "A",
@@ -22,113 +23,137 @@ UNAMBIGUOUS_AA = (
22
23
  "W",
23
24
  "Y",
24
25
  )
25
- AAs = "".join(UNAMBIGUOUS_AA)
26
- # Let's use extended list for ESM-1v
27
- AAs_EXTENDED = "ACDEFGHIKLMNPQRSTVWYBXZJUO"
26
+ aa_unambiguous = "ACDEFGHIKLMNPQRSTVWY"
27
+ aa_extended = aa_unambiguous + "BXZUO"
28
28
 
29
-
30
- UNAMBIGUOUS_DNA = ("A", "C", "T", "G")
31
- AMBIGUOUS_DNA = ("A", "C", "T", "G", "X", "N", "U")
29
+ dna_unambiguous = "ACTG"
30
+ dna_ambiguous = dna_unambiguous + "XNU"
32
31
 
33
32
 
34
33
  regexes = {
35
- "empty_or_unambiguous_aa_validator": re.compile(f"^[{AAs}]*$"),
36
- "empty_or_unambiguous_dna_validator": re.compile(r"^[ACGT]*$"),
37
- "extended_aa_validator": re.compile(f"^[{AAs_EXTENDED}]+$"),
38
- "unambiguous_aa_validator": re.compile(f"^[{AAs}]+$"),
39
- "unambiguous_dna_validator": re.compile(r"^[ACGT]+$"),
34
+ "empty_or_aa_unambiguous_validator": re.compile(f"^[{aa_unambiguous}]*$"),
35
+ "aa_extended_validator": re.compile(f"^[{aa_extended}]+$"),
36
+ "aa_unambiguous_validator": re.compile(f"^[{aa_unambiguous}]+$"),
37
+ "empty_or_dna_unambiguous_validator": re.compile(f"^[{dna_unambiguous}]*$"),
38
+ "dna_unambiguous_validator": re.compile(f"^[{dna_unambiguous}]+$"),
40
39
  }
41
40
 
42
41
 
43
- def empty_or_unambiguous_aa_validator(txt):
44
- r = regexes["empty_or_unambiguous_aa_validator"]
45
- if not bool(r.match(txt)):
46
- err = f"Residues can only be represented with '{AAs}' characters"
47
- raise AssertionError(err)
48
- return txt
42
+ def empty_or_aa_unambiguous_validator(text: str) -> str:
43
+ if not regexes["empty_or_aa_unambiguous_validator"].match(text):
44
+ raise ValueError(
45
+ f"Residues can only be represented with '{aa_unambiguous}' characters"
46
+ )
47
+ return text
49
48
 
50
49
 
51
- def empty_or_unambiguous_dna_validator(txt):
52
- r = regexes["empty_or_unambiguous_dna_validator"]
53
- if not bool(r.match(txt)):
54
- err = "Nucleotides can only be represented with 'ACTG' characters"
55
- raise AssertionError(err)
56
- return txt
50
+ def empty_or_dna_unambiguous_validator(text: str) -> str:
51
+ if not regexes["empty_or_dna_unambiguous_validator"].match(text):
52
+ raise ValueError(
53
+ f"Nucleotides can only be represented with '{dna_unambiguous}' characters"
54
+ )
55
+ return text
57
56
 
58
57
 
59
- def extended_aa_validator(txt):
60
- r = regexes["extended_aa_validator"]
61
- if not bool(r.match(txt)):
62
- err = (
63
- f"Extended residues can only be represented with "
64
- f"'{AAs_EXTENDED}' characters"
58
+ def aa_extended_validator(text: str) -> str:
59
+ if not regexes["aa_extended_validator"].match(text):
60
+ raise ValueError(
61
+ f"Residues can only be represented with '{aa_extended}' characters"
65
62
  )
66
- raise AssertionError(err)
67
- return txt
63
+ return text
68
64
 
69
65
 
70
- def unambiguous_aa_validator(txt):
71
- r = regexes["unambiguous_aa_validator"]
72
- if not bool(r.match(txt)):
73
- err = (
74
- f"Unambiguous residues can only be represented with '{AAs}' " f"characters"
66
+ def aa_unambiguous_validator(text: str) -> str:
67
+ if not regexes["aa_unambiguous_validator"].match(text):
68
+ raise ValueError(
69
+ f"Residues can only be represented with '{aa_unambiguous}' characters"
75
70
  )
76
- raise AssertionError(err)
77
- return txt
71
+ return text
78
72
 
79
73
 
80
- def unambiguous_dna_validator(txt):
81
- r = regexes["unambiguous_dna_validator"]
82
- if not bool(r.match(txt)):
83
- err = (
84
- "Unambiguous nucleotides can only be represented with 'ACTG' " "characters"
74
+ def dna_unambiguous_validator(text: str) -> str:
75
+ if not regexes["dna_unambiguous_validator"].match(text):
76
+ raise ValueError(
77
+ f"Nucleotides can only be represented with '{dna_unambiguous}' characters"
85
78
  )
86
- raise AssertionError(err)
87
- return txt
79
+ return text
88
80
 
81
+ def pdb_validator(text: str) -> str:
82
+ if "ATOM" not in text:
83
+ raise ValueError("PDB string does not appear to be a valid PDB")
84
+ return text
89
85
 
90
- class UnambiguousAA:
86
+
87
+ class PDB:
88
+ def __call__(self, value):
89
+ _ = pdb_validator(value)
90
+ class AAUnambiguous:
91
91
  def __call__(self, value):
92
- _ = unambiguous_aa_validator(value)
92
+ _ = aa_unambiguous_validator(value)
93
93
 
94
+ class AAExtended:
95
+ def __call__(self, value):
96
+ _ = aa_extended_validator(value)
94
97
 
95
- class UnambiguousAAPlusExtra:
96
- def __init__(self, extra=None):
97
- if extra is None:
98
- extra = []
99
- self.extra = extra
100
- assert len(extra) > 0
101
- assert isinstance(extra, list)
98
+ class DNAUnambiguous:
99
+ def __call__(self, value):
100
+ _ = dna_unambiguous_validator(value)
102
101
 
102
+ class AAUnambiguousEmpty:
103
103
  def __call__(self, value):
104
- txt_clean = value
104
+ _ = empty_or_aa_unambiguous_validator(value)
105
+
106
+ class AAUnambiguousPlusExtra:
107
+ def __init__(self, extra: List[str]):
108
+ if not extra:
109
+ raise ValueError("Extra cannot be empty")
110
+ self.extra = extra
111
+
112
+ def __call__(self, value: str) -> str:
113
+ text_clean = value
105
114
  for ex in self.extra:
106
- txt_clean = value.replace(ex, "")
107
- _ = unambiguous_aa_validator(txt_clean)
115
+ text_clean = text_clean.replace(ex, "")
116
+ aa_unambiguous_validator(text_clean)
117
+ return value
108
118
 
109
119
 
110
- class ExtendedAAPlusExtra:
111
- def __init__(self, extra=None):
112
- if extra is None:
113
- extra = []
120
+ class AAExtendedPlusExtra:
121
+ def __init__(self, extra: List[str]):
122
+ if not extra:
123
+ raise ValueError("Extra cannot be empty")
114
124
  self.extra = extra
115
- assert len(extra) > 0
116
- assert isinstance(extra, list)
117
125
 
118
- def __call__(self, value):
119
- txt_clean = value
126
+ def __call__(self, value: str) -> str:
127
+ text_clean = value
120
128
  for ex in self.extra:
121
- txt_clean = value.replace(ex, "")
122
- _ = extended_aa_validator(txt_clean)
129
+ text_clean = text_clean.replace(ex, "")
130
+ aa_extended_validator(text_clean)
131
+ return value
123
132
 
124
133
 
125
134
  class SingleOccurrenceOf:
126
- def __init__(self, single_char):
127
- self.single_char = single_char
135
+ def __init__(self, single_token: str):
136
+ self.single_token = single_token
137
+
138
+ def __call__(self, value: str) -> str:
139
+ count = value.count(self.single_token)
140
+ if count != 1:
141
+ raise ValueError(
142
+ f"Expected a single occurrence of '{self.single_token}', got {count}"
143
+ )
144
+ return value
145
+
146
+
147
+ class SingleOrMoreOccurrencesOf:
148
+ def __init__(self, token: str):
149
+ self.token = token
150
+
151
+ def __call__(self, value: str) -> str:
152
+ count = value.count(self.token)
153
+ if count < 1:
154
+ raise ValueError(
155
+ f"Expected at least one occurrence of '{self.token}', got none"
156
+ )
157
+ return value
158
+
128
159
 
129
- def __call__(self, value):
130
- s = self.single_char
131
- cc = value.count(s)
132
- if cc != 1:
133
- err = "Expected a single occurrence of '{}', got {}"
134
- raise AssertionError(err.format(s, cc))
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: biolmai
3
- Version: 0.1.7
3
+ Version: 0.1.9
4
4
  Summary: Python client and SDK for https://biolm.ai
5
5
  Home-page: https://github.com/BioLM/py-biolm
6
6
  Author: Nikhil Haas
@@ -0,0 +1,18 @@
1
+ biolmai/__init__.py,sha256=05laq4xekEMZnrxknETvEsq9nY4Xa-CcZZs3ekK2aoA,162
2
+ biolmai/api.py,sha256=1T38KUoOiPl8IjXfxsypIKGraLNcjtlDbtkrvohEZJU,12959
3
+ biolmai/asynch.py,sha256=BVypJhhEEK2Bek2AhqNGn7FIRJehAbJflUdeeslbXFE,9073
4
+ biolmai/auth.py,sha256=flI9KAD90qdXyLDnpJTrc9voKsiK0uWtD2ehsPBn8r4,6329
5
+ biolmai/biolmai.py,sha256=xwjAvuw6AtmQdkRf_usSGUZ-k2oU-fjl82_WAgfSvVE,74
6
+ biolmai/cli.py,sha256=bdb4q8QlN73A6Ttz0e-dBIwoct7PYqy5WSc52jCMIyU,1967
7
+ biolmai/cls.py,sha256=Hiy_Qoj2Eb43oltnEUdJfMPCsOeFKZ-GUNljF-yShug,4287
8
+ biolmai/const.py,sha256=vCSj-itsusZWoLR27DYQSpuq024GQz3-uKJuDUoPF0Y,1153
9
+ biolmai/ltc.py,sha256=al7HZc5tLyUR5fmpIb95hOz5ctudVsc0xzjd_c2Ew3M,49
10
+ biolmai/payloads.py,sha256=BOhEKl9kWkKMXy1YiNw2_eC6MJ4Dn6vKNvkhEBsM7Lw,1735
11
+ biolmai/validate.py,sha256=58XMWrdWoDRmfiNAayWqrYaH3_bjRmEpG_yx6XSjTrM,4168
12
+ biolmai-0.1.9.dist-info/AUTHORS.rst,sha256=TB_ACuFPgVmxn1NspYwksTdT6jdZeShcxfafmi-XWKQ,158
13
+ biolmai-0.1.9.dist-info/LICENSE,sha256=8yt0SdP38I7a3g0zWqZjNe0VSDQhJA4bWLQSqqKtAVg,583
14
+ biolmai-0.1.9.dist-info/METADATA,sha256=mEmPMicZdXQVKMsayll4CCaVdi2hWAxkpqL9ZbYqKKc,1929
15
+ biolmai-0.1.9.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
16
+ biolmai-0.1.9.dist-info/entry_points.txt,sha256=ylQnDpCYrxF1F9z_T7NRQcYMWYF5ia_KsTUuboxjEAM,44
17
+ biolmai-0.1.9.dist-info/top_level.txt,sha256=jyQO45JN3g_jbdI8WqMnb0aEIzf4h1MrmPAZkKgfnwY,8
18
+ biolmai-0.1.9.dist-info/RECORD,,
@@ -1,18 +0,0 @@
1
- biolmai/__init__.py,sha256=lJ7PiA_IyjKhz3dI8nrnqy8S_wqAHtEM3iN3v3eArr0,136
2
- biolmai/api.py,sha256=3DcXeTFwXdn2KpHrGPxFGN6bvzdFjK6_4KUZuaRe64w,10974
3
- biolmai/asynch.py,sha256=ZLCiNdGDR2XvijM6jFB2IFl3bG7ROp4PxKbo1rI5s7A,8698
4
- biolmai/auth.py,sha256=flI9KAD90qdXyLDnpJTrc9voKsiK0uWtD2ehsPBn8r4,6329
5
- biolmai/biolmai.py,sha256=xwjAvuw6AtmQdkRf_usSGUZ-k2oU-fjl82_WAgfSvVE,74
6
- biolmai/cli.py,sha256=bdb4q8QlN73A6Ttz0e-dBIwoct7PYqy5WSc52jCMIyU,1967
7
- biolmai/cls.py,sha256=yacZIwDyDq3sgU3FSc-l8uld83lkwSTh4wiS-vGNT4I,2425
8
- biolmai/const.py,sha256=kbpmBEm-bw7lhGIJcMFeq1pfsIYeRk01_JwBufjupXc,1111
9
- biolmai/ltc.py,sha256=al7HZc5tLyUR5fmpIb95hOz5ctudVsc0xzjd_c2Ew3M,49
10
- biolmai/payloads.py,sha256=WmFN9JUojbrdvd_By8WWURS6Gm5Bh1fPYK0UjLDCbzU,1356
11
- biolmai/validate.py,sha256=QdPDuZodHn85p1Y7KGkxCDMuRcXBOzAB9lkNZpigw9g,3311
12
- biolmai-0.1.7.dist-info/AUTHORS.rst,sha256=TB_ACuFPgVmxn1NspYwksTdT6jdZeShcxfafmi-XWKQ,158
13
- biolmai-0.1.7.dist-info/LICENSE,sha256=8yt0SdP38I7a3g0zWqZjNe0VSDQhJA4bWLQSqqKtAVg,583
14
- biolmai-0.1.7.dist-info/METADATA,sha256=S2JBm8gzzRm_Xsb0aY3LozcW9TSocbqFLZd8BsA7gQw,1929
15
- biolmai-0.1.7.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
16
- biolmai-0.1.7.dist-info/entry_points.txt,sha256=ylQnDpCYrxF1F9z_T7NRQcYMWYF5ia_KsTUuboxjEAM,44
17
- biolmai-0.1.7.dist-info/top_level.txt,sha256=jyQO45JN3g_jbdI8WqMnb0aEIzf4h1MrmPAZkKgfnwY,8
18
- biolmai-0.1.7.dist-info/RECORD,,