bayesianflow-for-chem 2.0.2__tar.gz → 2.0.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bayesianflow-for-chem might be problematic. Click here for more details.

Files changed (23) hide show
  1. {bayesianflow_for_chem-2.0.2 → bayesianflow_for_chem-2.0.4}/PKG-INFO +7 -1
  2. {bayesianflow_for_chem-2.0.2 → bayesianflow_for_chem-2.0.4}/README.md +5 -0
  3. {bayesianflow_for_chem-2.0.2 → bayesianflow_for_chem-2.0.4}/bayesianflow_for_chem/__init__.py +4 -1
  4. {bayesianflow_for_chem-2.0.2 → bayesianflow_for_chem-2.0.4}/bayesianflow_for_chem/cli.py +33 -18
  5. {bayesianflow_for_chem-2.0.2 → bayesianflow_for_chem-2.0.4}/bayesianflow_for_chem/model.py +13 -0
  6. {bayesianflow_for_chem-2.0.2 → bayesianflow_for_chem-2.0.4}/bayesianflow_for_chem/tool.py +4 -1
  7. {bayesianflow_for_chem-2.0.2 → bayesianflow_for_chem-2.0.4}/bayesianflow_for_chem/train.py +1 -1
  8. {bayesianflow_for_chem-2.0.2 → bayesianflow_for_chem-2.0.4}/bayesianflow_for_chem.egg-info/PKG-INFO +7 -1
  9. {bayesianflow_for_chem-2.0.2 → bayesianflow_for_chem-2.0.4}/bayesianflow_for_chem.egg-info/SOURCES.txt +3 -1
  10. {bayesianflow_for_chem-2.0.2 → bayesianflow_for_chem-2.0.4}/bayesianflow_for_chem.egg-info/requires.txt +1 -0
  11. {bayesianflow_for_chem-2.0.2 → bayesianflow_for_chem-2.0.4}/setup.py +1 -0
  12. bayesianflow_for_chem-2.0.4/test/test_merge_lora.py +40 -0
  13. bayesianflow_for_chem-2.0.4/test/test_molecular_embedding.py +67 -0
  14. {bayesianflow_for_chem-2.0.2 → bayesianflow_for_chem-2.0.4}/LICENSE +0 -0
  15. {bayesianflow_for_chem-2.0.2 → bayesianflow_for_chem-2.0.4}/bayesianflow_for_chem/data.py +0 -0
  16. {bayesianflow_for_chem-2.0.2 → bayesianflow_for_chem-2.0.4}/bayesianflow_for_chem/scorer.py +0 -0
  17. {bayesianflow_for_chem-2.0.2 → bayesianflow_for_chem-2.0.4}/bayesianflow_for_chem/spectra.py +0 -0
  18. {bayesianflow_for_chem-2.0.2 → bayesianflow_for_chem-2.0.4}/bayesianflow_for_chem/vocab.txt +0 -0
  19. {bayesianflow_for_chem-2.0.2 → bayesianflow_for_chem-2.0.4}/bayesianflow_for_chem.egg-info/dependency_links.txt +0 -0
  20. {bayesianflow_for_chem-2.0.2 → bayesianflow_for_chem-2.0.4}/bayesianflow_for_chem.egg-info/entry_points.txt +0 -0
  21. {bayesianflow_for_chem-2.0.2 → bayesianflow_for_chem-2.0.4}/bayesianflow_for_chem.egg-info/top_level.txt +0 -0
  22. {bayesianflow_for_chem-2.0.2 → bayesianflow_for_chem-2.0.4}/pyproject.toml +0 -0
  23. {bayesianflow_for_chem-2.0.2 → bayesianflow_for_chem-2.0.4}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: bayesianflow_for_chem
3
- Version: 2.0.2
3
+ Version: 2.0.4
4
4
  Summary: Bayesian flow network framework for Chemistry
5
5
  Home-page: https://augus1999.github.io/bayesian-flow-network-for-chemistry/
6
6
  Author: Nianze A. Tao
@@ -23,6 +23,7 @@ License-File: LICENSE
23
23
  Requires-Dist: rdkit>=2025.3.5
24
24
  Requires-Dist: torch>=2.8.0
25
25
  Requires-Dist: torchao>=0.12
