bayesianflow-for-chem 2.0.5__tar.gz → 2.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bayesianflow-for-chem might be problematic. Click here for more details.

Files changed (24) hide show
  1. {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.1.0}/PKG-INFO +1 -1
  2. {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.1.0}/bayesianflow_for_chem/__init__.py +1 -1
  3. {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.1.0}/bayesianflow_for_chem/cli.py +35 -4
  4. {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.1.0}/bayesianflow_for_chem/model.py +168 -1
  5. {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.1.0}/bayesianflow_for_chem/tool.py +91 -1
  6. {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.1.0}/bayesianflow_for_chem.egg-info/PKG-INFO +1 -1
  7. {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.1.0}/bayesianflow_for_chem.egg-info/SOURCES.txt +1 -0
  8. bayesianflow_for_chem-2.1.0/test/test_jit_compatibility.py +28 -0
  9. {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.1.0}/LICENSE +0 -0
  10. {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.1.0}/README.md +0 -0
  11. {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.1.0}/bayesianflow_for_chem/data.py +0 -0
  12. {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.1.0}/bayesianflow_for_chem/scorer.py +0 -0
  13. {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.1.0}/bayesianflow_for_chem/spectra.py +0 -0
  14. {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.1.0}/bayesianflow_for_chem/train.py +0 -0
  15. {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.1.0}/bayesianflow_for_chem/vocab.txt +0 -0
  16. {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.1.0}/bayesianflow_for_chem.egg-info/dependency_links.txt +0 -0
  17. {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.1.0}/bayesianflow_for_chem.egg-info/entry_points.txt +0 -0
  18. {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.1.0}/bayesianflow_for_chem.egg-info/requires.txt +0 -0
  19. {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.1.0}/bayesianflow_for_chem.egg-info/top_level.txt +0 -0
  20. {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.1.0}/pyproject.toml +0 -0
  21. {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.1.0}/setup.cfg +0 -0
  22. {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.1.0}/setup.py +0 -0
  23. {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.1.0}/test/test_merge_lora.py +0 -0
  24. {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.1.0}/test/test_molecular_embedding.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: bayesianflow_for_chem
3
- Version: 2.0.5
3
+ Version: 2.1.0
4
4
  Summary: Bayesian flow network framework for Chemistry
5
5
  Home-page: https://augus1999.github.io/bayesian-flow-network-for-chemistry/
6
6
  Author: Nianze A. Tao
@@ -18,7 +18,7 @@ __all__ = [
18
18
  "MLP",
19
19
  "EnsembleChemBFN",
20
20
  ]
21
- __version__ = "2.0.5"
21
+ __version__ = "2.1.0"
22
22
  __author__ = "Nianze A. Tao (Omozawa Sueno)"
23
23
 
24
24
 
@@ -32,7 +32,7 @@ from bayesianflow_for_chem.data import (
32
32
  collate,
33
33
  CSVData,
34
34
  )
35
- from bayesianflow_for_chem.tool import sample, inpaint
35
+ from bayesianflow_for_chem.tool import sample, inpaint, optimise, adjust_lora_
36
36
 
37
37
 
38
38
  """
@@ -99,9 +99,11 @@ sample_size = 1000 # the minimum number of samples you want
99
99
  sample_step = 100
100
100
  sample_method = "ODE:0.5" # ODE-solver with temperature of 0.5; another choice is "BFN"
101
101
  semi_autoregressive = false
102
+ lora_scaling = 1.0 # LoRA scaling if applied
102
103
  guidance_objective = [-0.023, 0.09, 0.113] # if no objective is needed set it to empty array []
103
104
  guidance_objective_strength = 4.0 # unnecessary if guidance_objective = []
104
105
  guidance_scaffold = "c1ccccc1" # if no scaffold is used set it to empty string ""
106
+ sample_template = "" # template for mol2mol task; leave it blank if scaffold is used
105
107
  unwanted_token = []
106
108
  exclude_invalid = true # to only store valid samples
107
109
  exclude_duplicate = true # to only store unique samples
@@ -130,7 +132,7 @@ def parse_cli(version: str) -> argparse.Namespace:
130
132
  """
