chembfn-webui 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- chembfn_webui/__init__.py +5 -0
- chembfn_webui/bin/app.py +439 -0
- chembfn_webui/cache/cache_file_here.txt +1 -0
- chembfn_webui/cache/results.csv +15 -0
- chembfn_webui/lib/utilities.py +134 -0
- chembfn_webui/lib/version.py +8 -0
- chembfn_webui/model/base_model/place_base_model_here.txt +1 -0
- chembfn_webui/model/lora/place_lora_folder_here.txt +1 -0
- chembfn_webui/model/standalone_model/place_standalone_model_folder_here.txt +1 -0
- chembfn_webui/model/vocab/place_vocabulary_file_here.txt +1 -0
- chembfn_webui-0.1.0.dist-info/METADATA +130 -0
- chembfn_webui-0.1.0.dist-info/RECORD +16 -0
- chembfn_webui-0.1.0.dist-info/WHEEL +5 -0
- chembfn_webui-0.1.0.dist-info/entry_points.txt +2 -0
- chembfn_webui-0.1.0.dist-info/licenses/LICENSE +661 -0
- chembfn_webui-0.1.0.dist-info/top_level.txt +1 -0
chembfn_webui/bin/app.py
ADDED
|
@@ -0,0 +1,439 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# Author: Nianze A. TAO (omozawa SUENO)
|
|
3
|
+
"""
|
|
4
|
+
Define application behaviours.
|
|
5
|
+
"""
|
|
6
|
+
import sys
|
|
7
|
+
import argparse
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from functools import partial
|
|
10
|
+
from typing import Tuple, List, Dict
|
|
11
|
+
|
|
12
|
+
sys.path.append(str(Path(__file__).parent.parent))
|
|
13
|
+
from rdkit.Chem import Draw, MolFromSmiles, MolFromFASTA
|
|
14
|
+
from mol2chemfigPy3 import mol2chemfig
|
|
15
|
+
import gradio as gr
|
|
16
|
+
import torch
|
|
17
|
+
from selfies import decoder
|
|
18
|
+
from bayesianflow_for_chem import ChemBFN, MLP, EnsembleChemBFN
|
|
19
|
+
from bayesianflow_for_chem.data import (
|
|
20
|
+
VOCAB_KEYS,
|
|
21
|
+
AA_VOCAB_KEYS,
|
|
22
|
+
load_vocab,
|
|
23
|
+
smiles2vec,
|
|
24
|
+
aa2vec,
|
|
25
|
+
split_selfies,
|
|
26
|
+
)
|
|
27
|
+
from bayesianflow_for_chem.tool import sample, inpaint, adjust_lora_
|
|
28
|
+
from lib.utilities import (
|
|
29
|
+
find_model,
|
|
30
|
+
find_vocab,
|
|
31
|
+
parse_prompt,
|
|
32
|
+
parse_exclude_token,
|
|
33
|
+
parse_sar_control,
|
|
34
|
+
)
|
|
35
|
+
from lib.version import __version__
|
|
36
|
+
|
|
37
|
+
vocabs = find_vocab()
|
|
38
|
+
models = find_model()
|
|
39
|
+
lora_selected = False # lora select flag
|
|
40
|
+
cache_dir = Path(__file__).parent.parent / "cache"
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def selfies2vec(sel: str, vocab_dict: Dict[str, int]) -> List[int]:
|
|
44
|
+
s = split_selfies(sel)
|
|
45
|
+
unknown_id = None
|
|
46
|
+
for key, idx in vocab_dict.items():
|
|
47
|
+
if "unknown" in key.lower():
|
|
48
|
+
unknown_id = idx
|
|
49
|
+
break
|
|
50
|
+
return [vocab_dict.get(i, default=unknown_id) for i in s]
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def refresh(
|
|
54
|
+
model_selected: str, vocab_selected: str, tokeniser_selected: str
|
|
55
|
+
) -> Tuple[
|
|
56
|
+
List[str], List[str], List[List[str]], List[List[str]], gr.Dropdown, gr.Dropdown
|
|
57
|
+
]:
|
|
58
|
+
global vocabs, models
|
|
59
|
+
vocabs = find_vocab()
|
|
60
|
+
models = find_model()
|
|
61
|
+
a = list(vocabs.keys())
|
|
62
|
+
b = [i[0] for i in models["base"]]
|
|
63
|
+
c = [[i[0], i[2]] for i in models["standalone"]]
|
|
64
|
+
d = [[i[0], i[2]] for i in models["lora"]]
|
|
65
|
+
e = gr.Dropdown(
|
|
66
|
+
[i[0] for i in models["base"]] + [i[0] for i in models["standalone"]],
|
|
67
|
+
value=model_selected,
|
|
68
|
+
label="model",
|
|
69
|
+
)
|
|
70
|
+
f = gr.Dropdown(
|
|
71
|
+
list(vocabs.keys()),
|
|
72
|
+
value=vocab_selected,
|
|
73
|
+
label="vocabulary",
|
|
74
|
+
visible=tokeniser_selected == "SELFIES",
|
|
75
|
+
)
|
|
76
|
+
return a, b, c, d, e, f
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def select_lora(evt: gr.SelectData, prompt: str) -> str:
|
|
80
|
+
global lora_selected
|
|
81
|
+
if lora_selected: # avoid double select
|
|
82
|
+
lora_selected = False
|
|
83
|
+
return prompt
|
|
84
|
+
selected_lora = evt.value
|
|
85
|
+
lora_selected = True
|
|
86
|
+
if evt.index[1] != 0:
|
|
87
|
+
return prompt
|
|
88
|
+
if not prompt:
|
|
89
|
+
return f"<{selected_lora}:1>"
|
|
90
|
+
return f"{prompt};\n<{selected_lora}:1>"
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def run(
|
|
94
|
+
model_name: str,
|
|
95
|
+
token_name: str,
|
|
96
|
+
vocab_fn: str,
|
|
97
|
+
step: int,
|
|
98
|
+
batch_size: int,
|
|
99
|
+
sequence_size: int,
|
|
100
|
+
guidance_strength: float,
|
|
101
|
+
method: str,
|
|
102
|
+
temperature: float,
|
|
103
|
+
prompt: str,
|
|
104
|
+
scaffold: str,
|
|
105
|
+
sar_control: str,
|
|
106
|
+
exclude_token: str,
|
|
107
|
+
) -> Tuple[List, List[str], str, str, str]:
|
|
108
|
+
_message = []
|
|
109
|
+
base_model_dict = dict(models["base"])
|
|
110
|
+
standalone_model_dict = dict([[i[0], i[1]] for i in models["standalone"]])
|
|
111
|
+
lora_model_dict = dict([[i[0], i[1]] for i in models["lora"]])
|
|
112
|
+
standalone_label_dict = dict([[i[0], i[2] != []] for i in models["standalone"]])
|
|
113
|
+
lora_label_dict = dict([[i[0], i[2] != []] for i in models["lora"]])
|
|
114
|
+
standalone_lmax_dict = dict([[i[0], i[3]] for i in models["standalone"]])
|
|
115
|
+
lora_lmax_dict = dict([[i[0], i[3]] for i in models["lora"]])
|
|
116
|
+
if token_name == "SMILES & SAFE":
|
|
117
|
+
vocab_keys = VOCAB_KEYS
|
|
118
|
+
tokeniser = smiles2vec
|
|
119
|
+
trans_fn = lambda x: [i for i in x if (MolFromSmiles(i) and i)]
|
|
120
|
+
img_fn = lambda x: [Draw.MolToImage(MolFromSmiles(i), (500, 500)) for i in x]
|
|
121
|
+
chemfig_fn = lambda x: [mol2chemfig(i, "-r", inline=True) for i in x]
|
|
122
|
+
if token_name == "FASTA":
|
|
123
|
+
vocab_keys = AA_VOCAB_KEYS
|
|
124
|
+
tokeniser = aa2vec
|
|
125
|
+
trans_fn = lambda x: x
|
|
126
|
+
img_fn = lambda x: [Draw.MolToImage(MolFromFASTA(i), (500, 500)) for i in x]
|
|
127
|
+
chemfig_fn = lambda x: ["null" for _ in x]
|
|
128
|
+
if token_name == "SELFIES":
|
|
129
|
+
vocab_data = load_vocab(vocabs[vocab_fn])
|
|
130
|
+
vocab_keys = vocab_data["vocab_keys"]
|
|
131
|
+
vocab_dict = vocab_data["vocab_dict"]
|
|
132
|
+
tokeniser = partial(selfies2vec, vocab_dict=vocab_dict)
|
|
133
|
+
trans_fn = lambda x: x
|
|
134
|
+
img_fn = lambda x: [
|
|
135
|
+
Draw.MolToImage(MolFromSmiles(decoder(i)), (500, 500)) for i in x
|
|
136
|
+
]
|
|
137
|
+
chemfig_fn = lambda x: [mol2chemfig(decoder(i), "-r", inline=True) for i in x]
|
|
138
|
+
_method = "bfn" if method == "BFN" else f"ode:{temperature}"
|
|
139
|
+
# ------- build model -------
|
|
140
|
+
prompt_info = parse_prompt(prompt)
|
|
141
|
+
sar_flag = parse_sar_control(sar_control)
|
|
142
|
+
print(prompt_info)
|
|
143
|
+
if not prompt_info["lora"]:
|
|
144
|
+
if model_name in base_model_dict:
|
|
145
|
+
lmax = sequence_size
|
|
146
|
+
bfn = ChemBFN.from_checkpoint(base_model_dict[model_name])
|
|
147
|
+
y = None
|
|
148
|
+
if prompt_info["objective"]:
|
|
149
|
+
_message.append("Objective values ignored by base model.")
|
|
150
|
+
else:
|
|
151
|
+
lmax = standalone_lmax_dict[model_name]
|
|
152
|
+
bfn = ChemBFN.from_checkpoint(
|
|
153
|
+
standalone_model_dict[model_name] / "model.pt"
|
|
154
|
+
)
|
|
155
|
+
if prompt_info["objective"]:
|
|
156
|
+
if not standalone_label_dict[model_name]:
|
|
157
|
+
y = None
|
|
158
|
+
_message.append("Objective values ignored.")
|
|
159
|
+
else:
|
|
160
|
+
mlp = MLP.from_checkpoint(
|
|
161
|
+
standalone_model_dict[model_name] / "mlp.pt"
|
|
162
|
+
)
|
|
163
|
+
y = torch.tensor([prompt_info["objective"]], dtype=torch.float32)
|
|
164
|
+
y = mlp.forward(y)
|
|
165
|
+
else:
|
|
166
|
+
y = None
|
|
167
|
+
_message.append(f"Sequence length is set to {lmax} from model metadata.")
|
|
168
|
+
bfn.semi_autoregressive = sar_flag[0]
|
|
169
|
+
elif len(prompt_info["lora"]) == 1:
|
|
170
|
+
lmax = lora_lmax_dict[prompt_info["lora"][0]]
|
|
171
|
+
if model_name in base_model_dict:
|
|
172
|
+
bfn = ChemBFN.from_checkpoint(
|
|
173
|
+
base_model_dict[model_name],
|
|
174
|
+
lora_model_dict[prompt_info["lora"][0]] / "lora.pt",
|
|
175
|
+
)
|
|
176
|
+
else:
|
|
177
|
+
bfn = ChemBFN.from_checkpoint(
|
|
178
|
+
standalone_model_dict[model_name] / "model.pt",
|
|
179
|
+
lora_model_dict[prompt_info["lora"][0]] / "lora.pt",
|
|
180
|
+
)
|
|
181
|
+
if prompt_info["objective"]:
|
|
182
|
+
if not lora_label_dict[prompt_info["lora"][0]]:
|
|
183
|
+
y = None
|
|
184
|
+
_message.append("Objective values ignored.")
|
|
185
|
+
else:
|
|
186
|
+
mlp = MLP.from_checkpoint(
|
|
187
|
+
lora_model_dict[prompt_info["lora"][0]] / "mlp.pt"
|
|
188
|
+
)
|
|
189
|
+
y = torch.tensor([prompt_info["objective"]], dtype=torch.float32)
|
|
190
|
+
y = mlp.forward(y)
|
|
191
|
+
else:
|
|
192
|
+
y = None
|
|
193
|
+
if prompt_info["lora_scaling"][0] != 1.0:
|
|
194
|
+
adjust_lora_(bfn, prompt_info["lora_scaling"][0])
|
|
195
|
+
_message.append(f"Sequence length is set to {lmax} from model metadata.")
|
|
196
|
+
bfn.semi_autoregressive = sar_flag[0]
|
|
197
|
+
else:
|
|
198
|
+
lmax = max([lora_lmax_dict[i] for i in prompt_info["lora"]])
|
|
199
|
+
if model_name in base_model_dict:
|
|
200
|
+
base_model_dir = base_model_dict[model_name]
|
|
201
|
+
else:
|
|
202
|
+
base_model_dir = standalone_model_dict[model_name] / "model.pt"
|
|
203
|
+
lmax = max([lmax, standalone_lmax_dict[model_name]])
|
|
204
|
+
lora_dir = [lora_model_dict[i] / "lora.pt" for i in prompt_info["lora"]]
|
|
205
|
+
mlps = [
|
|
206
|
+
MLP.from_checkpoint(lora_model_dict[i] / "mlp.pt")
|
|
207
|
+
for i in prompt_info["lora"]
|
|
208
|
+
]
|
|
209
|
+
weights = prompt_info["lora_scaling"]
|
|
210
|
+
if len(sar_flag) == 1:
|
|
211
|
+
sar_flag = [sar_flag[0] for _ in range(len(weights))]
|
|
212
|
+
bfn = EnsembleChemBFN(base_model_dir, lora_dir, mlps, weights)
|
|
213
|
+
y = [torch.tensor([i], dtype=torch.float32) for i in prompt_info["objective"]]
|
|
214
|
+
_message.append(f"Sequence length is set to {lmax} from model metadata.")
|
|
215
|
+
# ------- inference -------
|
|
216
|
+
allowed_tokens = parse_exclude_token(exclude_token, vocab_keys)
|
|
217
|
+
if not allowed_tokens:
|
|
218
|
+
allowed_tokens = "all"
|
|
219
|
+
scaffold = scaffold.strip()
|
|
220
|
+
if not scaffold:
|
|
221
|
+
mols = sample(
|
|
222
|
+
bfn,
|
|
223
|
+
batch_size,
|
|
224
|
+
lmax,
|
|
225
|
+
step,
|
|
226
|
+
y,
|
|
227
|
+
guidance_strength,
|
|
228
|
+
vocab_keys=vocab_keys,
|
|
229
|
+
method=_method,
|
|
230
|
+
allowed_tokens=allowed_tokens,
|
|
231
|
+
)
|
|
232
|
+
mols = trans_fn(mols)
|
|
233
|
+
imgs = img_fn(mols)
|
|
234
|
+
chemfigs = chemfig_fn(mols)
|
|
235
|
+
else:
|
|
236
|
+
x = [1] + tokeniser(scaffold)
|
|
237
|
+
x = x + [0 for _ in range(lmax - len(x))]
|
|
238
|
+
x = torch.tensor([x], dtype=torch.long).repeat(batch_size, 1)
|
|
239
|
+
mols = inpaint(
|
|
240
|
+
bfn,
|
|
241
|
+
x,
|
|
242
|
+
step,
|
|
243
|
+
y,
|
|
244
|
+
guidance_strength,
|
|
245
|
+
vocab_keys=vocab_keys,
|
|
246
|
+
method=_method,
|
|
247
|
+
allowed_tokens=allowed_tokens,
|
|
248
|
+
)
|
|
249
|
+
mols = trans_fn(mols)
|
|
250
|
+
imgs = img_fn(mols)
|
|
251
|
+
chemfigs = chemfig_fn(mols)
|
|
252
|
+
n_mol = len(mols)
|
|
253
|
+
with open(cache_dir / "results.csv", "w", encoding="utf-8", newline="") as rf:
|
|
254
|
+
rf.write("\n".join(mols))
|
|
255
|
+
_message.append(
|
|
256
|
+
f"{n_mol} smaples generated and saved to cache that can be downloaded."
|
|
257
|
+
)
|
|
258
|
+
return (
|
|
259
|
+
imgs,
|
|
260
|
+
mols,
|
|
261
|
+
"\n\n".join(chemfigs),
|
|
262
|
+
"\n".join(_message),
|
|
263
|
+
str(cache_dir / "results.csv"),
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
with gr.Blocks(title="ChemBFN WebUI") as app:
|
|
268
|
+
gr.Markdown("### WebUI to generate and visualise molecules for ChemBFN method.")
|
|
269
|
+
with gr.Row():
|
|
270
|
+
with gr.Column(scale=1):
|
|
271
|
+
btn = gr.Button("RUN", variant="primary")
|
|
272
|
+
model_name = gr.Dropdown(
|
|
273
|
+
[i[0] for i in models["base"]] + [i[0] for i in models["standalone"]],
|
|
274
|
+
label="model",
|
|
275
|
+
)
|
|
276
|
+
token_name = gr.Dropdown(
|
|
277
|
+
["SMILES & SAFE", "SELFIES", "FASTA"], label="tokeniser"
|
|
278
|
+
)
|
|
279
|
+
vocab_fn = gr.Dropdown(
|
|
280
|
+
list(vocabs.keys()),
|
|
281
|
+
label="vocabulary",
|
|
282
|
+
visible=token_name.value == "SELFIES",
|
|
283
|
+
)
|
|
284
|
+
step = gr.Slider(1, 5000, 100, step=1, precision=0, label="step")
|
|
285
|
+
batch_size = gr.Slider(1, 512, 1, step=1, precision=0, label="batch size")
|
|
286
|
+
sequence_size = gr.Slider(
|
|
287
|
+
5, 4096, 50, step=1, precision=0, label="sequence length"
|
|
288
|
+
)
|
|
289
|
+
guidance_strength = gr.Slider(
|
|
290
|
+
0, 25, 4, step=0.05, label="guidance strength"
|
|
291
|
+
)
|
|
292
|
+
method = gr.Dropdown(["BFN", "ODE"], label="method")
|
|
293
|
+
temperature = gr.Slider(
|
|
294
|
+
0.0,
|
|
295
|
+
2.5,
|
|
296
|
+
0.5,
|
|
297
|
+
step=0.001,
|
|
298
|
+
label="temperature",
|
|
299
|
+
visible=method.value == "ODE",
|
|
300
|
+
)
|
|
301
|
+
with gr.Column(scale=2):
|
|
302
|
+
with gr.Tab(label="prompt editor"):
|
|
303
|
+
prompt = gr.TextArea(label="prompt", lines=12)
|
|
304
|
+
scaffold = gr.Textbox(label="scaffold")
|
|
305
|
+
gr.Markdown("")
|
|
306
|
+
message = gr.TextArea(label="message")
|
|
307
|
+
with gr.Tab(label="result viewer"):
|
|
308
|
+
with gr.Tab(label="result"):
|
|
309
|
+
btn_download = gr.File(label="download", visible=False)
|
|
310
|
+
result = gr.Dataframe(
|
|
311
|
+
headers=["molecule"],
|
|
312
|
+
col_count=(1, "fixed"),
|
|
313
|
+
label="",
|
|
314
|
+
show_fullscreen_button=True,
|
|
315
|
+
show_row_numbers=True,
|
|
316
|
+
show_copy_button=True,
|
|
317
|
+
)
|
|
318
|
+
with gr.Tab(label="LATEX Chemfig"):
|
|
319
|
+
chemfig = gr.Code(
|
|
320
|
+
label="", language="latex", show_line_numbers=True
|
|
321
|
+
)
|
|
322
|
+
with gr.Tab(label="gallery"):
|
|
323
|
+
img = gr.Gallery(label="molecule", columns=4, height=512)
|
|
324
|
+
with gr.Tab(label="model explorer"):
|
|
325
|
+
btn_refresh = gr.Button("refresh", variant="secondary")
|
|
326
|
+
with gr.Tab(label="customised vocabulary"):
|
|
327
|
+
vocab_table = gr.Dataframe(
|
|
328
|
+
list(vocabs.keys()),
|
|
329
|
+
headers=["name"],
|
|
330
|
+
col_count=(1, "fixed"),
|
|
331
|
+
label="",
|
|
332
|
+
interactive=False,
|
|
333
|
+
show_row_numbers=True,
|
|
334
|
+
)
|
|
335
|
+
with gr.Tab(label="base models"):
|
|
336
|
+
base_table = gr.Dataframe(
|
|
337
|
+
[i[0] for i in models["base"]],
|
|
338
|
+
headers=["name"],
|
|
339
|
+
col_count=(1, "fixed"),
|
|
340
|
+
label="",
|
|
341
|
+
interactive=False,
|
|
342
|
+
show_row_numbers=True,
|
|
343
|
+
)
|
|
344
|
+
with gr.Tab(label="standalone models"):
|
|
345
|
+
standalone_table = gr.Dataframe(
|
|
346
|
+
[[i[0], i[2]] for i in models["standalone"]],
|
|
347
|
+
headers=["name", "objective"],
|
|
348
|
+
col_count=(2, "fixed"),
|
|
349
|
+
label="",
|
|
350
|
+
interactive=False,
|
|
351
|
+
show_row_numbers=True,
|
|
352
|
+
)
|
|
353
|
+
with gr.Tab(label="LoRA models"):
|
|
354
|
+
lora_tabel = gr.Dataframe(
|
|
355
|
+
[[i[0], i[2]] for i in models["lora"]],
|
|
356
|
+
headers=["name", "objective"],
|
|
357
|
+
col_count=(2, "fixed"),
|
|
358
|
+
label="",
|
|
359
|
+
interactive=False,
|
|
360
|
+
show_row_numbers=True,
|
|
361
|
+
)
|
|
362
|
+
with gr.Tab(label="advanced control"):
|
|
363
|
+
sar_control = gr.Textbox("F", label="semi-autoregressive behaviour")
|
|
364
|
+
gr.Markdown("")
|
|
365
|
+
exclude_token = gr.TextArea(
|
|
366
|
+
label="exclude tokens",
|
|
367
|
+
placeholder="key in unwanted tokens separated by comma.",
|
|
368
|
+
)
|
|
369
|
+
# ------ user interaction events -------
|
|
370
|
+
btn.click(
|
|
371
|
+
fn=run,
|
|
372
|
+
inputs=[
|
|
373
|
+
model_name,
|
|
374
|
+
token_name,
|
|
375
|
+
vocab_fn,
|
|
376
|
+
step,
|
|
377
|
+
batch_size,
|
|
378
|
+
sequence_size,
|
|
379
|
+
guidance_strength,
|
|
380
|
+
method,
|
|
381
|
+
temperature,
|
|
382
|
+
prompt,
|
|
383
|
+
scaffold,
|
|
384
|
+
sar_control,
|
|
385
|
+
exclude_token,
|
|
386
|
+
],
|
|
387
|
+
outputs=[img, result, chemfig, message, btn_download],
|
|
388
|
+
)
|
|
389
|
+
btn_refresh.click(
|
|
390
|
+
fn=refresh,
|
|
391
|
+
inputs=[model_name, vocab_fn, token_name],
|
|
392
|
+
outputs=[
|
|
393
|
+
vocab_table,
|
|
394
|
+
base_table,
|
|
395
|
+
standalone_table,
|
|
396
|
+
lora_tabel,
|
|
397
|
+
model_name,
|
|
398
|
+
vocab_fn,
|
|
399
|
+
],
|
|
400
|
+
)
|
|
401
|
+
token_name.input(
|
|
402
|
+
fn=lambda x, y: gr.Dropdown(
|
|
403
|
+
list(vocabs.keys()), value=y, label="vocabulary", visible=x == "SELFIES"
|
|
404
|
+
),
|
|
405
|
+
inputs=[token_name, vocab_fn],
|
|
406
|
+
outputs=vocab_fn,
|
|
407
|
+
)
|
|
408
|
+
method.input(
|
|
409
|
+
fn=lambda x, y: gr.Slider(
|
|
410
|
+
0.0,
|
|
411
|
+
2.5,
|
|
412
|
+
y,
|
|
413
|
+
step=0.001,
|
|
414
|
+
label="temperature",
|
|
415
|
+
visible=x == "ODE",
|
|
416
|
+
),
|
|
417
|
+
inputs=[method, temperature],
|
|
418
|
+
outputs=temperature,
|
|
419
|
+
)
|
|
420
|
+
lora_tabel.select(fn=select_lora, inputs=prompt, outputs=prompt)
|
|
421
|
+
result.change(
|
|
422
|
+
fn=lambda x: gr.File(x, label="download", visible=True),
|
|
423
|
+
inputs=btn_download,
|
|
424
|
+
outputs=btn_download,
|
|
425
|
+
)
|
|
426
|
+
|
|
427
|
+
|
|
428
|
+
def main() -> None:
|
|
429
|
+
parser = argparse.ArgumentParser()
|
|
430
|
+
parser.add_argument(
|
|
431
|
+
"--public", default=False, help="open to public", action="store_true"
|
|
432
|
+
)
|
|
433
|
+
parser.add_argument("-V", "--version", action="version", version=__version__)
|
|
434
|
+
args = parser.parse_args()
|
|
435
|
+
app.launch(share=args.public)
|
|
436
|
+
|
|
437
|
+
|
|
438
|
+
if __name__ == "__main__":
|
|
439
|
+
main()
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
Cache file will appear here.
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
[C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][O]
|
|
2
|
+
[C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][=Branch1][C][=O][N][C][C][C]
|
|
3
|
+
[C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][N][C][C][Branch1][=Branch2][C][=C][C][=C][C][=C][Ring1][=Branch1][C][=C][C][=C][C][=C][Ring1][=Branch1]
|
|
4
|
+
[C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][O]
|
|
5
|
+
[C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C]
|
|
6
|
+
[C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C]
|
|
7
|
+
[C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C]
|
|
8
|
+
[C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C]
|
|
9
|
+
[C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C]
|
|
10
|
+
[C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C]
|
|
11
|
+
[O][=C][Branch2][Ring1][C][N][N][C][=Branch1][C][=O][C][=C][C][=C][Branch1][C][Cl][C][=C][Ring1][#Branch1][C][=C][C][=C][C][=C][Ring1][=Branch1]
|
|
12
|
+
[C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C]
|
|
13
|
+
[C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][=Branch1][C][=O][O]
|
|
14
|
+
[C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C]
|
|
15
|
+
[C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C]
|
|
@@ -0,0 +1,134 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# Author: Nianze A. TAO (omozawa SUENO)
|
|
3
|
+
"""
|
|
4
|
+
Utilities.
|
|
5
|
+
"""
|
|
6
|
+
import os
|
|
7
|
+
import json
|
|
8
|
+
from glob import glob
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import Dict, List, Tuple, Union
|
|
11
|
+
|
|
12
|
+
_model_path = Path(__file__).parent.parent / "model"
|
|
13
|
+
if "CHEMBFN_WEBUI_MODEL_DIR" in os.environ:
|
|
14
|
+
_model_path = Path(os.environ["CHEMBFN_WEBUI_MODEL_DIR"])
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def find_vocab() -> Dict[str, str]:
|
|
18
|
+
vocab_fns = glob(str(_model_path / "vocab/*.txt"))
|
|
19
|
+
return {
|
|
20
|
+
os.path.basename(i).replace(".txt", ""): i
|
|
21
|
+
for i in vocab_fns
|
|
22
|
+
if "place_vocabulary_file_here.txt" not in i
|
|
23
|
+
}
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def find_model() -> Dict[str, List[List[Union[str, int, List[str], Path]]]]:
|
|
27
|
+
models = {}
|
|
28
|
+
# find base models
|
|
29
|
+
base_fns = glob(str(_model_path / "base_model/*.pt"))
|
|
30
|
+
models["base"] = [[os.path.basename(i), i] for i in base_fns]
|
|
31
|
+
# find standalone models
|
|
32
|
+
standalone_models = []
|
|
33
|
+
standalone_fns = glob(str(_model_path / "standalone_model/*/model.pt"))
|
|
34
|
+
for standalone_fn in standalone_fns:
|
|
35
|
+
config_fn = Path(standalone_fn).parent / "config.json"
|
|
36
|
+
if not os.path.exists(config_fn):
|
|
37
|
+
continue
|
|
38
|
+
else:
|
|
39
|
+
with open(config_fn, "r", encoding="utf-8") as f:
|
|
40
|
+
config = json.load(f)
|
|
41
|
+
name = config["name"]
|
|
42
|
+
label = config["label"]
|
|
43
|
+
lmax = config["padding_length"]
|
|
44
|
+
standalone_models.append([name, Path(standalone_fn).parent, label, lmax])
|
|
45
|
+
models["standalone"] = standalone_models
|
|
46
|
+
# find lora models
|
|
47
|
+
lora_models = []
|
|
48
|
+
lora_fns = glob(str(_model_path / "lora/*/lora.pt"))
|
|
49
|
+
for lora_fn in lora_fns:
|
|
50
|
+
config_fn = Path(lora_fn).parent / "config.json"
|
|
51
|
+
if not os.path.exists(config_fn):
|
|
52
|
+
continue
|
|
53
|
+
else:
|
|
54
|
+
with open(config_fn, "r", encoding="utf-8") as f:
|
|
55
|
+
config = json.load(f)
|
|
56
|
+
name = config["name"]
|
|
57
|
+
label = config["label"]
|
|
58
|
+
lmax = config["padding_length"]
|
|
59
|
+
lora_models.append([name, Path(lora_fn).parent, label, lmax])
|
|
60
|
+
models["lora"] = lora_models
|
|
61
|
+
return models
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def _get_lora_info(prompt: str) -> Tuple[str, List[float], List[float]]:
|
|
65
|
+
s = prompt.split(">")
|
|
66
|
+
s1 = s[0].replace("<", "")
|
|
67
|
+
lora_info = s1.split(":")
|
|
68
|
+
lora_name = lora_info[0]
|
|
69
|
+
if len(lora_info) == 1:
|
|
70
|
+
lora_scaling = 1.0
|
|
71
|
+
else:
|
|
72
|
+
lora_scaling = float(lora_info[1])
|
|
73
|
+
if len(s) == 1:
|
|
74
|
+
obj = []
|
|
75
|
+
elif ":" not in s[1]:
|
|
76
|
+
obj = []
|
|
77
|
+
else:
|
|
78
|
+
s2 = s[1].replace(":", "").replace("[", "").replace("]", "").split(",")
|
|
79
|
+
obj = [float(i) for i in s2]
|
|
80
|
+
return lora_name, obj, lora_scaling
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def parse_prompt(
|
|
84
|
+
prompt: str,
|
|
85
|
+
) -> Dict[str, Union[List[str], List[float], List[List[float]]]]:
|
|
86
|
+
prompt_group = prompt.strip().replace("\n", "").split(";")
|
|
87
|
+
prompt_group = [i for i in prompt_group if i]
|
|
88
|
+
info = {"lora": [], "objective": [], "lora_scaling": []}
|
|
89
|
+
if not prompt_group:
|
|
90
|
+
pass
|
|
91
|
+
if len(prompt_group) == 1:
|
|
92
|
+
if not ("<" in prompt_group[0] and ">" in prompt_group[0]):
|
|
93
|
+
obj = [
|
|
94
|
+
float(i)
|
|
95
|
+
for i in prompt_group[0].replace("[", "").replace("]", "").split(",")
|
|
96
|
+
]
|
|
97
|
+
info["objective"].append(obj)
|
|
98
|
+
else:
|
|
99
|
+
lora_name, obj, lora_scaling = _get_lora_info(prompt_group[0])
|
|
100
|
+
info["lora"].append(lora_name)
|
|
101
|
+
if obj:
|
|
102
|
+
info["objective"].append(obj)
|
|
103
|
+
info["lora_scaling"].append(lora_scaling)
|
|
104
|
+
else:
|
|
105
|
+
for _prompt in prompt_group:
|
|
106
|
+
if not ("<" in _prompt and ">" in _prompt):
|
|
107
|
+
continue
|
|
108
|
+
lora_name, obj, lora_scaling = _get_lora_info(_prompt)
|
|
109
|
+
info["lora"].append(lora_name)
|
|
110
|
+
if obj:
|
|
111
|
+
info["objective"].append(obj)
|
|
112
|
+
info["lora_scaling"].append(lora_scaling)
|
|
113
|
+
return info
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def parse_exclude_token(tokens: str, vocab_keys: List[str]) -> List[str]:
|
|
117
|
+
tokens = tokens.strip().replace("\n", "").split(",")
|
|
118
|
+
tokens = [i for i in tokens if i]
|
|
119
|
+
if not tokens:
|
|
120
|
+
return tokens
|
|
121
|
+
tokens = [i for i in vocab_keys if i not in tokens]
|
|
122
|
+
return tokens
|
|
123
|
+
|
|
124
|
+
def parse_sar_control(sar_control: str) -> List[bool]:
|
|
125
|
+
sar_flag = sar_control.strip().replace("\n", "").split(",")
|
|
126
|
+
sar_flag = [i for i in sar_flag if i]
|
|
127
|
+
if not sar_flag:
|
|
128
|
+
return [False]
|
|
129
|
+
sar_flag = [i.lower() == "t" for i in sar_flag]
|
|
130
|
+
return sar_flag
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
if __name__ == "__main__":
|
|
134
|
+
...
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
Place thy base model weight files here.
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
Place thy LoRA weight and configureation files under subfolders here.
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
Place thy standalone model weight and configuration files under subfloders here.
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
Place vocabulary files here.
|