bayesianflow-for-chem 2.0.5__tar.gz → 2.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of bayesianflow-for-chem might be problematic. Click here for more details.
- {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.1.0}/PKG-INFO +1 -1
- {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.1.0}/bayesianflow_for_chem/__init__.py +1 -1
- {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.1.0}/bayesianflow_for_chem/cli.py +35 -4
- {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.1.0}/bayesianflow_for_chem/model.py +168 -1
- {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.1.0}/bayesianflow_for_chem/tool.py +91 -1
- {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.1.0}/bayesianflow_for_chem.egg-info/PKG-INFO +1 -1
- {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.1.0}/bayesianflow_for_chem.egg-info/SOURCES.txt +1 -0
- bayesianflow_for_chem-2.1.0/test/test_jit_compatibility.py +28 -0
- {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.1.0}/LICENSE +0 -0
- {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.1.0}/README.md +0 -0
- {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.1.0}/bayesianflow_for_chem/data.py +0 -0
- {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.1.0}/bayesianflow_for_chem/scorer.py +0 -0
- {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.1.0}/bayesianflow_for_chem/spectra.py +0 -0
- {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.1.0}/bayesianflow_for_chem/train.py +0 -0
- {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.1.0}/bayesianflow_for_chem/vocab.txt +0 -0
- {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.1.0}/bayesianflow_for_chem.egg-info/dependency_links.txt +0 -0
- {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.1.0}/bayesianflow_for_chem.egg-info/entry_points.txt +0 -0
- {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.1.0}/bayesianflow_for_chem.egg-info/requires.txt +0 -0
- {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.1.0}/bayesianflow_for_chem.egg-info/top_level.txt +0 -0
- {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.1.0}/pyproject.toml +0 -0
- {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.1.0}/setup.cfg +0 -0
- {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.1.0}/setup.py +0 -0
- {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.1.0}/test/test_merge_lora.py +0 -0
- {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.1.0}/test/test_molecular_embedding.py +0 -0
|
@@ -32,7 +32,7 @@ from bayesianflow_for_chem.data import (
|
|
|
32
32
|
collate,
|
|
33
33
|
CSVData,
|
|
34
34
|
)
|
|
35
|
-
from bayesianflow_for_chem.tool import sample, inpaint
|
|
35
|
+
from bayesianflow_for_chem.tool import sample, inpaint, optimise, adjust_lora_
|
|
36
36
|
|
|
37
37
|
|
|
38
38
|
"""
|
|
@@ -99,9 +99,11 @@ sample_size = 1000 # the minimum number of samples you want
|
|
|
99
99
|
sample_step = 100
|
|
100
100
|
sample_method = "ODE:0.5" # ODE-solver with temperature of 0.5; another choice is "BFN"
|
|
101
101
|
semi_autoregressive = false
|
|
102
|
+
lora_scaling = 1.0 # LoRA scaling if applied
|
|
102
103
|
guidance_objective = [-0.023, 0.09, 0.113] # if no objective is needed set it to empty array []
|
|
103
104
|
guidance_objective_strength = 4.0 # unnecessary if guidance_objective = []
|
|
104
105
|
guidance_scaffold = "c1ccccc1" # if no scaffold is used set it to empty string ""
|
|
106
|
+
sample_template = "" # template for mol2mol task; leave it blank if scaffold is used
|
|
105
107
|
unwanted_token = []
|
|
106
108
|
exclude_invalid = true # to only store valid samples
|
|
107
109
|
exclude_duplicate = true # to only store unique samples
|
|
@@ -130,7 +132,7 @@ def parse_cli(version: str) -> argparse.Namespace:
|
|
|
130
132
|
"""
|
|
131
133
|
parser = argparse.ArgumentParser(
|
|
132
134
|
description="Madmol: a CLI molecular design tool for "
|
|
133
|
-
"de novo design, R-group replacement, and sequence in-filling, "
|
|
135
|
+
"de novo design, R-group replacement, molecule optimisation, and sequence in-filling, "
|
|
134
136
|
"based on generative route of ChemBFN method. "
|
|
135
137
|
"Let's make some craziest molecules.",
|
|
136
138
|
epilog=f"Madmol {version}, developed in Hiroshima University by chemists for chemists. "
|
|
@@ -157,7 +159,7 @@ def parse_cli(version: str) -> argparse.Namespace:
|
|
|
157
159
|
"-D",
|
|
158
160
|
"--dryrun",
|
|
159
161
|
action="store_true",
|
|
160
|
-
help="dry-run to check the configurations",
|
|
162
|
+
help="dry-run to check the configurations and exit",
|
|
161
163
|
)
|
|
162
164
|
parser.add_argument("-V", "--version", action="version", version=version)
|
|
163
165
|
return parser.parse_args()
|
|
@@ -284,6 +286,14 @@ def load_runtime_config(
|
|
|
284
286
|
f"\033[0;33mWarning\033[0;0m in {config_file}: Directory {result_dir} to save the result does not exist."
|
|
285
287
|
)
|
|
286
288
|
flag_warning += 1
|
|
289
|
+
if (
|
|
290
|
+
config["inference"]["guidance_scaffold"] != ""
|
|
291
|
+
and config["inference"]["sample_template"] != ""
|
|
292
|
+
):
|
|
293
|
+
print(
|
|
294
|
+
f"\033[0;33mWarning\033[0;0m in {config_file}: Inpaint task or mol2mol task?"
|
|
295
|
+
)
|
|
296
|
+
flag_warning += 1
|
|
287
297
|
return config, flag_critical, flag_warning
|
|
288
298
|
|
|
289
299
|
|
|
@@ -520,6 +530,7 @@ def main_script(version: str) -> None:
|
|
|
520
530
|
if "train" in runtime_config:
|
|
521
531
|
bfn = model.model
|
|
522
532
|
mlp = model.mlp
|
|
533
|
+
lora_scaling = runtime_config["inference"].get("lora_scaling", 1.0)
|
|
523
534
|
# ####### strat inference #######
|
|
524
535
|
bfn.semi_autoregressive = runtime_config["inference"]["semi_autoregressive"]
|
|
525
536
|
_device = (
|
|
@@ -550,8 +561,16 @@ def main_script(version: str) -> None:
|
|
|
550
561
|
x[:-1], (0, sequence_length - x.shape[-1] + 1), value=0
|
|
551
562
|
)
|
|
552
563
|
x = x[None, :].repeat(batch_size, 1)
|
|
564
|
+
# then sample template will be ignored.
|
|
565
|
+
elif runtime_config["inference"]["sample_template"]:
|
|
566
|
+
template = runtime_config["inference"]["sample_template"]
|
|
567
|
+
x = tokeniser(template)
|
|
568
|
+
x = torch.nn.functional.pad(x, (0, sequence_length - x.shape[-1]), value=0)
|
|
569
|
+
x = x[None, :].repeat(batch_size, 1)
|
|
553
570
|
else:
|
|
554
571
|
x = None
|
|
572
|
+
if bfn.lora_enabled:
|
|
573
|
+
adjust_lora_(bfn, lora_scaling)
|
|
555
574
|
mols = []
|
|
556
575
|
while len(mols) < runtime_config["inference"]["sample_size"]:
|
|
557
576
|
if x is None:
|
|
@@ -567,7 +586,7 @@ def main_script(version: str) -> None:
|
|
|
567
586
|
method=sample_method,
|
|
568
587
|
allowed_tokens=allowed_token,
|
|
569
588
|
)
|
|
570
|
-
|
|
589
|
+
elif runtime_config["inference"]["guidance_scaffold"]:
|
|
571
590
|
s = inpaint(
|
|
572
591
|
bfn,
|
|
573
592
|
x,
|
|
@@ -579,6 +598,18 @@ def main_script(version: str) -> None:
|
|
|
579
598
|
method=sample_method,
|
|
580
599
|
allowed_tokens=allowed_token,
|
|
581
600
|
)
|
|
601
|
+
else:
|
|
602
|
+
s = optimise(
|
|
603
|
+
bfn,
|
|
604
|
+
x,
|
|
605
|
+
sample_step,
|
|
606
|
+
y,
|
|
607
|
+
guidance_strength,
|
|
608
|
+
_device,
|
|
609
|
+
vocab_keys,
|
|
610
|
+
method=sample_method,
|
|
611
|
+
allowed_tokens=allowed_token,
|
|
612
|
+
)
|
|
582
613
|
if runtime_config["inference"]["exclude_invalid"]:
|
|
583
614
|
s = [i for i in s if i]
|
|
584
615
|
if tokeniser_name == "smiles" or tokeniser_name == "safe":
|
|
@@ -676,7 +676,7 @@ class ChemBFN(nn.Module):
|
|
|
676
676
|
token_mask: Optional[Tensor] = None,
|
|
677
677
|
) -> Tuple[Tensor, Tensor]:
|
|
678
678
|
"""
|
|
679
|
-
Sample from a piror distribution.
|
|
679
|
+
Sample from a uniform piror distribution.
|
|
680
680
|
|
|
681
681
|
:param batch_size: batch size
|
|
682
682
|
:param sequence_size: max sequence length
|
|
@@ -880,6 +880,108 @@ class ChemBFN(nn.Module):
|
|
|
880
880
|
p = p.masked_fill_(token_mask, 0.0)
|
|
881
881
|
return torch.argmax(p, -1), entropy
|
|
882
882
|
|
|
883
|
+
@torch.jit.export
|
|
884
|
+
def optimise(
|
|
885
|
+
self,
|
|
886
|
+
x: Tensor,
|
|
887
|
+
y: Optional[Tensor] = None,
|
|
888
|
+
sample_step: int = 100,
|
|
889
|
+
guidance_strength: float = 4.0,
|
|
890
|
+
token_mask: Optional[Tensor] = None,
|
|
891
|
+
) -> Tuple[Tensor, Tensor]:
|
|
892
|
+
"""
|
|
893
|
+
Optimise the template molecule (mol2mol). \n
|
|
894
|
+
This method is equivalent to sampling from a customised prior distribution.
|
|
895
|
+
|
|
896
|
+
:param x: categorical indices of template; shape: (n_b, n_t)
|
|
897
|
+
:param y: conditioning vector; shape: (n_b, 1, n_f) or (n_b, n_f)
|
|
898
|
+
:param sample_step: number of sampling steps
|
|
899
|
+
:param guidance_strength: strength of conditional generation. It is not used if y is null.
|
|
900
|
+
:param token_mask: token mask assigning unwanted token(s) with `True`;
|
|
901
|
+
shape: (1, 1, n_vocab)
|
|
902
|
+
:type x: torch.Tensor
|
|
903
|
+
:type y: torch.Tensor | None
|
|
904
|
+
:type sample_step: int
|
|
905
|
+
:type guidance_strength: float
|
|
906
|
+
:type token_mask: torch.Tensor | None
|
|
907
|
+
:return: sampled token indices; shape: (n_b, n_t) \n
|
|
908
|
+
entropy of the tokens; shape: (n_b)
|
|
909
|
+
:rtype: tuple
|
|
910
|
+
"""
|
|
911
|
+
n_b = x.shape[0]
|
|
912
|
+
x_onehot = nn.functional.one_hot(x, self.K).float()
|
|
913
|
+
theta = nn.functional.softmax(x_onehot, -1)
|
|
914
|
+
if y is not None:
|
|
915
|
+
y = self.reshape_y(y)
|
|
916
|
+
for i in torch.linspace(1, sample_step, sample_step, device=x.device):
|
|
917
|
+
t = (i - 1).view(1, 1, 1).repeat(n_b, 1, 1) / sample_step
|
|
918
|
+
p = self.discrete_output_distribution(theta, t, y, guidance_strength)
|
|
919
|
+
if token_mask is not None:
|
|
920
|
+
p = p.masked_fill_(token_mask, 0.0)
|
|
921
|
+
alpha = self.calc_discrete_alpha(t, t + 1 / sample_step)
|
|
922
|
+
e_k = nn.functional.one_hot(torch.argmax(p, -1), self.K).float()
|
|
923
|
+
mu = alpha * (self.K * e_k - 1)
|
|
924
|
+
sigma = (alpha * self.K).sqrt()
|
|
925
|
+
theta = (mu + sigma * torch.randn_like(mu)).exp() * theta
|
|
926
|
+
theta = theta / theta.sum(-1, True)
|
|
927
|
+
t_final = torch.ones((n_b, 1, 1), device=x.device)
|
|
928
|
+
p = self.discrete_output_distribution(theta, t_final, y, guidance_strength)
|
|
929
|
+
entropy = -(p * p.log()).sum(-1).mean(-1)
|
|
930
|
+
if token_mask is not None:
|
|
931
|
+
p = p.masked_fill_(token_mask, 0.0)
|
|
932
|
+
return torch.argmax(p, -1), entropy
|
|
933
|
+
|
|
934
|
+
@torch.jit.export
|
|
935
|
+
def ode_optimise(
|
|
936
|
+
self,
|
|
937
|
+
x: Tensor,
|
|
938
|
+
y: Optional[Tensor] = None,
|
|
939
|
+
sample_step: int = 100,
|
|
940
|
+
guidance_strength: float = 4.0,
|
|
941
|
+
token_mask: Optional[Tensor] = None,
|
|
942
|
+
temperature: float = 0.5,
|
|
943
|
+
) -> Tuple[Tensor, Tensor]:
|
|
944
|
+
"""
|
|
945
|
+
ODE mol2mol.
|
|
946
|
+
|
|
947
|
+
:param x: categorical indices of template; shape: (n_b, n_t)
|
|
948
|
+
:param y: conditioning vector; shape: (n_b, 1, n_f) or (n_b, n_f)
|
|
949
|
+
:param sample_step: number of sampling steps
|
|
950
|
+
:param guidance_strength: strength of conditional generation. It is not used if y is null.
|
|
951
|
+
:param token_mask: token mask assigning unwanted token(s) with `True`;
|
|
952
|
+
shape: (1, 1, n_vocab)
|
|
953
|
+
:param temperature: sampling temperature
|
|
954
|
+
:type x: torch.Tensor
|
|
955
|
+
:type y: torch.Tensor | None
|
|
956
|
+
:type sample_step: int
|
|
957
|
+
:type guidance_strength: float
|
|
958
|
+
:type token_mask: torch.Tensor | None
|
|
959
|
+
:type temperature: float
|
|
960
|
+
:return: sampled token indices; shape: (n_b, n_t) \n
|
|
961
|
+
entropy of the tokens; shape: (n_b)
|
|
962
|
+
:rtype: tuple
|
|
963
|
+
"""
|
|
964
|
+
n_b = x.shape[0]
|
|
965
|
+
z = nn.functional.one_hot(x, self.K).float()
|
|
966
|
+
if y is not None:
|
|
967
|
+
y = self.reshape_y(y)
|
|
968
|
+
for i in torch.linspace(1, sample_step, sample_step, device=self.beta.device):
|
|
969
|
+
t = (i - 1).view(1, 1, 1).repeat(n_b, 1, 1) / sample_step
|
|
970
|
+
theta = torch.softmax(z, -1)
|
|
971
|
+
beta = self.calc_beta(t + 1 / sample_step)
|
|
972
|
+
p = self.discrete_output_distribution(theta, t, y, guidance_strength)
|
|
973
|
+
if token_mask is not None:
|
|
974
|
+
p = p.masked_fill_(token_mask, 0.0)
|
|
975
|
+
u = torch.randn_like(z)
|
|
976
|
+
z = (self.K * p - 1) * beta + (self.K * beta * temperature).sqrt() * u
|
|
977
|
+
t_final = torch.ones((n_b, 1, 1), device=self.beta.device)
|
|
978
|
+
theta = torch.softmax(z, -1)
|
|
979
|
+
p = self.discrete_output_distribution(theta, t_final, y, guidance_strength)
|
|
980
|
+
entropy = -(p * p.log()).sum(-1).mean(-1)
|
|
981
|
+
if token_mask is not None:
|
|
982
|
+
p = p.masked_fill_(token_mask, 0.0)
|
|
983
|
+
return torch.argmax(p, -1), entropy
|
|
984
|
+
|
|
883
985
|
def inference(
|
|
884
986
|
self, x: Tensor, mlp: MLP, embed_fn: Optional[Callable[[Tensor], Tensor]] = None
|
|
885
987
|
) -> Tensor:
|
|
@@ -1250,6 +1352,71 @@ class EnsembleChemBFN(ChemBFN):
|
|
|
1250
1352
|
x, y, sample_step, guidance_strength, token_mask, temperature
|
|
1251
1353
|
)
|
|
1252
1354
|
|
|
1355
|
+
@torch.inference_mode()
|
|
1356
|
+
def optimise(
|
|
1357
|
+
self,
|
|
1358
|
+
x: Tensor,
|
|
1359
|
+
conditions: Union[List[Tensor], Dict[str, Tensor]],
|
|
1360
|
+
sample_step: int = 100,
|
|
1361
|
+
guidance_strength: float = 4.0,
|
|
1362
|
+
token_mask: Optional[Tensor] = None,
|
|
1363
|
+
) -> Tuple[Tensor, Tensor]:
|
|
1364
|
+
"""
|
|
1365
|
+
Optimise the template molecule (mol2mol). \n
|
|
1366
|
+
This method is equivalent to sampling from a customised prior distribution.
|
|
1367
|
+
|
|
1368
|
+
:param x: categorical indices of template; shape: (n_b, n_t)
|
|
1369
|
+
:param conditions: conditioning vector; shape: (n_b, n_c) * n_h
|
|
1370
|
+
:param sample_step: number of sampling steps
|
|
1371
|
+
:param guidance_strength: strength of conditional generation. It is not used if y is null.
|
|
1372
|
+
:param token_mask: token mask assigning unwanted token(s) with `True`;
|
|
1373
|
+
shape: (1, 1, n_vocab)
|
|
1374
|
+
:type x: torch.Tensor
|
|
1375
|
+
:type y: torch.Tensor | None
|
|
1376
|
+
:type sample_step: int
|
|
1377
|
+
:type guidance_strength: float
|
|
1378
|
+
:type token_mask: torch.Tensor | None
|
|
1379
|
+
:return: sampled token indices; shape: (n_b, n_t) \n
|
|
1380
|
+
entropy of the tokens; shape: (n_b)
|
|
1381
|
+
:rtype: tuple
|
|
1382
|
+
"""
|
|
1383
|
+
y = self.construct_y(conditions)
|
|
1384
|
+
return super().optimise(x, y, sample_step, guidance_strength, token_mask)
|
|
1385
|
+
|
|
1386
|
+
@torch.inference_mode()
|
|
1387
|
+
def ode_optimise(
|
|
1388
|
+
self,
|
|
1389
|
+
x: Tensor,
|
|
1390
|
+
conditions: Union[List[Tensor], Dict[str, Tensor]],
|
|
1391
|
+
sample_step: int = 100,
|
|
1392
|
+
guidance_strength: float = 4.0,
|
|
1393
|
+
token_mask: Optional[Tensor] = None,
|
|
1394
|
+
temperature: float = 0.5,
|
|
1395
|
+
) -> Tuple[Tensor, Tensor]:
|
|
1396
|
+
"""
|
|
1397
|
+
ODE inpainting.
|
|
1398
|
+
|
|
1399
|
+
:param x: categorical indices of template; shape: (n_b, n_t)
|
|
1400
|
+
:param conditions: conditioning vector; shape: (n_b, n_c) * n_h
|
|
1401
|
+
:param sample_step: number of sampling steps
|
|
1402
|
+
:param guidance_strength: strength of conditional generation. It is not used if y is null.
|
|
1403
|
+
:param token_mask: token mask; shape: (1, 1, n_vocab)
|
|
1404
|
+
:param temperature: sampling temperature
|
|
1405
|
+
:type x: torch.Tensor
|
|
1406
|
+
:type conditions: list | dict
|
|
1407
|
+
:type sample_step: int
|
|
1408
|
+
:type guidance_strength: float
|
|
1409
|
+
:type token_mask: torch.Tensor | None
|
|
1410
|
+
:type temperature: float
|
|
1411
|
+
:return: sampled token indices; shape: (n_b, n_t) \n
|
|
1412
|
+
entropy of the tokens; shape: (n_b)
|
|
1413
|
+
:rtype: tuple
|
|
1414
|
+
"""
|
|
1415
|
+
y = self.construct_y(conditions)
|
|
1416
|
+
return super().ode_optimise(
|
|
1417
|
+
x, y, sample_step, guidance_strength, token_mask, temperature
|
|
1418
|
+
)
|
|
1419
|
+
|
|
1253
1420
|
def quantise(
|
|
1254
1421
|
self, quantise_method: Optional[Callable[[ChemBFN], None]] = None
|
|
1255
1422
|
) -> None:
|
|
@@ -219,7 +219,7 @@ def sample(
|
|
|
219
219
|
sort: bool = False,
|
|
220
220
|
) -> List[str]:
|
|
221
221
|
"""
|
|
222
|
-
Sampling.
|
|
222
|
+
Sampling molecules.
|
|
223
223
|
|
|
224
224
|
:param model: trained ChemBFN model
|
|
225
225
|
:param batch_size: batch size
|
|
@@ -385,6 +385,96 @@ def inpaint(
|
|
|
385
385
|
]
|
|
386
386
|
|
|
387
387
|
|
|
388
|
+
@torch.no_grad()
|
|
389
|
+
def optimise(
|
|
390
|
+
model: Union[ChemBFN, EnsembleChemBFN],
|
|
391
|
+
x: Tensor,
|
|
392
|
+
sample_step: int = 100,
|
|
393
|
+
y: Optional[Union[Tensor, Dict[str, Tensor], List[Tensor]]] = None,
|
|
394
|
+
guidance_strength: float = 4.0,
|
|
395
|
+
device: Union[str, torch.device, None] = None,
|
|
396
|
+
vocab_keys: List[str] = VOCAB_KEYS,
|
|
397
|
+
separator: str = "",
|
|
398
|
+
method: str = "BFN",
|
|
399
|
+
allowed_tokens: Union[str, List[str]] = "all",
|
|
400
|
+
sort: bool = False,
|
|
401
|
+
) -> List[str]:
|
|
402
|
+
"""
|
|
403
|
+
Optimising template molecules (mol2mol).
|
|
404
|
+
|
|
405
|
+
:param model: trained ChemBFN model
|
|
406
|
+
:param x: categorical indices of template; shape: (n_b, n_t)
|
|
407
|
+
:param sample_step: number of sampling steps
|
|
408
|
+
:param y: conditioning vector; shape: (n_b, 1, n_f) or (n_b, n_f) \n
|
|
409
|
+
or a list/`dict` of conditions; shape: (n_b, n_c) * n_h
|
|
410
|
+
|
|
411
|
+
:param guidance_strength: strength of conditional generation. It is not used if y is null.
|
|
412
|
+
:param device: hardware accelerator
|
|
413
|
+
:param vocab_keys: a list of (ordered) vocabulary
|
|
414
|
+
:param separator: token separator; default is `""`
|
|
415
|
+
:param method: sampling method chosen from `"ODE:x"` or `"BFN"` where `x` is the value of sampling temperature; default is `"BFN"`
|
|
416
|
+
:param allowed_tokens: a list of allowed tokens; default is `"all"`
|
|
417
|
+
:param sort: whether to sort the samples according to entropy values; default is `False`
|
|
418
|
+
:type model: bayesianflow_for_chem.model.ChemBFN | bayesianflow_for_chem.model.EnsembleChemBFN
|
|
419
|
+
:type x: torch.Tensor
|
|
420
|
+
:type sample_step: int
|
|
421
|
+
:type y: torch.Tensor | list | dict | None
|
|
422
|
+
:type guidance_strength: float
|
|
423
|
+
:type device: str | torch.device | None
|
|
424
|
+
:type vocab_keys: list
|
|
425
|
+
:type separator: str
|
|
426
|
+
:type method: str
|
|
427
|
+
:type allowed_tokens: str | list
|
|
428
|
+
:type sort: bool
|
|
429
|
+
:return: a list of generated molecular strings
|
|
430
|
+
:rtype: list
|
|
431
|
+
"""
|
|
432
|
+
assert method.split(":")[0].lower() in ("ode", "bfn")
|
|
433
|
+
if isinstance(model, EnsembleChemBFN):
|
|
434
|
+
assert y is not None, "conditioning is required while using an ensemble model."
|
|
435
|
+
assert isinstance(y, list) or isinstance(y, dict)
|
|
436
|
+
else:
|
|
437
|
+
assert isinstance(y, Tensor) or y is None
|
|
438
|
+
if device is None:
|
|
439
|
+
device = _find_device()
|
|
440
|
+
model.to(device).eval()
|
|
441
|
+
x = x.to(device)
|
|
442
|
+
if y is not None:
|
|
443
|
+
if isinstance(y, Tensor):
|
|
444
|
+
y = y.to(device)
|
|
445
|
+
elif isinstance(y, list):
|
|
446
|
+
y = [i.to(device) for i in y]
|
|
447
|
+
elif isinstance(y, dict):
|
|
448
|
+
y = {k: v.to(device) for k, v in y.items()}
|
|
449
|
+
else:
|
|
450
|
+
raise NotImplementedError
|
|
451
|
+
if isinstance(allowed_tokens, list):
|
|
452
|
+
token_mask = [0 if i in allowed_tokens else 1 for i in vocab_keys]
|
|
453
|
+
token_mask = torch.tensor([[token_mask]], dtype=torch.bool).to(device)
|
|
454
|
+
else:
|
|
455
|
+
token_mask = None
|
|
456
|
+
if "ode" in method.lower():
|
|
457
|
+
tp = float(method.split(":")[-1])
|
|
458
|
+
assert tp > 0, "Sampling temperature should be higher than 0."
|
|
459
|
+
tokens, entropy = model.ode_optimise(
|
|
460
|
+
x, y, sample_step, guidance_strength, token_mask, tp
|
|
461
|
+
)
|
|
462
|
+
else:
|
|
463
|
+
tokens, entropy = model.optimise(
|
|
464
|
+
x, y, sample_step, guidance_strength, token_mask
|
|
465
|
+
)
|
|
466
|
+
if sort:
|
|
467
|
+
sorted_idx = entropy.argsort(stable=True)
|
|
468
|
+
tokens = tokens[sorted_idx]
|
|
469
|
+
return [
|
|
470
|
+
separator.join([vocab_keys[i] for i in j])
|
|
471
|
+
.split("<start>" + separator)[-1]
|
|
472
|
+
.split(separator + "<end>")[0]
|
|
473
|
+
.replace("<pad>", "")
|
|
474
|
+
for j in tokens
|
|
475
|
+
]
|
|
476
|
+
|
|
477
|
+
|
|
388
478
|
def quantise_model_(model: ChemBFN) -> None:
|
|
389
479
|
"""
|
|
390
480
|
In-place dynamic quantisation of the trained model to `int8` data type. \n
|
|
@@ -17,5 +17,6 @@ bayesianflow_for_chem.egg-info/dependency_links.txt
|
|
|
17
17
|
bayesianflow_for_chem.egg-info/entry_points.txt
|
|
18
18
|
bayesianflow_for_chem.egg-info/requires.txt
|
|
19
19
|
bayesianflow_for_chem.egg-info/top_level.txt
|
|
20
|
+
test/test_jit_compatibility.py
|
|
20
21
|
test/test_merge_lora.py
|
|
21
22
|
test/test_molecular_embedding.py
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# Author: Nianze A. Tao (Omozawa Sueno)
|
|
3
|
+
"""
|
|
4
|
+
Model should be compatible with TorchScript.
|
|
5
|
+
"""
|
|
6
|
+
import torch
|
|
7
|
+
from bayesianflow_for_chem import ChemBFN
|
|
8
|
+
|
|
9
|
+
model = ChemBFN(512)
|
|
10
|
+
model_method = [
|
|
11
|
+
"sample",
|
|
12
|
+
"ode_sample",
|
|
13
|
+
"inpaint",
|
|
14
|
+
"ode_inpaint",
|
|
15
|
+
"optimise",
|
|
16
|
+
"ode_optimise",
|
|
17
|
+
]
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@torch.inference_mode()
|
|
21
|
+
def test():
|
|
22
|
+
jit_model = torch.jit.script(model).eval()
|
|
23
|
+
assert isinstance(jit_model, torch.jit.ScriptModule)
|
|
24
|
+
for method in model_method:
|
|
25
|
+
assert hasattr(jit_model, method)
|
|
26
|
+
jit_model = torch.jit.freeze(jit_model, model_method)
|
|
27
|
+
for method in model_method:
|
|
28
|
+
assert hasattr(jit_model, method)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.1.0}/bayesianflow_for_chem/spectra.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.1.0}/test/test_molecular_embedding.py
RENAMED
|
File without changes
|