131
133
  parser = argparse.ArgumentParser(
132
134
  description="Madmol: a CLI molecular design tool for "
133
- "de novo design, R-group replacement, and sequence in-filling, "
135
+ "de novo design, R-group replacement, molecule optimisation, and sequence in-filling, "
134
136
  "based on generative route of ChemBFN method. "
135
137
  "Let's make some craziest molecules.",
136
138
  epilog=f"Madmol {version}, developed in Hiroshima University by chemists for chemists. "
@@ -157,7 +159,7 @@ def parse_cli(version: str) -> argparse.Namespace:
157
159
  "-D",
158
160
  "--dryrun",
159
161
  action="store_true",
160
- help="dry-run to check the configurations",
162
+ help="dry-run to check the configurations and exit",
161
163
  )
162
164
  parser.add_argument("-V", "--version", action="version", version=version)
163
165
  return parser.parse_args()
@@ -284,6 +286,14 @@ def load_runtime_config(
284
286
  f"\033[0;33mWarning\033[0;0m in {config_file}: Directory {result_dir} to save the result does not exist."
285
287
  )
286
288
  flag_warning += 1
289
+ if (
290
+ config["inference"]["guidance_scaffold"] != ""
291
+ and config["inference"]["sample_template"] != ""
292
+ ):
293
+ print(
294
+ f"\033[0;33mWarning\033[0;0m in {config_file}: Inpaint task or mol2mol task?"
295
+ )
296
+ flag_warning += 1
287
297
  return config, flag_critical, flag_warning
288
298
 
289
299
 
@@ -520,6 +530,7 @@ def main_script(version: str) -> None:
520
530
  if "train" in runtime_config:
521
531
  bfn = model.model
522
532
  mlp = model.mlp
533
+ lora_scaling = runtime_config["inference"].get("lora_scaling", 1.0)
523
534
  # ####### strat inference #######
524
535
  bfn.semi_autoregressive = runtime_config["inference"]["semi_autoregressive"]
525
536
  _device = (
@@ -550,8 +561,16 @@ def main_script(version: str) -> None:
550
561
  x[:-1], (0, sequence_length - x.shape[-1] + 1), value=0
551
562
  )
552
563
  x = x[None, :].repeat(batch_size, 1)
564
+ # then sample template will be ignored.
565
+ elif runtime_config["inference"]["sample_template"]:
566
+ template = runtime_config["inference"]["sample_template"]
567
+ x = tokeniser(template)
568
+ x = torch.nn.functional.pad(x, (0, sequence_length - x.shape[-1]), value=0)
569
+ x = x[None, :].repeat(batch_size, 1)
553
570
  else:
554
571
  x = None
572
+ if bfn.lora_enabled:
573
+ adjust_lora_(bfn, lora_scaling)
555
574
  mols = []
556
575
  while len(mols) < runtime_config["inference"]["sample_size"]:
557
576
  if x is None:
@@ -567,7 +586,7 @@ def main_script(version: str) -> None:
567
586
  method=sample_method,
568
587
  allowed_tokens=allowed_token,
569
588
  )
570
- else:
589
+ elif runtime_config["inference"]["guidance_scaffold"]:
571
590
  s = inpaint(
572
591
  bfn,
573
592
  x,
@@ -579,6 +598,18 @@ def main_script(version: str) -> None:
579
598
  method=sample_method,
580
599
  allowed_tokens=allowed_token,
581
600
  )
601
+ else:
602
+ s = optimise(
603
+ bfn,
604
+ x,
605
+ sample_step,
606
+ y,
607
+ guidance_strength,
608
+ _device,
609
+ vocab_keys,
610
+ method=sample_method,
611
+ allowed_tokens=allowed_token,
612
+ )
582
613
  if runtime_config["inference"]["exclude_invalid"]:
583
614
  s = [i for i in s if i]
584
615
  if tokeniser_name == "smiles" or tokeniser_name == "safe":
@@ -676,7 +676,7 @@ class ChemBFN(nn.Module):
676
676
  token_mask: Optional[Tensor] = None,