26
+ Requires-Dist: colorama>=0.4.6
26
27
  Requires-Dist: numpy>=2.3.2
27
28
  Requires-Dist: scipy>=1.16.1
28
29
  Requires-Dist: loralib>=0.1.2
@@ -49,6 +50,11 @@ Dynamic: summary
49
50
 
50
51
  This is the repository of the PyTorch implementation of ChemBFN model.
51
52
 
53
+ ### Build State
54
+
55
+ [![PyPI](https://img.shields.io/pypi/v/bayesianflow-for-chem?color=ff69b4)](https://pypi.org/project/bayesianflow-for-chem/)
56
+ ![pytest](https://github.com/Augus1999/bayesian-flow-network-for-chemistry/actions/workflows/pytest.yml/badge.svg)
57
+
52
58
  ## Features
53
59
 
54
60
  ChemBFN provides the state-of-the-art functionalities of
@@ -5,6 +5,11 @@
5
5
 
6
6
  This is the repository of the PyTorch implementation of ChemBFN model.
7
7
 
8
+ ### Build State
9
+
10
+ [![PyPI](https://img.shields.io/pypi/v/bayesianflow-for-chem?color=ff69b4)](https://pypi.org/project/bayesianflow-for-chem/)
11
+ ![pytest](https://github.com/Augus1999/bayesian-flow-network-for-chemistry/actions/workflows/pytest.yml/badge.svg)
12
+
8
13
  ## Features
9
14
 
10
15
  ChemBFN provides the state-of-the-art functionalities of
@@ -3,6 +3,7 @@
3
3
  """
4
4
  ChemBFN package.
5
5
  """
6
+ import colorama
6
7
  from . import data, tool, train, scorer, spectra
7
8
  from .model import ChemBFN, MLP, EnsembleChemBFN
8
9
  from .cli import main_script
@@ -17,7 +18,7 @@ __all__ = [
17
18
  "MLP",
18
19
  "EnsembleChemBFN",
19
20
  ]
20
- __version__ = "2.0.2"
21
+ __version__ = "2.0.4"
21
22
  __author__ = "Nianze A. Tao (Omozawa Sueno)"
22
23
 
23
24
 
@@ -28,4 +29,6 @@ def main() -> None:
28
29
  :return:
29
30
  :rtype: None
30
31
  """
32
+ colorama.just_fix_windows_console()
31
33
  main_script(__version__)
34
+ colorama.deinit()
@@ -130,10 +130,11 @@ def parse_cli(version: str) -> argparse.Namespace:
130
130
  """
131
131
  parser = argparse.ArgumentParser(
132
132
  description="Madmol: a CLI molecular design tool for "
133
- "de novo design and R-group replacement, "
133
+ "de novo design, R-group replacement, and sequence in-filling, "
134
134
  "based on generative route of ChemBFN method. "
135
135
  "Let's make some craziest molecules.",
136
- epilog=f"Madmol {version}, developed in Hiroshima University",
136
+ epilog=f"Madmol {version}, developed in Hiroshima University by chemists for chemists. "
137
+ "Visit https://augus1999.github.io/bayesian-flow-network-for-chemistry/ for more details.",
137
138
  formatter_class=argparse.ArgumentDefaultsHelpFormatter,
138
139
  )
139
140
  parser.add_argument(
@@ -180,14 +181,16 @@ def load_model_config(
180
181
  model_config = tomllib.load(f)
181
182
  if model_config["ChemBFN"]["num_vocab"] != "match vocabulary size":
182
183
  if not isinstance(model_config["ChemBFN"]["num_vocab"], int):
183
- print(f"Critical in {config_file}: You must specify num_vocab.")
184
+ print(
185
+ f"\033[0;31mCritical\033[0;0m in {config_file}: You must specify num_vocab."
186
+ )
184
187
  flag_critical += 1
185
188
  if model_config["ChemBFN"]["base_model"]:
186
189
  model_file = model_config["ChemBFN"]["base_model"]
187
190
  for fn in model_file:
188
191
  if not os.path.exists(fn):
189
192
  print(
190
- f"Critical in {config_file}: Base model file {fn} does not exist."
193
+ f"\033[0;31mCritical\033[0;0m in {config_file}: Base model file {fn} does not exist."
191
194
  )
192
195
  flag_critical += 1
193
196
  if "MLP" in model_config:
@@ -195,14 +198,14 @@ def load_model_config(
195
198
  b = model_config["MLP"]["size"][-1]
196
199
  if a != b:
197
200
  print(
198
- f"Critical in {config_file}: MLP hidden size {b} should match ChemBFN hidden size {a}."
201
+ f"\033[0;31mCritical\033[0;0m in {config_file}: MLP hidden size {b} should match ChemBFN hidden size {a}."
199
202
  )
200
203
  flag_critical += 1
201
204
  if model_config["MLP"]["base_model"]:
202
205
  model_file = model_config["MLP"]["base_model"]
203
206
  if not os.path.exists(model_file):
204
207
  print(
205
- f"Critical in {config_file}: Base model file {fn} does not exist."
208
+ f"\033[0;31mCritical\033[0;0m in {config_file}: Base model file {fn} does not exist."
206
209
  )
207
210
  flag_critical += 1
208
211
  return model_config, flag_critical, flag_warning
@@ -226,49 +229,61 @@ def load_runtime_config(
226
229
  config = tomllib.load(f)
227
230
  tokeniser_name = config["tokeniser"]["name"].lower()
228
231
  if not tokeniser_name in "smiles selfies safe fasta".split():
229
- print(f"Critical in {config_file}: Unknown tokensier name: {tokeniser_name}.")
232
+ print(
233
+ f"\033[0;31mCritical\033[0;0m in {config_file}: Unknown tokensier name: {tokeniser_name}."
234
+ )
230
235
  flag_critical += 1
231
236
  if tokeniser_name == "selfies":
232
237
  vocab = config["tokeniser"]["vocab"]
233
238
  if vocab.lower() == "default":
234
- print(f"Critical in {config_file}: You should specify a vocabulary file.")
239
+ print(
240
+ f"\033[0;31mCritical\033[0;0m in {config_file}: You should specify a vocabulary file."
241
+ )
235
242
  flag_critical += 1
236
243
  elif not os.path.exists(vocab):
237
- print(f"Critical in {config_file}: Vocabulary file {vocab} does not exist.")
244
+ print(
245
+ f"\033[0;31mCritical\033[0;0m in {config_file}: Vocabulary file {vocab} does not exist."
246
+ )
238
247
  flag_critical += 1
239
248
  if "train" in config:
240
249
  dataset_file = config["train"]["dataset"]
241
250
  if not os.path.exists(dataset_file):
242
251
  print(
243
- f"Critical in {config_file}: Dataset file {dataset_file} does not exist."
252
+ f"\033[0;31mCritical\033[0;0m in {config_file}: Dataset file {dataset_file} does not exist."
244
253
  )
245
254
  flag_critical += 1
246
255
  logger_name = config["train"]["logger_name"].lower()
247
256
  if not logger_name in "csv tensorboard wandb".split():
248
- print(f"Critical in {config_file}: Unknown logger: {logger_name}.")
257
+ print(
258
+ f"\033[0;31mCritical\033[0;0m in {config_file}: Unknown logger: {logger_name}."
259
+ )
249
260
  flag_critical += 1
250
261
  if config["train"]["restart"]:
251
262
  ckpt_file = config["train"]["restart"]
252
263
  if not os.path.exists(ckpt_file):
253
264
  print(
254
- f"Critical in {config_file}: Restart checkpoint file {ckpt_file} does not exist."
265
+ f"\033[0;31mCritical\033[0;0m in {config_file}: Restart checkpoint file {ckpt_file} does not exist."
255
266
  )
256
267
  flag_critical += 1
257
268
  if "inference" in config:
258
269
  if not "train" in config:
259
270
  if not isinstance(config["inference"]["sequence_length"], int):
260
271
  print(
261
- f"Critical in {config_file}: You must set an integer for sequence_length."
272
+ f"\033[0;31mCritical\033[0;0m in {config_file}: You must set an integer for sequence_length."
262
273
  )
263
274
  flag_critical += 1
264
275
  if config["inference"]["guidance_objective"]:
265
276
  if not "guidance_objective_strength" in config["inference"]:
266
277
  print(
267
- f"Critical in {config_file}: You need to add guidance_objective_strength."
278
+ f"\033[0;31mCritical\033[0;0m in {config_file}: You need to add guidance_objective_strength."
268
279
  )
269
280
  flag_critical += 1
270
281
  result_dir = Path(config["inference"]["result_file"]).parent
271
- assert os.path.exists(result_dir), f"directory {result_dir} does not exist."
282
+ if not os.path.exists(result_dir):
283
+ print(
284
+ f"\033[0;33mWarning\033[0;0m in {config_file}: Directory {result_dir} to save the result does not exist."
285
+ )
286
+ flag_warning += 1
272
287
  return config, flag_critical, flag_warning
273
288
 
274
289
 
@@ -306,7 +321,7 @@ def main_script(version: str) -> None:
306
321
  if runtime_config["train"]["enable_lora"]:
307
322
  if not model_config["ChemBFN"]["base_model"]:
308
323
  print(
309
- f"Warning in {parser.model_config}: You should load a pretrained model first."
324
+ f"\033[0;33mWarning\033[0;0m in {parser.model_config}: You should load a pretrained model first."
310
325
  )
311
326
  flag_warning += 1
312
327
  if not os.path.exists(runtime_config["train"]["checkpoint_save_path"]):
@@ -314,12 +329,12 @@ def main_script(version: str) -> None:
314
329
  else:
315
330
  if not model_config["ChemBFN"]["base_model"]:
316
331
  print(
317
- f"Warning in {parser.model_config}: You should load a pretrained ChemBFN model."
332
+ f"\033[0;33mWarning\033[0;0m in {parser.model_config}: You should load a pretrained ChemBFN model."
318
333
  )
319
334
  flag_warning += 1
320
335
  if not model_config["MLP"]["base_model"]:
321
336
  print(
322
- f"Warning in {parser.model_config}: You should load a pretrained MLP."
337
+ f"\033[0;33mWarning\033[0;0m in {parser.model_config}: You should load a pretrained MLP."
323
338
  )
324
339
  flag_warning += 1
325
340
  if "inference" in runtime_config:
@@ -1038,6 +1038,19 @@ class EnsembleChemBFN(ChemBFN):
1038
1038
  self.__delattr__("lora_enabled")
1039
1039
  self.__delattr__("lora_param")
1040
1040
  self.__delattr__("hparam")
1041
+ # ------- merge LoRA parameters to reduce the latency -------
1042
+ for _, v in self.models.items():
1043
+ for module in v.modules():
1044
+ if hasattr(module, "lora_A"):
1045
+ module.weight.data += (
1046
+ module.lora_B @ module.lora_A
1047
+ ) * module.scaling
1048
+ module.lora_enabled = False
1049
+ module.lora_A = None
1050
+ module.lora_B = None
1051
+ module.scaling = None
1052
+ module.lora_dropout = None
1053
+ v.lora_enabled = False
1041
1054
 
1042
1055
  def construct_y(
1043
1056
  self, c: Union[List[Tensor], Dict[str, Tensor]]
@@ -9,6 +9,7 @@ import warnings
9
9
  from pathlib import Path
10
10
  from typing import List, Dict, Tuple, Union, Optional
11
11
  import torch
12
+ import colorama
12
13
  import numpy as np
13
14
  from torch import cuda, Tensor, softmax
14
15
  from torch.utils.data import DataLoader
@@ -141,6 +142,7 @@ def split_dataset(
141
142
  assert file.endswith(".csv")
142
143
  assert len(split_ratio) == 3
143
144
  assert method in ("random", "scaffold")
145
+ colorama.just_fix_windows_console()
144
146
  with open(file, "r") as f:
145
147
  data = list(csv.reader(f))
146
148
  header = data[0]
@@ -198,6 +200,7 @@ def split_dataset(
198
200
  with open(file.replace(".csv", "_val.csv"), "w", newline="") as fva:
199
201
  writer = csv.writer(fva)
200
202
  writer.writerows([header] + val_set)
203
+ colorama.deinit()
201
204
 
202
205
 
203
206
  @torch.no_grad()
@@ -467,7 +470,7 @@ class GeometryConverter:
467
470
  spin: float = 0.0,
468
471
  ) -> Tuple[List[str], np.ndarray]:
469
472
  """
470
- Guess the 3D geometry from SMILES string via MMFF conformer search.
473
+ Guess the 3D geometry from SMILES string via conformer search.
471
474
 
472
475
  :param smiles: a valid SMILES string
473
476
  :param num_conformers: number of initial conformers
@@ -134,7 +134,7 @@ class Regressor(LightningModule):
134
134
  hparam: Dict[str, Union[str, int, float, bool]] = DEFAULT_REGRESSOR_HPARAM,
135
135
  ) -> None:
136
136
  """
137
- A `~lightning.LightningModule` wrapper of bayesian flow network for chemistry regression model.\n
137
+ A `~lightning.LightningModule` wrapper of bayesian flow network for chemistry regression or classification model.\n
138
138
  This module is used in training stage only. By calling `Regressor(...).export_model(YOUR_WORK_DIR)` after training,
139
139
  the models will be saved to `YOUR_WORK_DIR/model_ft.pt` (if LoRA is enabled then `YOUR_WORK_DIR/lora.pt`)
140
140
  and `YOUR_WORK_DIR/readout.pt`.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: bayesianflow_for_chem
3
- Version: 2.0.2
3
+ Version: 2.0.4
4
4
  Summary: Bayesian flow network framework for Chemistry
5
5
  Home-page: https://augus1999.github.io/bayesian-flow-network-for-chemistry/
6
6
  Author: Nianze A. Tao
@@ -23,6 +23,7 @@ License-File: LICENSE
23
23
  Requires-Dist: rdkit>=2025.3.5
24
24
  Requires-Dist: torch>=2.8.0
25
25
  Requires-Dist: torchao>=0.12
26
+ Requires-Dist: colorama>=0.4.6
26
27
  Requires-Dist: numpy>=2.3.2
27
28
  Requires-Dist: scipy>=1.16.1
28
29
  Requires-Dist: loralib>=0.1.2
@@ -49,6 +50,11 @@ Dynamic: summary
49
50
 
50
51
  This is the repository of the PyTorch implementation of ChemBFN model.
51
52
 
53
+ ### Build State
54
+
55
+ [![PyPI](https://img.shields.io/pypi/v/bayesianflow-for-chem?color=ff69b4)](https://pypi.org/project/bayesianflow-for-chem/)
56
+ ![pytest](https://github.com/Augus1999/bayesian-flow-network-for-chemistry/actions/workflows/pytest.yml/badge.svg)
57
+
52
58
  ## Features
53
59
 
54
60
  ChemBFN provides the state-of-the-art functionalities of
@@ -16,4 +16,6 @@ bayesianflow_for_chem.egg-info/SOURCES.txt
16
16
  bayesianflow_for_chem.egg-info/dependency_links.txt
17
17
  bayesianflow_for_chem.egg-info/entry_points.txt
18
18
  bayesianflow_for_chem.egg-info/requires.txt
19
- bayesianflow_for_chem.egg-info/top_level.txt
19
+ bayesianflow_for_chem.egg-info/top_level.txt
20
+ test/test_merge_lora.py
21
+ test/test_molecular_embedding.py
@@ -1,6 +1,7 @@
1
1
  rdkit>=2025.3.5
2
2
  torch>=2.8.0
3
3
  torchao>=0.12
4
+ colorama>=0.4.6
4
5
  numpy>=2.3.2
5
6
  scipy>=1.16.1
6
7
  loralib>=0.1.2
@@ -55,6 +55,7 @@ setup(
55
55
  "rdkit>=2025.3.5",
56
56
  "torch>=2.8.0",
57
57
  "torchao>=0.12",
58
+ "colorama>=0.4.6",
58
59
  "numpy>=2.3.2",
59
60
  "scipy>=1.16.1",
60
61
  "loralib>=0.1.2",
@@ -0,0 +1,40 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Author: Nianze A. Tao (Omozawa Sueno)
3
+ """
4
+ Model output should be almost identical before and after emerging LoRA parameters into base model.
5
+ """
6
+ import torch
7
+ from bayesianflow_for_chem import ChemBFN, MLP
8
+ from bayesianflow_for_chem.tool import merge_lora_
9
+ from bayesianflow_for_chem.data import VOCAB_COUNT, smiles2token, collate
10
+
11
+ torch.manual_seed(8964)
12
+
13
+ model = ChemBFN(VOCAB_COUNT)
14
+ model.enable_lora(r=8)
15
+ model.eval()
16
+ mlp = MLP([512, 256, 3], dropout=0.7)
17
+ mlp.eval()
18
+ for module in model.modules():
19
+ if hasattr(module, "lora_B"):
20
+ torch.nn.init.kaiming_uniform_(module.lora_B, a=5**0.5)
21
+
22
+ x = collate(
23
+ [{"token": smiles2token("c1ccccc1O")}, {"token": smiles2token("[NH4+]CCCCCC[O-]")}]
24
+ )["token"]
25
+
26
+
27
+ @torch.inference_mode()
28
+ def test():
29
+ model.semi_autoregressive = False
30
+ y1 = model.inference(x, mlp)
31
+ model.semi_autoregressive = True
32
+ y2 = model.inference(x, mlp)
33
+ merge_lora_(model)
34
+ model.semi_autoregressive = False
35
+ y3 = model.inference(x, mlp)
36
+ model.semi_autoregressive = True
37
+ y4 = model.inference(x, mlp)
38
+ assert not model.lora_enabled
39
+ assert (y1 - y3).abs().mean() < 1e-6
40
+ assert (y2 - y4).abs().mean() < 1e-6
@@ -0,0 +1,67 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Author: Nianze A. Tao (Omozawa Sueno)
3
+ """
4
+ Molecular embedding vectors should not be affected by <pad> tokens.
5
+ """
6
+ from functools import partial
7
+ import torch
8
+ from bayesianflow_for_chem import ChemBFN, MLP
9
+ from bayesianflow_for_chem.data import VOCAB_COUNT, smiles2token
10
+
11
+ torch.manual_seed(8964)
12
+
13
+ model = ChemBFN(VOCAB_COUNT)
14
+ model.eval()
15
+ mlp1 = MLP([512, 256, 3], dropout=0.7)
16
+ mlp1.eval()
17
+ mlp2 = MLP([1024, 512, 3], dropout=0.7)
18
+ mlp2.eval()
19
+
20
+ x = smiles2token("c1ccccc1O.[NH4+]CCCCCC[O-]")
21
+ x1 = x[None, ...]
22
+ x2 = torch.nn.functional.pad(x1, (0, 7, 0, 0))
23
+
24
+
25
+ def embed_fn(z, sar_flag, mask, x):
26
+ mb0 = z[x == 2].view(z.shape[0], -1) if sar_flag else z[::, 0]
27
+ mb1 = (z * mask[..., None]).sum(1) / (mask != 0).float().sum(1, True)
28
+ return torch.cat([mb0, mb1], -1)
29
+
30
+
31
+ @torch.inference_mode()
32
+ def test():
33
+ model.semi_autoregressive = False
34
+ y1 = model.inference(x1, mlp1)
35
+ y2 = model.inference(x2, mlp1)
36
+ assert (y1 != y2).sum() == 0
37
+ model.semi_autoregressive = True
38
+ y1 = model.inference(x1, mlp1)
39
+ y2 = model.inference(x2, mlp1)
40
+ assert (y1 != y2).sum() == 0
41
+ # ------- customised embedding extraction -------
42
+ mask1 = torch.tensor([[0] + [0.7] * 9 + [0] + [0.3] * 16 + [0]])
43
+ mask2 = torch.tensor([[0] + [0.7] * 9 + [0] + [0.3] * 16 + [0] * 8])
44
+ model.semi_autoregressive = False
45
+ y1 = model.inference(
46
+ x1,
47
+ mlp2,
48
+ partial(embed_fn, sar_flag=model.semi_autoregressive, mask=mask1, x=x1),
49
+ )
50
+ y2 = model.inference(
51
+ x2,
52
+ mlp2,
53
+ partial(embed_fn, sar_flag=model.semi_autoregressive, mask=mask2, x=x2),
54
+ )
55
+ assert (y1 != y2).sum() == 0
56
+ model.semi_autoregressive = True
57
+ y1 = model.inference(
58
+ x1,
59
+ mlp2,
60
+ partial(embed_fn, sar_flag=model.semi_autoregressive, mask=mask1, x=x1),
61
+ )
62
+ y2 = model.inference(
63
+ x2,
64
+ mlp2,
65
+ partial(embed_fn, sar_flag=model.semi_autoregressive, mask=mask2, x=x2),
66
+ )
67
+ assert (y1 != y2).sum() == 0