chembfn-webui 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,5 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Author: Nianze A. TAO (omozawa SUENO)
3
+ """
4
+ ChemBFN WebUI.
5
+ """
@@ -0,0 +1,439 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Author: Nianze A. TAO (omozawa SUENO)
3
+ """
4
+ Define application behaviours.
5
+ """
6
+ import sys
7
+ import argparse
8
+ from pathlib import Path
9
+ from functools import partial
10
+ from typing import Tuple, List, Dict
11
+
12
+ sys.path.append(str(Path(__file__).parent.parent))
13
+ from rdkit.Chem import Draw, MolFromSmiles, MolFromFASTA
14
+ from mol2chemfigPy3 import mol2chemfig
15
+ import gradio as gr
16
+ import torch
17
+ from selfies import decoder
18
+ from bayesianflow_for_chem import ChemBFN, MLP, EnsembleChemBFN
19
+ from bayesianflow_for_chem.data import (
20
+ VOCAB_KEYS,
21
+ AA_VOCAB_KEYS,
22
+ load_vocab,
23
+ smiles2vec,
24
+ aa2vec,
25
+ split_selfies,
26
+ )
27
+ from bayesianflow_for_chem.tool import sample, inpaint, adjust_lora_
28
+ from lib.utilities import (
29
+ find_model,
30
+ find_vocab,
31
+ parse_prompt,
32
+ parse_exclude_token,
33
+ parse_sar_control,
34
+ )
35
+ from lib.version import __version__
36
+
37
+ vocabs = find_vocab()
38
+ models = find_model()
39
+ lora_selected = False # lora select flag
40
+ cache_dir = Path(__file__).parent.parent / "cache"
41
+
42
+
43
+ def selfies2vec(sel: str, vocab_dict: Dict[str, int]) -> List[int]:
44
+ s = split_selfies(sel)
45
+ unknown_id = None
46
+ for key, idx in vocab_dict.items():
47
+ if "unknown" in key.lower():
48
+ unknown_id = idx
49
+ break
50
+ return [vocab_dict.get(i, default=unknown_id) for i in s]
51
+
52
+
53
+ def refresh(
54
+ model_selected: str, vocab_selected: str, tokeniser_selected: str
55
+ ) -> Tuple[
56
+ List[str], List[str], List[List[str]], List[List[str]], gr.Dropdown, gr.Dropdown
57
+ ]:
58
+ global vocabs, models
59
+ vocabs = find_vocab()
60
+ models = find_model()
61
+ a = list(vocabs.keys())
62
+ b = [i[0] for i in models["base"]]
63
+ c = [[i[0], i[2]] for i in models["standalone"]]
64
+ d = [[i[0], i[2]] for i in models["lora"]]
65
+ e = gr.Dropdown(
66
+ [i[0] for i in models["base"]] + [i[0] for i in models["standalone"]],
67
+ value=model_selected,
68
+ label="model",
69
+ )
70
+ f = gr.Dropdown(
71
+ list(vocabs.keys()),
72
+ value=vocab_selected,
73
+ label="vocabulary",
74
+ visible=tokeniser_selected == "SELFIES",
75
+ )
76
+ return a, b, c, d, e, f
77
+
78
+
79
+ def select_lora(evt: gr.SelectData, prompt: str) -> str:
80
+ global lora_selected
81
+ if lora_selected: # avoid double select
82
+ lora_selected = False
83
+ return prompt
84
+ selected_lora = evt.value
85
+ lora_selected = True
86
+ if evt.index[1] != 0:
87
+ return prompt
88
+ if not prompt:
89
+ return f"<{selected_lora}:1>"
90
+ return f"{prompt};\n<{selected_lora}:1>"
91
+
92
+
93
+ def run(
94
+ model_name: str,
95
+ token_name: str,
96
+ vocab_fn: str,
97
+ step: int,
98
+ batch_size: int,
99
+ sequence_size: int,
100
+ guidance_strength: float,
101
+ method: str,
102
+ temperature: float,
103
+ prompt: str,
104
+ scaffold: str,
105
+ sar_control: str,
106
+ exclude_token: str,
107
+ ) -> Tuple[List, List[str], str, str, str]:
108
+ _message = []
109
+ base_model_dict = dict(models["base"])
110
+ standalone_model_dict = dict([[i[0], i[1]] for i in models["standalone"]])
111
+ lora_model_dict = dict([[i[0], i[1]] for i in models["lora"]])
112
+ standalone_label_dict = dict([[i[0], i[2] != []] for i in models["standalone"]])
113
+ lora_label_dict = dict([[i[0], i[2] != []] for i in models["lora"]])
114
+ standalone_lmax_dict = dict([[i[0], i[3]] for i in models["standalone"]])
115
+ lora_lmax_dict = dict([[i[0], i[3]] for i in models["lora"]])
116
+ if token_name == "SMILES & SAFE":
117
+ vocab_keys = VOCAB_KEYS
118
+ tokeniser = smiles2vec
119
+ trans_fn = lambda x: [i for i in x if (MolFromSmiles(i) and i)]
120
+ img_fn = lambda x: [Draw.MolToImage(MolFromSmiles(i), (500, 500)) for i in x]
121
+ chemfig_fn = lambda x: [mol2chemfig(i, "-r", inline=True) for i in x]
122
+ if token_name == "FASTA":
123
+ vocab_keys = AA_VOCAB_KEYS
124
+ tokeniser = aa2vec
125
+ trans_fn = lambda x: x
126
+ img_fn = lambda x: [Draw.MolToImage(MolFromFASTA(i), (500, 500)) for i in x]
127
+ chemfig_fn = lambda x: ["null" for _ in x]
128
+ if token_name == "SELFIES":
129
+ vocab_data = load_vocab(vocabs[vocab_fn])
130
+ vocab_keys = vocab_data["vocab_keys"]
131
+ vocab_dict = vocab_data["vocab_dict"]
132
+ tokeniser = partial(selfies2vec, vocab_dict=vocab_dict)
133
+ trans_fn = lambda x: x
134
+ img_fn = lambda x: [
135
+ Draw.MolToImage(MolFromSmiles(decoder(i)), (500, 500)) for i in x
136
+ ]
137
+ chemfig_fn = lambda x: [mol2chemfig(decoder(i), "-r", inline=True) for i in x]
138
+ _method = "bfn" if method == "BFN" else f"ode:{temperature}"
139
+ # ------- build model -------
140
+ prompt_info = parse_prompt(prompt)
141
+ sar_flag = parse_sar_control(sar_control)
142
+ print(prompt_info)
143
+ if not prompt_info["lora"]:
144
+ if model_name in base_model_dict:
145
+ lmax = sequence_size
146
+ bfn = ChemBFN.from_checkpoint(base_model_dict[model_name])
147
+ y = None
148
+ if prompt_info["objective"]:
149
+ _message.append("Objective values ignored by base model.")
150
+ else:
151
+ lmax = standalone_lmax_dict[model_name]
152
+ bfn = ChemBFN.from_checkpoint(
153
+ standalone_model_dict[model_name] / "model.pt"
154
+ )
155
+ if prompt_info["objective"]:
156
+ if not standalone_label_dict[model_name]:
157
+ y = None
158
+ _message.append("Objective values ignored.")
159
+ else:
160
+ mlp = MLP.from_checkpoint(
161
+ standalone_model_dict[model_name] / "mlp.pt"
162
+ )
163
+ y = torch.tensor([prompt_info["objective"]], dtype=torch.float32)
164
+ y = mlp.forward(y)
165
+ else:
166
+ y = None
167
+ _message.append(f"Sequence length is set to {lmax} from model metadata.")
168
+ bfn.semi_autoregressive = sar_flag[0]
169
+ elif len(prompt_info["lora"]) == 1:
170
+ lmax = lora_lmax_dict[prompt_info["lora"][0]]
171
+ if model_name in base_model_dict:
172
+ bfn = ChemBFN.from_checkpoint(
173
+ base_model_dict[model_name],
174
+ lora_model_dict[prompt_info["lora"][0]] / "lora.pt",
175
+ )
176
+ else:
177
+ bfn = ChemBFN.from_checkpoint(
178
+ standalone_model_dict[model_name] / "model.pt",
179
+ lora_model_dict[prompt_info["lora"][0]] / "lora.pt",
180
+ )
181
+ if prompt_info["objective"]:
182
+ if not lora_label_dict[prompt_info["lora"][0]]:
183
+ y = None
184
+ _message.append("Objective values ignored.")
185
+ else:
186
+ mlp = MLP.from_checkpoint(
187
+ lora_model_dict[prompt_info["lora"][0]] / "mlp.pt"
188
+ )
189
+ y = torch.tensor([prompt_info["objective"]], dtype=torch.float32)
190
+ y = mlp.forward(y)
191
+ else:
192
+ y = None
193
+ if prompt_info["lora_scaling"][0] != 1.0:
194
+ adjust_lora_(bfn, prompt_info["lora_scaling"][0])
195
+ _message.append(f"Sequence length is set to {lmax} from model metadata.")
196
+ bfn.semi_autoregressive = sar_flag[0]
197
+ else:
198
+ lmax = max([lora_lmax_dict[i] for i in prompt_info["lora"]])
199
+ if model_name in base_model_dict:
200
+ base_model_dir = base_model_dict[model_name]
201
+ else:
202
+ base_model_dir = standalone_model_dict[model_name] / "model.pt"
203
+ lmax = max([lmax, standalone_lmax_dict[model_name]])
204
+ lora_dir = [lora_model_dict[i] / "lora.pt" for i in prompt_info["lora"]]
205
+ mlps = [
206
+ MLP.from_checkpoint(lora_model_dict[i] / "mlp.pt")
207
+ for i in prompt_info["lora"]
208
+ ]
209
+ weights = prompt_info["lora_scaling"]
210
+ if len(sar_flag) == 1:
211
+ sar_flag = [sar_flag[0] for _ in range(len(weights))]
212
+ bfn = EnsembleChemBFN(base_model_dir, lora_dir, mlps, weights)
213
+ y = [torch.tensor([i], dtype=torch.float32) for i in prompt_info["objective"]]
214
+ _message.append(f"Sequence length is set to {lmax} from model metadata.")
215
+ # ------- inference -------
216
+ allowed_tokens = parse_exclude_token(exclude_token, vocab_keys)
217
+ if not allowed_tokens:
218
+ allowed_tokens = "all"
219
+ scaffold = scaffold.strip()
220
+ if not scaffold:
221
+ mols = sample(
222
+ bfn,
223
+ batch_size,
224
+ lmax,
225
+ step,
226
+ y,
227
+ guidance_strength,
228
+ vocab_keys=vocab_keys,
229
+ method=_method,
230
+ allowed_tokens=allowed_tokens,
231
+ )
232
+ mols = trans_fn(mols)
233
+ imgs = img_fn(mols)
234
+ chemfigs = chemfig_fn(mols)
235
+ else:
236
+ x = [1] + tokeniser(scaffold)
237
+ x = x + [0 for _ in range(lmax - len(x))]
238
+ x = torch.tensor([x], dtype=torch.long).repeat(batch_size, 1)
239
+ mols = inpaint(
240
+ bfn,
241
+ x,
242
+ step,
243
+ y,
244
+ guidance_strength,
245
+ vocab_keys=vocab_keys,
246
+ method=_method,
247
+ allowed_tokens=allowed_tokens,
248
+ )
249
+ mols = trans_fn(mols)
250
+ imgs = img_fn(mols)
251
+ chemfigs = chemfig_fn(mols)
252
+ n_mol = len(mols)
253
+ with open(cache_dir / "results.csv", "w", encoding="utf-8", newline="") as rf:
254
+ rf.write("\n".join(mols))
255
+ _message.append(
256
+ f"{n_mol} smaples generated and saved to cache that can be downloaded."
257
+ )
258
+ return (
259
+ imgs,
260
+ mols,
261
+ "\n\n".join(chemfigs),
262
+ "\n".join(_message),
263
+ str(cache_dir / "results.csv"),
264
+ )
265
+
266
+
267
+ with gr.Blocks(title="ChemBFN WebUI") as app:
268
+ gr.Markdown("### WebUI to generate and visualise molecules for ChemBFN method.")
269
+ with gr.Row():
270
+ with gr.Column(scale=1):
271
+ btn = gr.Button("RUN", variant="primary")
272
+ model_name = gr.Dropdown(
273
+ [i[0] for i in models["base"]] + [i[0] for i in models["standalone"]],
274
+ label="model",
275
+ )
276
+ token_name = gr.Dropdown(
277
+ ["SMILES & SAFE", "SELFIES", "FASTA"], label="tokeniser"
278
+ )
279
+ vocab_fn = gr.Dropdown(
280
+ list(vocabs.keys()),
281
+ label="vocabulary",
282
+ visible=token_name.value == "SELFIES",
283
+ )
284
+ step = gr.Slider(1, 5000, 100, step=1, precision=0, label="step")
285
+ batch_size = gr.Slider(1, 512, 1, step=1, precision=0, label="batch size")
286
+ sequence_size = gr.Slider(
287
+ 5, 4096, 50, step=1, precision=0, label="sequence length"
288
+ )
289
+ guidance_strength = gr.Slider(
290
+ 0, 25, 4, step=0.05, label="guidance strength"
291
+ )
292
+ method = gr.Dropdown(["BFN", "ODE"], label="method")
293
+ temperature = gr.Slider(
294
+ 0.0,
295
+ 2.5,
296
+ 0.5,
297
+ step=0.001,
298
+ label="temperature",
299
+ visible=method.value == "ODE",
300
+ )
301
+ with gr.Column(scale=2):
302
+ with gr.Tab(label="prompt editor"):
303
+ prompt = gr.TextArea(label="prompt", lines=12)
304
+ scaffold = gr.Textbox(label="scaffold")
305
+ gr.Markdown("")
306
+ message = gr.TextArea(label="message")
307
+ with gr.Tab(label="result viewer"):
308
+ with gr.Tab(label="result"):
309
+ btn_download = gr.File(label="download", visible=False)
310
+ result = gr.Dataframe(
311
+ headers=["molecule"],
312
+ col_count=(1, "fixed"),
313
+ label="",
314
+ show_fullscreen_button=True,
315
+ show_row_numbers=True,
316
+ show_copy_button=True,
317
+ )
318
+ with gr.Tab(label="LATEX Chemfig"):
319
+ chemfig = gr.Code(
320
+ label="", language="latex", show_line_numbers=True
321
+ )
322
+ with gr.Tab(label="gallery"):
323
+ img = gr.Gallery(label="molecule", columns=4, height=512)
324
+ with gr.Tab(label="model explorer"):
325
+ btn_refresh = gr.Button("refresh", variant="secondary")
326
+ with gr.Tab(label="customised vocabulary"):
327
+ vocab_table = gr.Dataframe(
328
+ list(vocabs.keys()),
329
+ headers=["name"],
330
+ col_count=(1, "fixed"),
331
+ label="",
332
+ interactive=False,
333
+ show_row_numbers=True,
334
+ )
335
+ with gr.Tab(label="base models"):
336
+ base_table = gr.Dataframe(
337
+ [i[0] for i in models["base"]],
338
+ headers=["name"],
339
+ col_count=(1, "fixed"),
340
+ label="",
341
+ interactive=False,
342
+ show_row_numbers=True,
343
+ )
344
+ with gr.Tab(label="standalone models"):
345
+ standalone_table = gr.Dataframe(
346
+ [[i[0], i[2]] for i in models["standalone"]],
347
+ headers=["name", "objective"],
348
+ col_count=(2, "fixed"),
349
+ label="",
350
+ interactive=False,
351
+ show_row_numbers=True,
352
+ )
353
+ with gr.Tab(label="LoRA models"):
354
+ lora_tabel = gr.Dataframe(
355
+ [[i[0], i[2]] for i in models["lora"]],
356
+ headers=["name", "objective"],
357
+ col_count=(2, "fixed"),
358
+ label="",
359
+ interactive=False,
360
+ show_row_numbers=True,
361
+ )
362
+ with gr.Tab(label="advanced control"):
363
+ sar_control = gr.Textbox("F", label="semi-autoregressive behaviour")
364
+ gr.Markdown("")
365
+ exclude_token = gr.TextArea(
366
+ label="exclude tokens",
367
+ placeholder="key in unwanted tokens separated by comma.",
368
+ )
369
+ # ------ user interaction events -------
370
+ btn.click(
371
+ fn=run,
372
+ inputs=[
373
+ model_name,
374
+ token_name,
375
+ vocab_fn,
376
+ step,
377
+ batch_size,
378
+ sequence_size,
379
+ guidance_strength,
380
+ method,
381
+ temperature,
382
+ prompt,
383
+ scaffold,
384
+ sar_control,
385
+ exclude_token,
386
+ ],
387
+ outputs=[img, result, chemfig, message, btn_download],
388
+ )
389
+ btn_refresh.click(
390
+ fn=refresh,
391
+ inputs=[model_name, vocab_fn, token_name],
392
+ outputs=[
393
+ vocab_table,
394
+ base_table,
395
+ standalone_table,
396
+ lora_tabel,
397
+ model_name,
398
+ vocab_fn,
399
+ ],
400
+ )
401
+ token_name.input(
402
+ fn=lambda x, y: gr.Dropdown(
403
+ list(vocabs.keys()), value=y, label="vocabulary", visible=x == "SELFIES"
404
+ ),
405
+ inputs=[token_name, vocab_fn],
406
+ outputs=vocab_fn,
407
+ )
408
+ method.input(
409
+ fn=lambda x, y: gr.Slider(
410
+ 0.0,
411
+ 2.5,
412
+ y,
413
+ step=0.001,
414
+ label="temperature",
415
+ visible=x == "ODE",
416
+ ),
417
+ inputs=[method, temperature],
418
+ outputs=temperature,
419
+ )
420
+ lora_tabel.select(fn=select_lora, inputs=prompt, outputs=prompt)
421
+ result.change(
422
+ fn=lambda x: gr.File(x, label="download", visible=True),
423
+ inputs=btn_download,
424
+ outputs=btn_download,
425
+ )
426
+
427
+
428
+ def main() -> None:
429
+ parser = argparse.ArgumentParser()
430
+ parser.add_argument(
431
+ "--public", default=False, help="open to public", action="store_true"
432
+ )
433
+ parser.add_argument("-V", "--version", action="version", version=__version__)
434
+ args = parser.parse_args()
435
+ app.launch(share=args.public)
436
+
437
+
438
+ if __name__ == "__main__":
439
+ main()
@@ -0,0 +1 @@
1
+ Cache file will appear here.
@@ -0,0 +1,15 @@
1
+ [C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][O]
2
+ [C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][=Branch1][C][=O][N][C][C][C]
3
+ [C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][N][C][C][Branch1][=Branch2][C][=C][C][=C][C][=C][Ring1][=Branch1][C][=C][C][=C][C][=C][Ring1][=Branch1]
4
+ [C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][O]
5
+ [C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C]
6
+ [C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C]
7
+ [C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C]
8
+ [C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C]
9
+ [C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C]
10
+ [C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C]
11
+ [O][=C][Branch2][Ring1][C][N][N][C][=Branch1][C][=O][C][=C][C][=C][Branch1][C][Cl][C][=C][Ring1][#Branch1][C][=C][C][=C][C][=C][Ring1][=Branch1]
12
+ [C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C]
13
+ [C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][=Branch1][C][=O][O]
14
+ [C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C]
15
+ [C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C]
@@ -0,0 +1,134 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Author: Nianze A. TAO (omozawa SUENO)
3
+ """
4
+ Utilities.
5
+ """
6
+ import os
7
+ import json
8
+ from glob import glob
9
+ from pathlib import Path
10
+ from typing import Dict, List, Tuple, Union
11
+
12
+ _model_path = Path(__file__).parent.parent / "model"
13
+ if "CHEMBFN_WEBUI_MODEL_DIR" in os.environ:
14
+ _model_path = Path(os.environ["CHEMBFN_WEBUI_MODEL_DIR"])
15
+
16
+
17
+ def find_vocab() -> Dict[str, str]:
18
+ vocab_fns = glob(str(_model_path / "vocab/*.txt"))
19
+ return {
20
+ os.path.basename(i).replace(".txt", ""): i
21
+ for i in vocab_fns
22
+ if "place_vocabulary_file_here.txt" not in i
23
+ }
24
+
25
+
26
+ def find_model() -> Dict[str, List[List[Union[str, int, List[str], Path]]]]:
27
+ models = {}
28
+ # find base models
29
+ base_fns = glob(str(_model_path / "base_model/*.pt"))
30
+ models["base"] = [[os.path.basename(i), i] for i in base_fns]
31
+ # find standalone models
32
+ standalone_models = []
33
+ standalone_fns = glob(str(_model_path / "standalone_model/*/model.pt"))
34
+ for standalone_fn in standalone_fns:
35
+ config_fn = Path(standalone_fn).parent / "config.json"
36
+ if not os.path.exists(config_fn):
37
+ continue
38
+ else:
39
+ with open(config_fn, "r", encoding="utf-8") as f:
40
+ config = json.load(f)
41
+ name = config["name"]
42
+ label = config["label"]
43
+ lmax = config["padding_length"]
44
+ standalone_models.append([name, Path(standalone_fn).parent, label, lmax])
45
+ models["standalone"] = standalone_models
46
+ # find lora models
47
+ lora_models = []
48
+ lora_fns = glob(str(_model_path / "lora/*/lora.pt"))
49
+ for lora_fn in lora_fns:
50
+ config_fn = Path(lora_fn).parent / "config.json"
51
+ if not os.path.exists(config_fn):
52
+ continue
53
+ else:
54
+ with open(config_fn, "r", encoding="utf-8") as f:
55
+ config = json.load(f)
56
+ name = config["name"]
57
+ label = config["label"]
58
+ lmax = config["padding_length"]
59
+ lora_models.append([name, Path(lora_fn).parent, label, lmax])
60
+ models["lora"] = lora_models
61
+ return models
62
+
63
+
64
+ def _get_lora_info(prompt: str) -> Tuple[str, List[float], List[float]]:
65
+ s = prompt.split(">")
66
+ s1 = s[0].replace("<", "")
67
+ lora_info = s1.split(":")
68
+ lora_name = lora_info[0]
69
+ if len(lora_info) == 1:
70
+ lora_scaling = 1.0
71
+ else:
72
+ lora_scaling = float(lora_info[1])
73
+ if len(s) == 1:
74
+ obj = []
75
+ elif ":" not in s[1]:
76
+ obj = []
77
+ else:
78
+ s2 = s[1].replace(":", "").replace("[", "").replace("]", "").split(",")
79
+ obj = [float(i) for i in s2]
80
+ return lora_name, obj, lora_scaling
81
+
82
+
83
+ def parse_prompt(
84
+ prompt: str,
85
+ ) -> Dict[str, Union[List[str], List[float], List[List[float]]]]:
86
+ prompt_group = prompt.strip().replace("\n", "").split(";")
87
+ prompt_group = [i for i in prompt_group if i]
88
+ info = {"lora": [], "objective": [], "lora_scaling": []}
89
+ if not prompt_group:
90
+ pass
91
+ if len(prompt_group) == 1:
92
+ if not ("<" in prompt_group[0] and ">" in prompt_group[0]):
93
+ obj = [
94
+ float(i)
95
+ for i in prompt_group[0].replace("[", "").replace("]", "").split(",")
96
+ ]
97
+ info["objective"].append(obj)
98
+ else:
99
+ lora_name, obj, lora_scaling = _get_lora_info(prompt_group[0])
100
+ info["lora"].append(lora_name)
101
+ if obj:
102
+ info["objective"].append(obj)
103
+ info["lora_scaling"].append(lora_scaling)
104
+ else:
105
+ for _prompt in prompt_group:
106
+ if not ("<" in _prompt and ">" in _prompt):
107
+ continue
108
+ lora_name, obj, lora_scaling = _get_lora_info(_prompt)
109
+ info["lora"].append(lora_name)
110
+ if obj:
111
+ info["objective"].append(obj)
112
+ info["lora_scaling"].append(lora_scaling)
113
+ return info
114
+
115
+
116
+ def parse_exclude_token(tokens: str, vocab_keys: List[str]) -> List[str]:
117
+ tokens = tokens.strip().replace("\n", "").split(",")
118
+ tokens = [i for i in tokens if i]
119
+ if not tokens:
120
+ return tokens
121
+ tokens = [i for i in vocab_keys if i not in tokens]
122
+ return tokens
123
+
124
+ def parse_sar_control(sar_control: str) -> List[bool]:
125
+ sar_flag = sar_control.strip().replace("\n", "").split(",")
126
+ sar_flag = [i for i in sar_flag if i]
127
+ if not sar_flag:
128
+ return [False]
129
+ sar_flag = [i.lower() == "t" for i in sar_flag]
130
+ return sar_flag
131
+
132
+
133
+ if __name__ == "__main__":
134
+ ...
@@ -0,0 +1,8 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Author: Nianze A. TAO (omozawa SUENO)
3
+ """
4
+ Version info.
5
+ """
6
+
7
+ __version__ = "0.1.0"
8
+ __author__ = "Nianze A. TAO"
@@ -0,0 +1 @@
1
+ Place thy base model weight files here.
@@ -0,0 +1 @@
1
+ Place thy LoRA weight and configureation files under subfolders here.
@@ -0,0 +1 @@
1
+ Place thy standalone model weight and configuration files under subfloders here.
@@ -0,0 +1 @@
1
+ Place vocabulary files here.