677
677
  ) -> Tuple[Tensor, Tensor]:
678
678
  """
679
- Sample from a piror distribution.
679
+ Sample from a uniform piror distribution.
680
680
 
681
681
  :param batch_size: batch size
682
682
  :param sequence_size: max sequence length
@@ -880,6 +880,108 @@ class ChemBFN(nn.Module):
880
880
  p = p.masked_fill_(token_mask, 0.0)
881
881
  return torch.argmax(p, -1), entropy
882
882
 
883
+ @torch.jit.export
884
+ def optimise(
885
+ self,
886
+ x: Tensor,
887
+ y: Optional[Tensor] = None,
888
+ sample_step: int = 100,
889
+ guidance_strength: float = 4.0,
890
+ token_mask: Optional[Tensor] = None,
891
+ ) -> Tuple[Tensor, Tensor]:
892
+ """
893
+ Optimise the template molecule (mol2mol). \n
894
+ This method is equivalent to sampling from a customised prior distribution.
895
+
896
+ :param x: categorical indices of template; shape: (n_b, n_t)
897
+ :param y: conditioning vector; shape: (n_b, 1, n_f) or (n_b, n_f)
898
+ :param sample_step: number of sampling steps
899
+ :param guidance_strength: strength of conditional generation. It is not used if y is null.
900
+ :param token_mask: token mask assigning unwanted token(s) with `True`;
901
+ shape: (1, 1, n_vocab)
902
+ :type x: torch.Tensor
903
+ :type y: torch.Tensor | None
904
+ :type sample_step: int
905
+ :type guidance_strength: float
906
+ :type token_mask: torch.Tensor | None
907
+ :return: sampled token indices; shape: (n_b, n_t) \n
908
+ entropy of the tokens; shape: (n_b)
909
+ :rtype: tuple
910
+ """
911
+ n_b = x.shape[0]
912
+ x_onehot = nn.functional.one_hot(x, self.K).float()
913
+ theta = nn.functional.softmax(x_onehot, -1)
914
+ if y is not None:
915
+ y = self.reshape_y(y)
916
+ for i in torch.linspace(1, sample_step, sample_step, device=x.device):
917
+ t = (i - 1).view(1, 1, 1).repeat(n_b, 1, 1) / sample_step
918
+ p = self.discrete_output_distribution(theta, t, y, guidance_strength)
919
+ if token_mask is not None:
920
+ p = p.masked_fill_(token_mask, 0.0)
921
+ alpha = self.calc_discrete_alpha(t, t + 1 / sample_step)
922
+ e_k = nn.functional.one_hot(torch.argmax(p, -1), self.K).float()
923
+ mu = alpha * (self.K * e_k - 1)
924
+ sigma = (alpha * self.K).sqrt()
925
+ theta = (mu + sigma * torch.randn_like(mu)).exp() * theta
926
+ theta = theta / theta.sum(-1, True)
927
+ t_final = torch.ones((n_b, 1, 1), device=x.device)
928
+ p = self.discrete_output_distribution(theta, t_final, y, guidance_strength)
929
+ entropy = -(p * p.log()).sum(-1).mean(-1)
930
+ if token_mask is not None:
931
+ p = p.masked_fill_(token_mask, 0.0)
932
+ return torch.argmax(p, -1), entropy
933
+
934
+ @torch.jit.export
935
+ def ode_optimise(
936
+ self,
937
+ x: Tensor,
938
+ y: Optional[Tensor] = None,
939
+ sample_step: int = 100,
940
+ guidance_strength: float = 4.0,
941
+ token_mask: Optional[Tensor] = None,
942
+ temperature: float = 0.5,
943
+ ) -> Tuple[Tensor, Tensor]:
944
+ """
945
+ ODE mol2mol.
946
+
947
+ :param x: categorical indices of template; shape: (n_b, n_t)
948
+ :param y: conditioning vector; shape: (n_b, 1, n_f) or (n_b, n_f)
949
+ :param sample_step: number of sampling steps
950
+ :param guidance_strength: strength of conditional generation. It is not used if y is null.
951
+ :param token_mask: token mask assigning unwanted token(s) with `True`;
952
+ shape: (1, 1, n_vocab)
953
+ :param temperature: sampling temperature
954
+ :type x: torch.Tensor
955
+ :type y: torch.Tensor | None
956
+ :type sample_step: int
957
+ :type guidance_strength: float
958
+ :type token_mask: torch.Tensor | None
959
+ :type temperature: float
960
+ :return: sampled token indices; shape: (n_b, n_t) \n
961
+ entropy of the tokens; shape: (n_b)
962
+ :rtype: tuple
963
+ """
964
+ n_b = x.shape[0]
965
+ z = nn.functional.one_hot(x, self.K).float()
966
+ if y is not None:
967
+ y = self.reshape_y(y)
968
+ for i in torch.linspace(1, sample_step, sample_step, device=self.beta.device):
969
+ t = (i - 1).view(1, 1, 1).repeat(n_b, 1, 1) / sample_step
970
+ theta = torch.softmax(z, -1)
971
+ beta = self.calc_beta(t + 1 / sample_step)
972
+ p = self.discrete_output_distribution(theta, t, y, guidance_strength)
973
+ if token_mask is not None:
974
+ p = p.masked_fill_(token_mask, 0.0)
975
+ u = torch.randn_like(z)
976
+ z = (self.K * p - 1) * beta + (self.K * beta * temperature).sqrt() * u
977
+ t_final = torch.ones((n_b, 1, 1), device=self.beta.device)
978
+ theta = torch.softmax(z, -1)
979
+ p = self.discrete_output_distribution(theta, t_final, y, guidance_strength)
980
+ entropy = -(p * p.log()).sum(-1).mean(-1)
981
+ if token_mask is not None:
982
+ p = p.masked_fill_(token_mask, 0.0)
983
+ return torch.argmax(p, -1), entropy
984
+
883
985
  def inference(
884
986
  self, x: Tensor, mlp: MLP, embed_fn: Optional[Callable[[Tensor], Tensor]] = None
885
987
  ) -> Tensor:
@@ -1250,6 +1352,71 @@ class EnsembleChemBFN(ChemBFN):
1250
1352
  x, y, sample_step, guidance_strength, token_mask, temperature
1251
1353
  )
1252
1354
 
1355
+ @torch.inference_mode()
1356
+ def optimise(
1357
+ self,
1358
+ x: Tensor,
1359
+ conditions: Union[List[Tensor], Dict[str, Tensor]],
1360
+ sample_step: int = 100,
1361
+ guidance_strength: float = 4.0,
1362
+ token_mask: Optional[Tensor] = None,
1363
+ ) -> Tuple[Tensor, Tensor]:
1364
+ """
1365
+ Optimise the template molecule (mol2mol). \n
1366
+ This method is equivalent to sampling from a customised prior distribution.
1367
+
1368
+ :param x: categorical indices of template; shape: (n_b, n_t)
1369
+ :param conditions: conditioning vector; shape: (n_b, n_c) * n_h
1370
+ :param sample_step: number of sampling steps
1371
+ :param guidance_strength: strength of conditional generation. It is not used if y is null.
1372
+ :param token_mask: token mask assigning unwanted token(s) with `True`;
1373
+ shape: (1, 1, n_vocab)
1374
+ :type x: torch.Tensor
1375
+ :type y: torch.Tensor | None
1376
+ :type sample_step: int
1377
+ :type guidance_strength: float
1378
+ :type token_mask: torch.Tensor | None
1379
+ :return: sampled token indices; shape: (n_b, n_t) \n
1380
+ entropy of the tokens; shape: (n_b)
1381
+ :rtype: tuple
1382
+ """
1383
+ y = self.construct_y(conditions)
1384
+ return super().optimise(x, y, sample_step, guidance_strength, token_mask)
1385
+
1386
+ @torch.inference_mode()
1387
+ def ode_optimise(
1388
+ self,
1389
+ x: Tensor,
1390
+ conditions: Union[List[Tensor], Dict[str, Tensor]],
1391
+ sample_step: int = 100,
1392
+ guidance_strength: float = 4.0,
1393
+ token_mask: Optional[Tensor] = None,
1394
+ temperature: float = 0.5,
1395
+ ) -> Tuple[Tensor, Tensor]:
1396
+ """
1397
+ ODE inpainting.
1398
+
1399
+ :param x: categorical indices of template; shape: (n_b, n_t)
1400
+ :param conditions: conditioning vector; shape: (n_b, n_c) * n_h
1401
+ :param sample_step: number of sampling steps
1402
+ :param guidance_strength: strength of conditional generation. It is not used if y is null.
1403
+ :param token_mask: token mask; shape: (1, 1, n_vocab)
1404
+ :param temperature: sampling temperature
1405
+ :type x: torch.Tensor
1406
+ :type conditions: list | dict
1407
+ :type sample_step: int
1408
+ :type guidance_strength: float
1409
+ :type token_mask: torch.Tensor | None
1410
+ :type temperature: float
1411
+ :return: sampled token indices; shape: (n_b, n_t) \n
1412
+ entropy of the tokens; shape: (n_b)
1413
+ :rtype: tuple
1414
+ """
1415
+ y = self.construct_y(conditions)
1416
+ return super().ode_optimise(
1417
+ x, y, sample_step, guidance_strength, token_mask, temperature
1418
+ )
1419
+
1253
1420
  def quantise(
1254
1421
  self, quantise_method: Optional[Callable[[ChemBFN], None]] = None
1255
1422
  ) -> None:
@@ -219,7 +219,7 @@ def sample(
219
219
  sort: bool = False,
220
220
  ) -> List[str]:
221
221
  """
