bayesianflow-for-chem 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of bayesianflow-for-chem might be problematic. Click here for more details.
- bayesianflow_for_chem/__init__.py +11 -0
- bayesianflow_for_chem/data.py +250 -0
- bayesianflow_for_chem/model.py +927 -0
- bayesianflow_for_chem/scorer.py +134 -0
- bayesianflow_for_chem/tool.py +470 -0
- bayesianflow_for_chem/train.py +243 -0
- bayesianflow_for_chem/vocab.txt +246 -0
- bayesianflow_for_chem-1.2.0.dist-info/METADATA +162 -0
- bayesianflow_for_chem-1.2.0.dist-info/RECORD +11 -0
- bayesianflow_for_chem-1.2.0.dist-info/WHEEL +5 -0
- bayesianflow_for_chem-1.2.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,243 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# Author: Nianze A. Tao (Omozawa Sueno)
|
|
3
|
+
"""
|
|
4
|
+
Define ChemBFN and regressor models for training.
|
|
5
|
+
"""
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Dict, Tuple, Union, Optional
|
|
8
|
+
import torch
|
|
9
|
+
import torch.optim as op
|
|
10
|
+
import torch.nn.functional as F
|
|
11
|
+
from loralib import lora_state_dict, mark_only_lora_as_trainable
|
|
12
|
+
from torch import Tensor
|
|
13
|
+
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
|
14
|
+
from lightning import LightningModule
|
|
15
|
+
from .model import ChemBFN, MLP
|
|
16
|
+
from .scorer import Scorer
|
|
17
|
+
|
|
18
|
+
DEFAULT_MODEL_HPARAM = {"lr": 5e-5, "lr_warmup_step": 1000, "uncond_prob": 0.2}
|
|
19
|
+
DEFAULT_REGRESSOR_HPARAM = {
|
|
20
|
+
"mode": "regression",
|
|
21
|
+
"lr_scheduler_factor": 0.8,
|
|
22
|
+
"lr_scheduler_patience": 20,
|
|
23
|
+
"lr_warmup_step": 1000,
|
|
24
|
+
"max_lr": 1e-4,
|
|
25
|
+
"freeze": False,
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class Model(LightningModule):
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
model: ChemBFN,
|
|
33
|
+
mlp: Optional[MLP] = None,
|
|
34
|
+
scorer: Optional[Scorer] = None,
|
|
35
|
+
hparam: Dict[str, Union[int, float]] = DEFAULT_MODEL_HPARAM,
|
|
36
|
+
) -> None:
|
|
37
|
+
"""
|
|
38
|
+
A `~lightning.LightningModule` wrapper of bayesian flow network for chemistry model.\n
|
|
39
|
+
This module is used in training stage only. By calling `Model(...).export_model(YOUR_WORK_DIR)` after training,
|
|
40
|
+
the model(s) will be saved to `YOUR_WORK_DIR/model.pt` and (if exists) `YOUR_WORK_DIR/mlp.pt`.
|
|
41
|
+
|
|
42
|
+
:param model: `~bayesianflow_for_chem.model.ChemBFN` instance.
|
|
43
|
+
:param mlp: `~bayesianflow_for_chem.model.MLP` instance or `None`.
|
|
44
|
+
:param scorer: `~bayesianflow_for_chem.scorer.Scorer` instance or `None`.
|
|
45
|
+
:param hparam: a `dict` instance of hyperparameters. See `bayesianflow_for_chem.train.DEFAULT_MODEL_HPARAM`.
|
|
46
|
+
:type model: bayesianflow_for_chem.model.ChemBFN
|
|
47
|
+
:type mlp: bayesianflow_for_chem.model.MLP | None
|
|
48
|
+
:type scorer: bayesianflow_for_chem.scorer.Scorer | None
|
|
49
|
+
:type hparam: dict
|
|
50
|
+
"""
|
|
51
|
+
super().__init__()
|
|
52
|
+
self.model = model
|
|
53
|
+
self.mlp = mlp
|
|
54
|
+
self.scorer = scorer
|
|
55
|
+
self.save_hyperparameters(hparam, ignore=["model", "mlp", "scorer"])
|
|
56
|
+
if model.lora_enabled:
|
|
57
|
+
mark_only_lora_as_trainable(self.model)
|
|
58
|
+
self.use_scorer = self.scorer is not None
|
|
59
|
+
|
|
60
|
+
def training_step(self, batch: Dict[str, Tensor]) -> Tensor:
|
|
61
|
+
x = batch["token"]
|
|
62
|
+
t = torch.rand((x.shape[0], 1, 1), device=x.device)
|
|
63
|
+
if "mask" in batch:
|
|
64
|
+
mask = batch["mask"]
|
|
65
|
+
else:
|
|
66
|
+
mask = None
|
|
67
|
+
if self.mlp is not None:
|
|
68
|
+
y = batch["value"]
|
|
69
|
+
y = self.mlp.forward(y)
|
|
70
|
+
if y.dim() == 2:
|
|
71
|
+
y = y[:, None, :]
|
|
72
|
+
y_mask = F.dropout(torch.ones_like(t), self.hparams.uncond_prob, True, True)
|
|
73
|
+
y_mask = (y_mask != 0).float()
|
|
74
|
+
loss, p = self.model.cts_loss(x, t, y * y_mask, mask, self.use_scorer)
|
|
75
|
+
else:
|
|
76
|
+
loss, p = self.model.cts_loss(x, t, None, mask, self.use_scorer)
|
|
77
|
+
self.log("continuous_time_loss", loss.item())
|
|
78
|
+
if self.use_scorer:
|
|
79
|
+
scorer_loss = self.scorer.calc_score_loss(p)
|
|
80
|
+
self.log(f"{self.scorer.name}_loss", scorer_loss.item())
|
|
81
|
+
loss += scorer_loss * self.scorer.eta
|
|
82
|
+
return loss
|
|
83
|
+
|
|
84
|
+
def configure_optimizers(self) -> Dict[str, op.AdamW]:
|
|
85
|
+
optimizer = op.AdamW(self.parameters(), lr=1e-8, weight_decay=0.01)
|
|
86
|
+
return {"optimizer": optimizer}
|
|
87
|
+
|
|
88
|
+
def optimizer_step(self, *args, **kwargs) -> None:
|
|
89
|
+
optimizer: op.AdamW = kwargs["optimizer"] if "optimizer" in kwargs else args[2]
|
|
90
|
+
# warm-up step
|
|
91
|
+
if self.trainer.global_step < self.hparams.lr_warmup_step:
|
|
92
|
+
lr_scale = int(self.trainer.global_step + 1) / self.hparams.lr_warmup_step
|
|
93
|
+
lr_scale = min(1.0, lr_scale)
|
|
94
|
+
for pg in optimizer.param_groups:
|
|
95
|
+
pg["lr"] = lr_scale * self.hparams.lr
|
|
96
|
+
super().optimizer_step(*args, **kwargs)
|
|
97
|
+
optimizer.zero_grad(set_to_none=True)
|
|
98
|
+
|
|
99
|
+
def export_model(self, workdir: Path) -> None:
|
|
100
|
+
"""
|
|
101
|
+
Save the trained model.
|
|
102
|
+
|
|
103
|
+
:param workdir: the directory to save the model(s)
|
|
104
|
+
:type workdir: pathlib.Path
|
|
105
|
+
:return:
|
|
106
|
+
:rtype: None
|
|
107
|
+
"""
|
|
108
|
+
if self.model.lora_enabled:
|
|
109
|
+
torch.save(
|
|
110
|
+
{
|
|
111
|
+
"lora_nn": lora_state_dict(self.model),
|
|
112
|
+
"lora_param": self.model.lora_param,
|
|
113
|
+
},
|
|
114
|
+
workdir / "lora.pt",
|
|
115
|
+
)
|
|
116
|
+
else:
|
|
117
|
+
torch.save(
|
|
118
|
+
{"nn": self.model.state_dict(), "hparam": self.model.hparam},
|
|
119
|
+
workdir / "model.pt",
|
|
120
|
+
)
|
|
121
|
+
if self.mlp is not None:
|
|
122
|
+
torch.save(
|
|
123
|
+
{"nn": self.mlp.state_dict(), "hparam": self.mlp.hparam},
|
|
124
|
+
workdir / "mlp.pt",
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
class Regressor(LightningModule):
|
|
129
|
+
def __init__(
|
|
130
|
+
self,
|
|
131
|
+
model: ChemBFN,
|
|
132
|
+
mlp: MLP,
|
|
133
|
+
hparam: Dict[str, Union[str, int, float, bool]] = DEFAULT_REGRESSOR_HPARAM,
|
|
134
|
+
) -> None:
|
|
135
|
+
"""
|
|
136
|
+
A `~lightning.LightningModule` wrapper of bayesian flow network for chemistry regression model.\n
|
|
137
|
+
This module is used in training stage only. By calling `Regressor(...).export_model(YOUR_WORK_DIR)` after training,
|
|
138
|
+
the models will be saved to `YOUR_WORK_DIR/model.pt` and `YOUR_WORK_DIR/readout.pt`.
|
|
139
|
+
|
|
140
|
+
:param model: `~bayesianflow_for_chem.model.ChemBFN` instance.
|
|
141
|
+
:param mlp: `~bayesianflow_for_chem.model.MLP` instance.
|
|
142
|
+
:param hparam: a `dict` instance of hyperparameters. See `bayesianflow_for_chem.train.DEFAULT_REGRESSOR_HPARAM`.
|
|
143
|
+
:type model: bayesianflow_for_chem.model.ChemBFN
|
|
144
|
+
:type mlp: bayesianflow_for_chem.model.MLP
|
|
145
|
+
:type hparam: dict
|
|
146
|
+
"""
|
|
147
|
+
super().__init__()
|
|
148
|
+
self.model = model
|
|
149
|
+
self.mlp = mlp
|
|
150
|
+
self.model.requires_grad_(not hparam["freeze"])
|
|
151
|
+
self.save_hyperparameters(hparam, ignore=["model", "mlp"])
|
|
152
|
+
if model.lora_enabled:
|
|
153
|
+
mark_only_lora_as_trainable(self.model)
|
|
154
|
+
assert hparam["mode"] in ("regression", "classification")
|
|
155
|
+
|
|
156
|
+
@staticmethod
|
|
157
|
+
def _mask_label(label: Tensor) -> Tuple[Tensor, Tensor]:
|
|
158
|
+
# find the unlabelled position(s)
|
|
159
|
+
label_mask = (label != torch.inf).float()
|
|
160
|
+
# masked the unlabelled position(s)
|
|
161
|
+
masked_label = label.masked_fill(label == torch.inf, 0)
|
|
162
|
+
return label_mask, masked_label
|
|
163
|
+
|
|
164
|
+
def training_step(self, batch: Dict[str, Tensor]) -> Tensor:
|
|
165
|
+
x, y = batch["token"], batch["value"]
|
|
166
|
+
z = self.model.inference(x, self.mlp)
|
|
167
|
+
if self.hparams.mode == "classification":
|
|
168
|
+
n_b, n_y = y.shape
|
|
169
|
+
z = z.reshape(n_b * n_y, -1)
|
|
170
|
+
loss = F.cross_entropy(z, y.reshape(-1).to(torch.long))
|
|
171
|
+
else:
|
|
172
|
+
y_mask, y = self._mask_label(y)
|
|
173
|
+
loss = F.mse_loss(z * y_mask, y, reduction="mean")
|
|
174
|
+
self.log("train_loss", loss.item())
|
|
175
|
+
return loss
|
|
176
|
+
|
|
177
|
+
def validation_step(self, batch: Dict[str, Tensor]) -> None:
|
|
178
|
+
x, y = batch["token"], batch["value"]
|
|
179
|
+
z = self.model.inference(x, self.mlp)
|
|
180
|
+
if self.hparams.mode == "classification":
|
|
181
|
+
n_b, n_y = y.shape
|
|
182
|
+
z = z.reshape(n_b * n_y, -1)
|
|
183
|
+
val_loss = 1 - (torch.argmax(z, -1) == y.reshape(-1)).float().mean()
|
|
184
|
+
else:
|
|
185
|
+
y_mask, y = self._mask_label(y)
|
|
186
|
+
val_loss = (z * y_mask - y).abs().sum() / y_mask.sum()
|
|
187
|
+
self.log("val_loss", val_loss.item())
|
|
188
|
+
|
|
189
|
+
def configure_optimizers(self) -> Dict:
|
|
190
|
+
optimizer = op.AdamW(self.parameters(), lr=1e-7, weight_decay=0.01)
|
|
191
|
+
lr_scheduler_config = {
|
|
192
|
+
"scheduler": ReduceLROnPlateau(
|
|
193
|
+
optimizer,
|
|
194
|
+
"min",
|
|
195
|
+
factor=self.hparams.lr_scheduler_factor,
|
|
196
|
+
patience=self.hparams.lr_scheduler_patience,
|
|
197
|
+
min_lr=1e-6,
|
|
198
|
+
),
|
|
199
|
+
"interval": "epoch",
|
|
200
|
+
"monitor": "val_loss",
|
|
201
|
+
"frequency": 1,
|
|
202
|
+
"strict": True,
|
|
203
|
+
}
|
|
204
|
+
return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_config}
|
|
205
|
+
|
|
206
|
+
def optimizer_step(self, *args, **kwargs) -> None:
|
|
207
|
+
optimizer: op.AdamW = kwargs["optimizer"] if "optimizer" in kwargs else args[2]
|
|
208
|
+
# warm-up step
|
|
209
|
+
if self.trainer.global_step < self.hparams.lr_warmup_step:
|
|
210
|
+
lr_scale = int(self.trainer.global_step + 1) / self.hparams.lr_warmup_step
|
|
211
|
+
lr_scale = min(1.0, lr_scale)
|
|
212
|
+
for pg in optimizer.param_groups:
|
|
213
|
+
pg["lr"] = lr_scale * self.hparams.max_lr
|
|
214
|
+
super().optimizer_step(*args, **kwargs)
|
|
215
|
+
optimizer.zero_grad(set_to_none=True)
|
|
216
|
+
|
|
217
|
+
def export_model(self, workdir: Path) -> None:
|
|
218
|
+
"""
|
|
219
|
+
Save the trained model.
|
|
220
|
+
|
|
221
|
+
:param workdir: the directory to save the model
|
|
222
|
+
:type workdir: pathlib.Path
|
|
223
|
+
:return:
|
|
224
|
+
:rtype: None
|
|
225
|
+
"""
|
|
226
|
+
torch.save(
|
|
227
|
+
{"nn": self.mlp.state_dict(), "hparam": self.mlp.hparam},
|
|
228
|
+
workdir / "readout.pt",
|
|
229
|
+
)
|
|
230
|
+
if not self.hparams.freeze:
|
|
231
|
+
if self.model.lora_enabled:
|
|
232
|
+
torch.save(
|
|
233
|
+
{
|
|
234
|
+
"lora_nn": lora_state_dict(self.model),
|
|
235
|
+
"lora_param": self.model.lora_param,
|
|
236
|
+
},
|
|
237
|
+
workdir / "lora.pt",
|
|
238
|
+
)
|
|
239
|
+
else:
|
|
240
|
+
torch.save(
|
|
241
|
+
{"nn": self.model.state_dict(), "hparam": self.model.hparam},
|
|
242
|
+
workdir / "model_ft.pt",
|
|
243
|
+
)
|
|
@@ -0,0 +1,246 @@
|
|
|
1
|
+
<pad>
|
|
2
|
+
<start>
|
|
3
|
+
<end>
|
|
4
|
+
>
|
|
5
|
+
>>
|
|
6
|
+
[
|
|
7
|
+
]
|
|
8
|
+
(
|
|
9
|
+
)
|
|
10
|
+
.
|
|
11
|
+
=
|
|
12
|
+
#
|
|
13
|
+
-
|
|
14
|
+
+
|
|
15
|
+
\
|
|
16
|
+
/
|
|
17
|
+
:
|
|
18
|
+
~
|
|
19
|
+
@
|
|
20
|
+
?
|
|
21
|
+
*
|
|
22
|
+
$
|
|
23
|
+
0
|
|
24
|
+
1
|
|
25
|
+
2
|
|
26
|
+
3
|
|
27
|
+
4
|
|
28
|
+
5
|
|
29
|
+
6
|
|
30
|
+
7
|
|
31
|
+
8
|
|
32
|
+
9
|
|
33
|
+
H
|
|
34
|
+
He
|
|
35
|
+
Li
|
|
36
|
+
Be
|
|
37
|
+
B
|
|
38
|
+
C
|
|
39
|
+
N
|
|
40
|
+
O
|
|
41
|
+
F
|
|
42
|
+
Ne
|
|
43
|
+
Na
|
|
44
|
+
Mg
|
|
45
|
+
Al
|
|
46
|
+
Si
|
|
47
|
+
P
|
|
48
|
+
S
|
|
49
|
+
Cl
|
|
50
|
+
Ar
|
|
51
|
+
K
|
|
52
|
+
Ca
|
|
53
|
+
Sc
|
|
54
|
+
Ti
|
|
55
|
+
V
|
|
56
|
+
Cr
|
|
57
|
+
Mn
|
|
58
|
+
Fe
|
|
59
|
+
Co
|
|
60
|
+
Ni
|
|
61
|
+
Cu
|
|
62
|
+
Zn
|
|
63
|
+
Ga
|
|
64
|
+
Ge
|
|
65
|
+
As
|
|
66
|
+
Se
|
|
67
|
+
Br
|
|
68
|
+
Kr
|
|
69
|
+
Rb
|
|
70
|
+
Sr
|
|
71
|
+
Y
|
|
72
|
+
Zr
|
|
73
|
+
Nb
|
|
74
|
+
Mo
|
|
75
|
+
Tc
|
|
76
|
+
Ru
|
|
77
|
+
Rh
|
|
78
|
+
Pd
|
|
79
|
+
Ag
|
|
80
|
+
Cd
|
|
81
|
+
In
|
|
82
|
+
Sn
|
|
83
|
+
Sb
|
|
84
|
+
Te
|
|
85
|
+
I
|
|
86
|
+
Xe
|
|
87
|
+
Cs
|
|
88
|
+
Ba
|
|
89
|
+
Hf
|
|
90
|
+
Ta
|
|
91
|
+
W
|
|
92
|
+
Re
|
|
93
|
+
Os
|
|
94
|
+
Ir
|
|
95
|
+
Pt
|
|
96
|
+
Au
|
|
97
|
+
Hg
|
|
98
|
+
Tl
|
|
99
|
+
Pb
|
|
100
|
+
Bi
|
|
101
|
+
Po
|
|
102
|
+
At
|
|
103
|
+
Rn
|
|
104
|
+
Fr
|
|
105
|
+
Ra
|
|
106
|
+
Rf
|
|
107
|
+
Db
|
|
108
|
+
Sg
|
|
109
|
+
Bh
|
|
110
|
+
Hs
|
|
111
|
+
Mt
|
|
112
|
+
Ds
|
|
113
|
+
Rg
|
|
114
|
+
Cn
|
|
115
|
+
Nh
|
|
116
|
+
Fl
|
|
117
|
+
Mc
|
|
118
|
+
Lv
|
|
119
|
+
Ts
|
|
120
|
+
Og
|
|
121
|
+
La
|
|
122
|
+
Ce
|
|
123
|
+
Pr
|
|
124
|
+
Nd
|
|
125
|
+
Pm
|
|
126
|
+
Sm
|
|
127
|
+
Eu
|
|
128
|
+
Gd
|
|
129
|
+
Tb
|
|
130
|
+
Dy
|
|
131
|
+
Ho
|
|
132
|
+
Er
|
|
133
|
+
Tm
|
|
134
|
+
Yb
|
|
135
|
+
Lu
|
|
136
|
+
Ac
|
|
137
|
+
Th
|
|
138
|
+
Pa
|
|
139
|
+
U
|
|
140
|
+
Np
|
|
141
|
+
Pu
|
|
142
|
+
Am
|
|
143
|
+
Cm
|
|
144
|
+
Bk
|
|
145
|
+
Cf
|
|
146
|
+
Es
|
|
147
|
+
Fm
|
|
148
|
+
Md
|
|
149
|
+
No
|
|
150
|
+
Lr
|
|
151
|
+
b
|
|
152
|
+
c
|
|
153
|
+
n
|
|
154
|
+
o
|
|
155
|
+
s
|
|
156
|
+
p
|
|
157
|
+
%10
|
|
158
|
+
%11
|
|
159
|
+
%12
|
|
160
|
+
%13
|
|
161
|
+
%14
|
|
162
|
+
%15
|
|
163
|
+
%16
|
|
164
|
+
%17
|
|
165
|
+
%18
|
|
166
|
+
%19
|
|
167
|
+
%20
|
|
168
|
+
%21
|
|
169
|
+
%22
|
|
170
|
+
%23
|
|
171
|
+
%24
|
|
172
|
+
%25
|
|
173
|
+
%26
|
|
174
|
+
%27
|
|
175
|
+
%28
|
|
176
|
+
%29
|
|
177
|
+
%30
|
|
178
|
+
%31
|
|
179
|
+
%32
|
|
180
|
+
%33
|
|
181
|
+
%34
|
|
182
|
+
%35
|
|
183
|
+
%36
|
|
184
|
+
%37
|
|
185
|
+
%38
|
|
186
|
+
%39
|
|
187
|
+
%40
|
|
188
|
+
%41
|
|
189
|
+
%42
|
|
190
|
+
%43
|
|
191
|
+
%44
|
|
192
|
+
%45
|
|
193
|
+
%46
|
|
194
|
+
%47
|
|
195
|
+
%48
|
|
196
|
+
%49
|
|
197
|
+
%50
|
|
198
|
+
%51
|
|
199
|
+
%52
|
|
200
|
+
%53
|
|
201
|
+
%54
|
|
202
|
+
%55
|
|
203
|
+
%56
|
|
204
|
+
%57
|
|
205
|
+
%58
|
|
206
|
+
%59
|
|
207
|
+
%60
|
|
208
|
+
%61
|
|
209
|
+
%62
|
|
210
|
+
%63
|
|
211
|
+
%64
|
|
212
|
+
%65
|
|
213
|
+
%66
|
|
214
|
+
%67
|
|
215
|
+
%68
|
|
216
|
+
%69
|
|
217
|
+
%70
|
|
218
|
+
%71
|
|
219
|
+
%72
|
|
220
|
+
%73
|
|
221
|
+
%74
|
|
222
|
+
%75
|
|
223
|
+
%76
|
|
224
|
+
%77
|
|
225
|
+
%78
|
|
226
|
+
%79
|
|
227
|
+
%80
|
|
228
|
+
%81
|
|
229
|
+
%82
|
|
230
|
+
%83
|
|
231
|
+
%84
|
|
232
|
+
%85
|
|
233
|
+
%86
|
|
234
|
+
%87
|
|
235
|
+
%88
|
|
236
|
+
%89
|
|
237
|
+
%90
|
|
238
|
+
%91
|
|
239
|
+
%92
|
|
240
|
+
%93
|
|
241
|
+
%94
|
|
242
|
+
%95
|
|
243
|
+
%96
|
|
244
|
+
%97
|
|
245
|
+
%98
|
|
246
|
+
%99
|
|
@@ -0,0 +1,162 @@
|
|
|
1
|
+
Metadata-Version: 2.2
|
|
2
|
+
Name: bayesianflow_for_chem
|
|
3
|
+
Version: 1.2.0
|
|
4
|
+
Summary: Bayesian flow network framework for Chemistry
|
|
5
|
+
Home-page: https://augus1999.github.io/bayesian-flow-network-for-chemistry/
|
|
6
|
+
Author: Nianze A. Tao
|
|
7
|
+
Author-email: tao-nianze@hiroshima-u.ac.jp
|
|
8
|
+
License: AGPL-3.0 licence
|
|
9
|
+
Project-URL: Source, https://github.com/Augus1999/bayesian-flow-network-for-chemistry
|
|
10
|
+
Keywords: Chemistry,CLM,ChemBFN
|
|
11
|
+
Classifier: Development Status :: 5 - Production/Stable
|
|
12
|
+
Classifier: Intended Audience :: Science/Research
|
|
13
|
+
Classifier: License :: OSI Approved :: GNU Affero General Public License v3
|
|
14
|
+
Classifier: Natural Language :: English
|
|
15
|
+
Classifier: Programming Language :: Python :: 3
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
17
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
18
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
19
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
20
|
+
Classifier: Topic :: Scientific/Engineering :: Chemistry
|
|
21
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
22
|
+
Requires-Python: >=3.9
|
|
23
|
+
Description-Content-Type: text/markdown
|
|
24
|
+
Requires-Dist: rdkit>=2023.9.6
|
|
25
|
+
Requires-Dist: torch>=2.3.1
|
|
26
|
+
Requires-Dist: numpy>=1.26.4
|
|
27
|
+
Requires-Dist: loralib>=0.1.2
|
|
28
|
+
Requires-Dist: lightning>=2.2.0
|
|
29
|
+
Requires-Dist: scikit-learn>=1.5.0
|
|
30
|
+
Requires-Dist: typing_extensions>=4.8.0
|
|
31
|
+
Provides-Extra: geo2seq
|
|
32
|
+
Requires-Dist: pynauty>=2.8.8.1; extra == "geo2seq"
|
|
33
|
+
Dynamic: author
|
|
34
|
+
Dynamic: author-email
|
|
35
|
+
Dynamic: classifier
|
|
36
|
+
Dynamic: description
|
|
37
|
+
Dynamic: description-content-type
|
|
38
|
+
Dynamic: home-page
|
|
39
|
+
Dynamic: keywords
|
|
40
|
+
Dynamic: license
|
|
41
|
+
Dynamic: project-url
|
|
42
|
+
Dynamic: provides-extra
|
|
43
|
+
Dynamic: requires-dist
|
|
44
|
+
Dynamic: requires-python
|
|
45
|
+
Dynamic: summary
|
|
46
|
+
|
|
47
|
+
# ChemBFN: Bayesian Flow Network for Chemistry
|
|
48
|
+
|
|
49
|
+
[](https://doi.org/10.1021/acs.jcim.4c01792)
|
|
50
|
+
[](https://arxiv.org/abs/2412.11439)
|
|
51
|
+
|
|
52
|
+
This is the repository of the PyTorch implementation of ChemBFN model.
|
|
53
|
+
|
|
54
|
+
## Features
|
|
55
|
+
|
|
56
|
+
ChemBFN provides the state-of-the-art functionalities of
|
|
57
|
+
* SMILES or SELFIES-based *de novo* molecule generation
|
|
58
|
+
* Protein sequence *de novo* generation
|
|
59
|
+
* Classifier-free guidance conditional generation (single or multi-objective optimisation)
|
|
60
|
+
* Context-guided conditional generation (inpaint)
|
|
61
|
+
* Outstanding out-of-distribution chemical space sampling
|
|
62
|
+
* Fast sampling via ODE solver
|
|
63
|
+
* Molecular property and activity prediction finetuning
|
|
64
|
+
* Reaction yield prediction finetuning
|
|
65
|
+
|
|
66
|
+
in an all-in-one-model style.
|
|
67
|
+
|
|
68
|
+
## News
|
|
69
|
+
|
|
70
|
+
* [21/01/2025] Our first paper has been accepted by [JCIM](https://pubs.acs.org/doi/10.1021/acs.jcim.4c01792).
|
|
71
|
+
* [17/12/2024] The second paper of out-of-distribution generation is available on [arxiv.org](https://arxiv.org/abs/2412.11439).
|
|
72
|
+
* [31/07/2024] Paper is available on [arxiv.org](https://arxiv.org/abs/2407.20294).
|
|
73
|
+
* [21/07/2024] Paper was submitted to arXiv.
|
|
74
|
+
|
|
75
|
+
## Install
|
|
76
|
+
|
|
77
|
+
```bash
|
|
78
|
+
$ pip install -U bayesianflow_for_chemistry
|
|
79
|
+
```
|
|
80
|
+
|
|
81
|
+
## Usage
|
|
82
|
+
|
|
83
|
+
You can find example scripts in [📁example](./example) folder.
|
|
84
|
+
|
|
85
|
+
## Pre-trained Model
|
|
86
|
+
|
|
87
|
+
You can find pretrained models in [release](https://github.com/Augus1999/bayesian-flow-network-for-chemistry/releases) or on our [🤗Hugging Face model page](https://huggingface.co/suenoomozawa/ChemBFN).
|
|
88
|
+
|
|
89
|
+
## Dataset Handling
|
|
90
|
+
|
|
91
|
+
We provide a Python class [`CSVData`](./bayesianflow_for_chem/data.py) to handle data stored in CSV or similar format containing headers to identify the entities. The following is a quickstart.
|
|
92
|
+
|
|
93
|
+
1. Download your dataset file (e.g., ESOL form [MoleculeNet](https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/delaney-processed.csv)) and split the file:
|
|
94
|
+
```python
|
|
95
|
+
>>> from bayesianflow_for_chem.tool import split_data
|
|
96
|
+
|
|
97
|
+
>>> split_data("delaney-processed.csv", method="scaffold")
|
|
98
|
+
```
|
|
99
|
+
|
|
100
|
+
2. Load the split data:
|
|
101
|
+
```python
|
|
102
|
+
>>> from bayesianflow_for_chem.data import smiles2token, collate, CSVData
|
|
103
|
+
|
|
104
|
+
>>> dataset = CSVData("delaney-processed_train.csv")
|
|
105
|
+
>>> dataset[0]
|
|
106
|
+
{'Compound ID': ['Thiophene'],
|
|
107
|
+
'ESOL predicted log solubility in mols per litre': ['-2.2319999999999998'],
|
|
108
|
+
'Minimum Degree': ['2'],
|
|
109
|
+
'Molecular Weight': ['84.14299999999999'],
|
|
110
|
+
'Number of H-Bond Donors': ['0'],
|
|
111
|
+
'Number of Rings': ['1'],
|
|
112
|
+
'Number of Rotatable Bonds': ['0'],
|
|
113
|
+
'Polar Surface Area': ['0.0'],
|
|
114
|
+
'measured log solubility in mols per litre': ['-1.33'],
|
|
115
|
+
'smiles': ['c1ccsc1']}
|
|
116
|
+
```
|
|
117
|
+
|
|
118
|
+
3. Create a mapping function to tokenise the dataset and select values:
|
|
119
|
+
```python
|
|
120
|
+
>>> import torch
|
|
121
|
+
|
|
122
|
+
>>> def encode(x):
|
|
123
|
+
... smiles = x["smiles"][0]
|
|
124
|
+
... value = [float(i) for i in x["measured log solubility in mols per litre"]]
|
|
125
|
+
... return {"token": smiles2token(smiles), "value": torch.tensor(value)}
|
|
126
|
+
|
|
127
|
+
>>> dataset.map(encode)
|
|
128
|
+
>>> dataset[0]
|
|
129
|
+
{'token': tensor([ 1, 151, 23, 151, 151, 154, 151, 23, 2]),
|
|
130
|
+
'value': tensor([-1.3300])}
|
|
131
|
+
```
|
|
132
|
+
|
|
133
|
+
4. Wrap the dataset in <u>torch.utils.data.DataLoader</u>:
|
|
134
|
+
```python
|
|
135
|
+
>>> dataloader = torch.utils.data.DataLoader(dataset, 32, collate_fn=collate)
|
|
136
|
+
```
|
|
137
|
+
|
|
138
|
+
## Cite This Work
|
|
139
|
+
|
|
140
|
+
```bibtex
|
|
141
|
+
@misc{2024chembfn,
|
|
142
|
+
title={A Bayesian Flow Network Framework for Chemistry Tasks},
|
|
143
|
+
author={Nianze Tao and Minori Abe},
|
|
144
|
+
year={2024},
|
|
145
|
+
eprint={2407.20294},
|
|
146
|
+
archivePrefix={arXiv},
|
|
147
|
+
primaryClass={cs.LG},
|
|
148
|
+
url={https://arxiv.org/abs/2407.20294},
|
|
149
|
+
}
|
|
150
|
+
```
|
|
151
|
+
Out-of-distribution generation:
|
|
152
|
+
```bibtex
|
|
153
|
+
@misc{2024chembfn_ood,
|
|
154
|
+
title={Bayesian Flow Is All You Need to Sample Out-of-Distribution Chemical Spaces},
|
|
155
|
+
author={Nianze Tao},
|
|
156
|
+
year={2024},
|
|
157
|
+
eprint={2412.11439},
|
|
158
|
+
archivePrefix={arXiv},
|
|
159
|
+
primaryClass={cs.LG},
|
|
160
|
+
url={https://arxiv.org/abs/2412.11439},
|
|
161
|
+
}
|
|
162
|
+
```
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
bayesianflow_for_chem/__init__.py,sha256=9XG7bc3Fup_O4QjFVcUF-_ZkVioDlvmK1GHs2Caxo50,293
|
|
2
|
+
bayesianflow_for_chem/data.py,sha256=9tpRba40lxwrB6aPSJMkxUglEVC3VEQC9wWxhDuz3Q8,7760
|
|
3
|
+
bayesianflow_for_chem/model.py,sha256=xD4ZpJSFMCdfh5k4--8Lpa34nn_sUvgklK6d9CPjz10,35434
|
|
4
|
+
bayesianflow_for_chem/scorer.py,sha256=mV1vX8aBGFra2BE7N8WHihVIo3dXmUdPQIGfSaiuNdk,4084
|
|
5
|
+
bayesianflow_for_chem/tool.py,sha256=a8AnH3geBpNpF6SFiTqKBrWlwDcQzii1zSQiUoyiLgY,17009
|
|
6
|
+
bayesianflow_for_chem/train.py,sha256=kj6icGqymUUYopDtpre1oE_wpvpeNilbpzgffBsd1tk,9589
|
|
7
|
+
bayesianflow_for_chem/vocab.txt,sha256=HgtAZmpWYk4y8PqEVC4vqut1vE75DfRKE_10s2UW0rU,790
|
|
8
|
+
bayesianflow_for_chem-1.2.0.dist-info/METADATA,sha256=WXd5mvyBbHX5Ba8uF6jkys8SCfdurkMRkxg-AedcBlc,5743
|
|
9
|
+
bayesianflow_for_chem-1.2.0.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
|
|
10
|
+
bayesianflow_for_chem-1.2.0.dist-info/top_level.txt,sha256=KHsanI3BMCt8D9Qpze2ycrF6nMa3PyojgO6eS1c8kco,22
|
|
11
|
+
bayesianflow_for_chem-1.2.0.dist-info/RECORD,,
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
bayesianflow_for_chem
|