chembfn-webui 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of chembfn-webui might be problematic. Click here for more details.

chembfn_webui/bin/app.py CHANGED
@@ -24,7 +24,7 @@ from bayesianflow_for_chem.data import (
24
24
  aa2vec,
25
25
  split_selfies,
26
26
  )
27
- from bayesianflow_for_chem.tool import sample, inpaint, adjust_lora_
27
+ from bayesianflow_for_chem.tool import sample, inpaint, adjust_lora_, quantise_model_
28
28
  from lib.utilities import (
29
29
  find_model,
30
30
  find_vocab,
@@ -41,6 +41,16 @@ cache_dir = Path(__file__).parent.parent / "cache"
41
41
 
42
42
 
43
43
  def selfies2vec(sel: str, vocab_dict: Dict[str, int]) -> List[int]:
44
+ """
45
+ Tokeniser SELFIES string.
46
+
47
+ :param sel: SELFIES string
48
+ :param vocab_dict: vocabulary dictionary
49
+ :type sel: str
50
+ :type vocab_dict: dict
51
+ :return: a list of token indices
52
+ :rtype: list
53
+ """
44
54
  s = split_selfies(sel)
45
55
  unknown_id = None
46
56
  for key, idx in vocab_dict.items():
@@ -55,6 +65,23 @@ def refresh(
55
65
  ) -> Tuple[
56
66
  List[str], List[str], List[List[str]], List[List[str]], gr.Dropdown, gr.Dropdown
57
67
  ]:
68
+ """
69
+ Refresh model file list.
70
+
71
+ :param model_selected: the selected model name
72
+ :param vocab_selected: the selected vocabulary name
73
+ :param tokeniser_selected: the selected tokeniser name
74
+ :type model_selected: str
75
+ :type vocab_selected: str
76
+ :type tokeniser_selected: str
77
+ :return: a list of vocabulary names \n
78
+ a list of base model files \n
79
+ a list of standalone model files \n
80
+ a list of LoRA model files \n
81
+ Gradio Dropdown item \n
82
+ Gradio Dropdown item \n
83
+ :rtype: tuple
84
+ """
58
85
  global vocabs, models
59
86
  vocabs = find_vocab()
60
87
  models = find_model()
@@ -77,6 +104,16 @@ def refresh(
77
104
 
78
105
 
79
106
  def select_lora(evt: gr.SelectData, prompt: str) -> str:
107
+ """
108
+ Select LoRA model name from Dataframe object.
109
+
110
+ :param evt: `~gradio.SelectData` instance
111
+ :param prompt: prompt string
112
+ :type evt: gradio.SelectData
113
+ :type prompt: str
114
+ :return: new prompt string
115
+ :rtype: str
116
+ """
80
117
  global lora_selected
81
118
  if lora_selected: # avoid double select
82
119
  lora_selected = False
@@ -104,7 +141,49 @@ def run(
104
141
  scaffold: str,
105
142
  sar_control: str,
106
143
  exclude_token: str,
144
+ quantise: str,
145
+ jited: str,
107
146
  ) -> Tuple[List, List[str], str, str, str]:
147
+ """
148
+ Run generation or inpainting.
149
+
150
+ :param model_name: model name
151
+ :param token_name: tokeniser name
152
+ :param vocab_fn: customised vocabulary name
153
+ :param step: number of sampling steps
154
+ :param batch_size: batch-size
155
+ :param sequence_size: maximum sequence length
156
+ :param guidance_strength: guidance strength of conditioning
157
+ :param method: `"BFN"` or `"ODE"`
158
+ :param temperature: sampling temperature while ODE-solver used
159
+ :param prompt: prompt string
160
+ :param scaffold: molecular scaffold
161
+ :param sar_control: semi-autoregressive behaviour flags
162
+ :param exclude_token: unwanted tokens
163
+ :param quantise: `"on"` or `"off"`
164
+ :param jited: `"on"` or `"off"`
165
+ :type model_name: str
166
+ :type token_name: str
167
+ :type vocab_fn: str
168
+ :type step: int
169
+ :type batch_size: int
170
+ :type sequence_size: int
171
+ :type guidance_strength: float
172
+ :type method: str
173
+ :type temperature: float
174
+ :type prompt: str
175
+ :type scaffold: str
176
+ :type sar_control: str
177
+ :type exclude_token: str
178
+ :type quantise: str
179
+ :type jited: str
180
+ :return: list of images \n
181
+ list of generated molecules \n
182
+ Chemfig code \n
183
+ messages \n
184
+ cache file path
185
+ :rtype: tuple
186
+ """
108
187
  _message = []
109
188
  base_model_dict = dict(models["base"])
110
189
  standalone_model_dict = dict([[i[0], i[1]] for i in models["standalone"]])
@@ -113,6 +192,7 @@ def run(
113
192
  lora_label_dict = dict([[i[0], i[2] != []] for i in models["lora"]])
114
193
  standalone_lmax_dict = dict([[i[0], i[3]] for i in models["standalone"]])
115
194
  lora_lmax_dict = dict([[i[0], i[3]] for i in models["lora"]])
195
+ # ------- build tokeniser -------
116
196
  if token_name == "SMILES & SAFE":
117
197
  vocab_keys = VOCAB_KEYS
118
198
  tokeniser = smiles2vec
@@ -139,7 +219,7 @@ def run(
139
219
  # ------- build model -------
140
220
  prompt_info = parse_prompt(prompt)
141
221
  sar_flag = parse_sar_control(sar_control)
142
- print(prompt_info)
222
+ print("Prompt summary:", prompt_info) # show prompt info
143
223
  if not prompt_info["lora"]:
144
224
  if model_name in base_model_dict:
145
225
  lmax = sequence_size
@@ -166,6 +246,10 @@ def run(
166
246
  y = None
167
247
  _message.append(f"Sequence length is set to {lmax} from model metadata.")
168
248
  bfn.semi_autoregressive = sar_flag[0]
249
+ if quantise == "on":
250
+ quantise_model_(bfn)
251
+ if jited == "on":
252
+ bfn.compile()
169
253
  elif len(prompt_info["lora"]) == 1:
170
254
  lmax = lora_lmax_dict[prompt_info["lora"][0]]
171
255
  if model_name in base_model_dict:
@@ -194,6 +278,10 @@ def run(
194
278
  adjust_lora_(bfn, prompt_info["lora_scaling"][0])
195
279
  _message.append(f"Sequence length is set to {lmax} from model metadata.")
196
280
  bfn.semi_autoregressive = sar_flag[0]
281
+ if quantise == "on":
282
+ quantise_model_(bfn)
283
+ if jited == "on":
284
+ bfn.compile()
197
285
  else:
198
286
  lmax = max([lora_lmax_dict[i] for i in prompt_info["lora"]])
199
287
  if model_name in base_model_dict:
@@ -211,6 +299,10 @@ def run(
211
299
  sar_flag = [sar_flag[0] for _ in range(len(weights))]
212
300
  bfn = EnsembleChemBFN(base_model_dir, lora_dir, mlps, weights)
213
301
  y = [torch.tensor([i], dtype=torch.float32) for i in prompt_info["objective"]]
302
+ if quantise == "on":
303
+ bfn.quantise()
304
+ if jited == "on":
305
+ bfn.compile()
214
306
  _message.append(f"Sequence length is set to {lmax} from model metadata.")
215
307
  # ------- inference -------
216
308
  allowed_tokens = parse_exclude_token(exclude_token, vocab_keys)
@@ -366,6 +458,8 @@ with gr.Blocks(title="ChemBFN WebUI") as app:
366
458
  label="exclude tokens",
367
459
  placeholder="key in unwanted tokens separated by comma.",
368
460
  )
461
+ quantise = gr.Radio(["on", "off"], value="off", label="quantisation")
462
+ jited = gr.Radio(["on", "off"], value="off", label="JIT")
369
463
  # ------ user interaction events -------
370
464
  btn.click(
371
465
  fn=run,
@@ -383,6 +477,8 @@ with gr.Blocks(title="ChemBFN WebUI") as app:
383
477
  scaffold,
384
478
  sar_control,
385
479
  exclude_token,
480
+ quantise,
481
+ jited,
386
482
  ],
387
483
  outputs=[img, result, chemfig, message, btn_download],
388
484
  )
@@ -426,6 +522,12 @@ with gr.Blocks(title="ChemBFN WebUI") as app:
426
522
 
427
523
 
428
524
  def main() -> None:
525
+ """
526
+ Main function.
527
+
528
+ :return:
529
+ :rtype: None
530
+ """
429
531
  parser = argparse.ArgumentParser()
430
532
  parser.add_argument(
431
533
  "--public", default=False, help="open to public", action="store_true"
@@ -1,15 +1,7 @@
1
- [C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][O]
2
- [C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][=Branch1][C][=O][N][C][C][C]
3
- [C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][N][C][C][Branch1][=Branch2][C][=C][C][=C][C][=C][Ring1][=Branch1][C][=C][C][=C][C][=C][Ring1][=Branch1]
4
- [C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][O]
5
- [C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C]
6
- [C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C]
7
- [C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C]
8
- [C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C]
9
- [C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C]
10
- [C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C]
11
- [O][=C][Branch2][Ring1][C][N][N][C][=Branch1][C][=O][C][=C][C][=C][Branch1][C][Cl][C][=C][Ring1][#Branch1][C][=C][C][=C][C][=C][Ring1][=Branch1]
12
- [C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C]
13
- [C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][=Branch1][C][=O][O]
14
- [C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C]
15
- [C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C]
1
+ CC(C)(CC#N)C#N
2
+ CCC(C)(C#N)C#N
3
+ C1C(C)(C#N)C1C#N
4
+ C(C)(C)C(C#N)C#N
5
+ C(C)CC(C#N)C#N
6
+ CCC(C)(C#N)C#N
7
+ CC(C)(C#N)C#N
@@ -15,6 +15,12 @@ if "CHEMBFN_WEBUI_MODEL_DIR" in os.environ:
15
15
 
16
16
 
17
17
  def find_vocab() -> Dict[str, str]:
18
+ """
19
+ Find customised vocabulary files.
20
+
21
+ :return: {file_name: file_path}
22
+ :rtype: dict
23
+ """
18
24
  vocab_fns = glob(str(_model_path / "vocab/*.txt"))
19
25
  return {
20
26
  os.path.basename(i).replace(".txt", ""): i
@@ -24,6 +30,17 @@ def find_vocab() -> Dict[str, str]:
24
30
 
25
31
 
26
32
  def find_model() -> Dict[str, List[List[Union[str, int, List[str], Path]]]]:
33
+ """
34
+ Find model files.
35
+
36
+ :return: ```
37
+ {
38
+ "base": [[name1, path1], [name2, path2], ...],
39
+ "standalone": [[name1, parent_path1, label1, pad_len1], ...],
40
+ "lora": [[name1, parent_path1, label1, pad_len1], ...]
41
+ }```
42
+ :rtype: dict
43
+ """
27
44
  models = {}
28
45
  # find base models
29
46
  base_fns = glob(str(_model_path / "base_model/*.pt"))
@@ -61,7 +78,21 @@ def find_model() -> Dict[str, List[List[Union[str, int, List[str], Path]]]]:
61
78
  return models
62
79
 
63
80
 
64
- def _get_lora_info(prompt: str) -> Tuple[str, List[float], List[float]]:
81
+ def _get_lora_info(prompt: str) -> Tuple[str, List[float], float]:
82
+ """
83
+ Parse sub-prompt string containing LoRA info.
84
+
85
+ :param prompt: LoRA sub-pompt: \n
86
+ case I. `"<name:A>"` \n
87
+ case II. `"<name>"` \n
88
+ case III. `"<name:A>:[a,b,...]"` \n
89
+ case IV. `"<name>:[a,b,c,...]"`
90
+ :type prompt: str
91
+ :return: LoRA name \n
92
+ objective values \n
93
+ LoRA scaling
94
+ :rtype: tuple
95
+ """
65
96
  s = prompt.split(">")
66
97
  s1 = s[0].replace("<", "")
67
98
  lora_info = s1.split(":")
@@ -83,6 +114,27 @@ def _get_lora_info(prompt: str) -> Tuple[str, List[float], List[float]]:
83
114
  def parse_prompt(
84
115
  prompt: str,
85
116
  ) -> Dict[str, Union[List[str], List[float], List[List[float]]]]:
117
+ """
118
+ Parse propmt.
119
+
120
+ :param prompt: prompt string: \n
121
+ case I. empty string `""` --> `{"lora": [], "objective": [], "lora_scaling": []}`\n
122
+ case II. one condition `"[a,b,c,...]"` --> `{"lora": [], "objective": [[a, b, c, ...]], "lora_scaling": []}`\n
123
+ case III. one LoRA `"<name:A>"` --> `{"lora": [name], "objective": [], "lora_scaling": [A]}`\n
124
+ case IV. one LoRA `"<name>"` --> `{"lora": [name], "objective": [], "lora_scaling": [1]}`\n
125
+ case V. one LoRA with condition `"<name:A>:[a,b,...]"` --> `{"lora": [name], "objective": [[a, b, ...]], "lora_scaling": [A]}`\n
126
+ case VI. one LoRA with condition `"<name>:[a,b,...]"` --> `{"lora": [name], "objective": [[a, b, ...]], "lora_scaling": [1]}`\n
127
+ case VII. several LoRAs with conditions `"<name1:A1>:[a1,b1,...];<name2>:[a2,b2,c2,...]"` --> `{"lora": [name1, name2], "objective": [[a1, b1, ...], [a2, b2, c2, ...]], "lora_scaling": [A1, 1]}`\n
128
+ case VIII. other cases --> `{"lora": [], "objective": [], "lora_scaling": []}`\n
129
+ :type prompt: str
130
+ :return: ```
131
+ {
132
+ "lora": [name1, name2, ...],
133
+ "objective": [obj1, obj2, ...],
134
+ "lora_scaling": [s1, s2, ...]
135
+ }```
136
+ :rtype: dict
137
+ """
86
138
  prompt_group = prompt.strip().replace("\n", "").split(";")
87
139
  prompt_group = [i for i in prompt_group if i]
88
140
  info = {"lora": [], "objective": [], "lora_scaling": []}
@@ -114,6 +166,16 @@ def parse_prompt(
114
166
 
115
167
 
116
168
  def parse_exclude_token(tokens: str, vocab_keys: List[str]) -> List[str]:
169
+ """
170
+ Parse exclude token string.
171
+
172
+ :param tokens: unwanted token string in the format `"token1,token2,..."`
173
+ :param vocab_keys: vocabulary elements
174
+ :type tokens: str
175
+ :type vocab_keys: list
176
+ :return: a list of allowed vocabulary
177
+ :rtype: list
178
+ """
117
179
  tokens = tokens.strip().replace("\n", "").split(",")
118
180
  tokens = [i for i in tokens if i]
119
181
  if not tokens:
@@ -121,7 +183,21 @@ def parse_exclude_token(tokens: str, vocab_keys: List[str]) -> List[str]:
121
183
  tokens = [i for i in vocab_keys if i not in tokens]
122
184
  return tokens
123
185
 
186
+
124
187
  def parse_sar_control(sar_control: str) -> List[bool]:
188
+ """
189
+ Parse semi-autoregression control string.
190
+
191
+ :param sar_control: semi-autoregression control string: \n
192
+ case I. `""` --> `[False]` \n
193
+ case II. `"F"` --> `[False]` \n
194
+ case III. `"T"` --> `[True]` \n
195
+ case IV. `F,T,...` --> `[False, True, ...]` \n
196
+ case V. other cases --> `[False, False, ...]` \n
197
+ :type sar_control: str
198
+ :return: a list of SAR flag
199
+ :rtype: list
200
+ """
125
201
  sar_flag = sar_control.strip().replace("\n", "").split(",")
126
202
  sar_flag = [i for i in sar_flag if i]
127
203
  if not sar_flag:
@@ -4,5 +4,5 @@
4
4
  Version info.
5
5
  """
6
6
 
7
- __version__ = "0.1.0"
7
+ __version__ = "0.2.0"
8
8
  __author__ = "Nianze A. TAO"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: chembfn_webui
3
- Version: 0.1.0
3
+ Version: 0.2.0
4
4
  Summary: WebUI for ChemBFN
5
5
  Home-page: https://github.com/Augus1999/ChemBFN-WebUI
6
6
  Author: Nianze A. Tao
@@ -122,9 +122,16 @@ $ chembfn
122
122
 
123
123
  Under "advanced control" tab
124
124
 
125
- * You can control semi-autoregressive behaviours by key in `F` for switch off SAR, `T` for switch on SAR, and prompt like `F,F,T,...` to individually control the SAR in an ensemble model.
125
+ * You can control semi-autoregressive behaviours by key in `F` for switching off SAR, `T` for switching on SAR, and prompt like `F,F,T,...` to individually control the SAR in an ensemble model.
126
126
  * You can add unwanted tokens, e.g., `[Cu],p,[Si]`.
127
127
 
128
128
  ### 6. Generate molecules
129
129
 
130
130
  Click "RUN" then here you go! If error occured, please check your prompts and settings.
131
+
132
+ ## Where to obtain the models?
133
+
134
+ * Pretrained models: [https://huggingface.co/suenoomozawa/ChemBFN](https://huggingface.co/suenoomozawa/ChemBFN)
135
+ * ChemBFN source code: [https://github.com/Augus1999/bayesian-flow-network-for-chemistry](https://github.com/Augus1999/bayesian-flow-network-for-chemistry)
136
+ * ChemBFN document: [https://augus1999.github.io/bayesian-flow-network-for-chemistry/](https://augus1999.github.io/bayesian-flow-network-for-chemistry/)
137
+ * ChemBFN package: [https://pypi.org/project/bayesianflow-for-chem/](https://pypi.org/project/bayesianflow-for-chem/)
@@ -1,16 +1,16 @@
1
1
  chembfn_webui/__init__.py,sha256=AXUdd_PrlfVO56losFUP7A8XrqCDPylwRbTpe_WG3Uc,87
2
- chembfn_webui/bin/app.py,sha256=GLXsqaZFmKu3dj35Ja-ygPUQSLK-uKgVIMxZQipXf5c,15809
2
+ chembfn_webui/bin/app.py,sha256=nD6M_e3v7aI6Iyfr3ntFXkpCC24LNeU0XaK-bT5EveA,18864
3
3
  chembfn_webui/cache/cache_file_here.txt,sha256=hi60T_q6Cf5WPtXuwe4CqjiWpaUqrczsmGMhKIUL--M,28
4
- chembfn_webui/cache/results.csv,sha256=cNmpygApXW6XLwkZfKkLRh6BwlwURkHZ17da8qUDjac,1670
5
- chembfn_webui/lib/utilities.py,sha256=bnAAhfryDpZpAMk5p0eURJ2nhgaXgTY5QWXITdL26gc,4476
6
- chembfn_webui/lib/version.py,sha256=3uax1uzsS9zcwmKGqogR9oHyvdv4l5UktCj3R9mW1p4,138
4
+ chembfn_webui/cache/results.csv,sha256=QDwo2y-HHfxbvsNY4Tp8jpLOXOwLhzapJIRaxwQ4BS0,107
5
+ chembfn_webui/lib/utilities.py,sha256=ALPw-Evjd9DdsU_RQA6Zp2Gc6XnRR7Y_5fZrqG9azWo,7460
6
+ chembfn_webui/lib/version.py,sha256=tOCr0-h9d8eZdkQ040lxB9yzvb9spVCyxqjIs-Tt5yc,138
7
7
  chembfn_webui/model/base_model/place_base_model_here.txt,sha256=oa8_ILaAlWpTXICVDi-Y46_OahV7wB6Che6gbiEIh-c,39
8
8
  chembfn_webui/model/lora/place_lora_folder_here.txt,sha256=YYOo0Cj278DyRcgVrCLa1f2Q-cqgNeMnelaLiA3Fuic,69
9
9
  chembfn_webui/model/standalone_model/place_standalone_model_folder_here.txt,sha256=Dp42UscfI0Zp3SnvRv5vOfWiJZnxdY7rG3jo0kf86VM,80
10
10
  chembfn_webui/model/vocab/place_vocabulary_file_here.txt,sha256=fLOINvZP2022oE7RsmfDjgyaw2yMi7glmdu_cTwmo88,28
11
- chembfn_webui-0.1.0.dist-info/licenses/LICENSE,sha256=hIahDEOTzuHCU5J2nd07LWwkLW7Hko4UFO__ffsvB-8,34523
12
- chembfn_webui-0.1.0.dist-info/METADATA,sha256=r9Obs3CWZy_ZK42c46gDXMAORUWQhAv4WhL_mpdEO4o,5125
13
- chembfn_webui-0.1.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
14
- chembfn_webui-0.1.0.dist-info/entry_points.txt,sha256=fp8WTPybvwpeYKrUhTi456wwZbmCMJXN1TeFGpR1SlY,55
15
- chembfn_webui-0.1.0.dist-info/top_level.txt,sha256=VdWt3Z7jhbB0pQO_mkRawnU5s75SBT9BV8fGaAIJTDI,14
16
- chembfn_webui-0.1.0.dist-info/RECORD,,
11
+ chembfn_webui-0.2.0.dist-info/licenses/LICENSE,sha256=hIahDEOTzuHCU5J2nd07LWwkLW7Hko4UFO__ffsvB-8,34523
12
+ chembfn_webui-0.2.0.dist-info/METADATA,sha256=qUKPuLkPeeq2zsGRaqVE_LEbVNtW7VLONUp9nHaLBM4,5710
13
+ chembfn_webui-0.2.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
14
+ chembfn_webui-0.2.0.dist-info/entry_points.txt,sha256=fp8WTPybvwpeYKrUhTi456wwZbmCMJXN1TeFGpR1SlY,55
15
+ chembfn_webui-0.2.0.dist-info/top_level.txt,sha256=VdWt3Z7jhbB0pQO_mkRawnU5s75SBT9BV8fGaAIJTDI,14
16
+ chembfn_webui-0.2.0.dist-info/RECORD,,