222
- Sampling.
222
+ Sampling molecules.
223
223
 
224
224
  :param model: trained ChemBFN model
225
225
  :param batch_size: batch size
@@ -385,6 +385,96 @@ def inpaint(
385
385
  ]
386
386
 
387
387
 
388
+ @torch.no_grad()
389
+ def optimise(
390
+ model: Union[ChemBFN, EnsembleChemBFN],
391
+ x: Tensor,
392
+ sample_step: int = 100,
393
+ y: Optional[Union[Tensor, Dict[str, Tensor], List[Tensor]]] = None,
394
+ guidance_strength: float = 4.0,
395
+ device: Union[str, torch.device, None] = None,
396
+ vocab_keys: List[str] = VOCAB_KEYS,
397
+ separator: str = "",
398
+ method: str = "BFN",
399
+ allowed_tokens: Union[str, List[str]] = "all",
400
+ sort: bool = False,
401
+ ) -> List[str]:
402
+ """
403
+ Optimising template molecules (mol2mol).
404
+
405
+ :param model: trained ChemBFN model
406
+ :param x: categorical indices of template; shape: (n_b, n_t)
407
+ :param sample_step: number of sampling steps
408
+ :param y: conditioning vector; shape: (n_b, 1, n_f) or (n_b, n_f) \n
409
+ or a list/`dict` of conditions; shape: (n_b, n_c) * n_h
410
+
411
+ :param guidance_strength: strength of conditional generation. It is not used if y is null.
412
+ :param device: hardware accelerator
413
+ :param vocab_keys: a list of (ordered) vocabulary
414
+ :param separator: token separator; default is `""`
415
+ :param method: sampling method chosen from `"ODE:x"` or `"BFN"` where `x` is the value of sampling temperature; default is `"BFN"`
416
+ :param allowed_tokens: a list of allowed tokens; default is `"all"`
417
+ :param sort: whether to sort the samples according to entropy values; default is `False`
418
+ :type model: bayesianflow_for_chem.model.ChemBFN | bayesianflow_for_chem.model.EnsembleChemBFN
419
+ :type x: torch.Tensor
420
+ :type sample_step: int
421
+ :type y: torch.Tensor | list | dict | None
422
+ :type guidance_strength: float
423
+ :type device: str | torch.device | None
424
+ :type vocab_keys: list
425
+ :type separator: str
426
+ :type method: str
427
+ :type allowed_tokens: str | list
428
+ :type sort: bool
429
+ :return: a list of generated molecular strings
430
+ :rtype: list
431
+ """
432
+ assert method.split(":")[0].lower() in ("ode", "bfn")
433
+ if isinstance(model, EnsembleChemBFN):
434
+ assert y is not None, "conditioning is required while using an ensemble model."
435
+ assert isinstance(y, list) or isinstance(y, dict)
436
+ else:
437
+ assert isinstance(y, Tensor) or y is None
438
+ if device is None:
439
+ device = _find_device()
440
+ model.to(device).eval()
441
+ x = x.to(device)
442
+ if y is not None:
443
+ if isinstance(y, Tensor):
444
+ y = y.to(device)
445
+ elif isinstance(y, list):
446
+ y = [i.to(device) for i in y]
447
+ elif isinstance(y, dict):
448
+ y = {k: v.to(device) for k, v in y.items()}
449
+ else:
450
+ raise NotImplementedError
451
+ if isinstance(allowed_tokens, list):
452
+ token_mask = [0 if i in allowed_tokens else 1 for i in vocab_keys]
453
+ token_mask = torch.tensor([[token_mask]], dtype=torch.bool).to(device)
454
+ else:
455
+ token_mask = None
456
+ if "ode" in method.lower():
457
+ tp = float(method.split(":")[-1])
458
+ assert tp > 0, "Sampling temperature should be higher than 0."
459
+ tokens, entropy = model.ode_optimise(
460
+ x, y, sample_step, guidance_strength, token_mask, tp
461
+ )
462
+ else:
463
+ tokens, entropy = model.optimise(
464
+ x, y, sample_step, guidance_strength, token_mask
465
+ )
466
+ if sort:
467
+ sorted_idx = entropy.argsort(stable=True)
468
+ tokens = tokens[sorted_idx]
469
+ return [
470
+ separator.join([vocab_keys[i] for i in j])
471
+ .split("<start>" + separator)[-1]
472
+ .split(separator + "<end>")[0]
473
+ .replace("<pad>", "")
474
+ for j in tokens
475
+ ]
476
+
477
+
388
478
  def quantise_model_(model: ChemBFN) -> None:
389
479
  """
390
480
  In-place dynamic quantisation of the trained model to `int8` data type. \n
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: bayesianflow_for_chem
3
- Version: 2.0.5
3
+ Version: 2.1.0
4
4
  Summary: Bayesian flow network framework for Chemistry
5
5
  Home-page: https://augus1999.github.io/bayesian-flow-network-for-chemistry/
6
6
  Author: Nianze A. Tao
@@ -17,5 +17,6 @@ bayesianflow_for_chem.egg-info/dependency_links.txt
17
17
  bayesianflow_for_chem.egg-info/entry_points.txt
18
18
  bayesianflow_for_chem.egg-info/requires.txt
19
19
  bayesianflow_for_chem.egg-info/top_level.txt
20
+ test/test_jit_compatibility.py
20
21
  test/test_merge_lora.py
21
22
  test/test_molecular_embedding.py
@@ -0,0 +1,28 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Author: Nianze A. Tao (Omozawa Sueno)
3
+ """
4
+ Model should be compatible with TorchScript.
5
+ """
6
+ import torch
7
+ from bayesianflow_for_chem import ChemBFN
8
+
9
+ model = ChemBFN(512)
10
+ model_method = [
11
+ "sample",
12
+ "ode_sample",
13
+ "inpaint",
14
+ "ode_inpaint",
15
+ "optimise",
16
+ "ode_optimise",
17
+ ]
18
+
19
+
20
+ @torch.inference_mode()
21
+ def test():
22
+ jit_model = torch.jit.script(model).eval()
23
+ assert isinstance(jit_model, torch.jit.ScriptModule)
24
+ for method in model_method:
25
+ assert hasattr(jit_model, method)
26
+ jit_model = torch.jit.freeze(jit_model, model_method)
27
+ for method in model_method:
28
+ assert hasattr(jit_model, method)