hjxdl 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hdl/__init__.py +0 -0
- hdl/_version.py +16 -0
- hdl/args/__init__.py +0 -0
- hdl/args/loss_args.py +5 -0
- hdl/controllers/__init__.py +0 -0
- hdl/controllers/al/__init__.py +0 -0
- hdl/controllers/al/al.py +0 -0
- hdl/controllers/al/dispatcher.py +0 -0
- hdl/controllers/al/feedback.py +0 -0
- hdl/controllers/explain/__init__.py +0 -0
- hdl/controllers/explain/shapley.py +293 -0
- hdl/controllers/explain/subgraphx.py +865 -0
- hdl/controllers/train/__init__.py +0 -0
- hdl/controllers/train/rxn_train.py +219 -0
- hdl/controllers/train/train.py +50 -0
- hdl/controllers/train/train_ginet.py +316 -0
- hdl/controllers/train/trainer_base.py +155 -0
- hdl/controllers/train/trainer_iterative.py +389 -0
- hdl/data/__init__.py +0 -0
- hdl/data/dataset/__init__.py +0 -0
- hdl/data/dataset/base_dataset.py +98 -0
- hdl/data/dataset/fp/__init__.py +0 -0
- hdl/data/dataset/fp/fp_dataset.py +122 -0
- hdl/data/dataset/graph/__init__.py +0 -0
- hdl/data/dataset/graph/chiral.py +62 -0
- hdl/data/dataset/graph/gin.py +255 -0
- hdl/data/dataset/graph/molnet.py +362 -0
- hdl/data/dataset/loaders/__init__.py +0 -0
- hdl/data/dataset/loaders/chiral_graph.py +71 -0
- hdl/data/dataset/loaders/collate_funcs/__init__.py +0 -0
- hdl/data/dataset/loaders/collate_funcs/fp.py +56 -0
- hdl/data/dataset/loaders/collate_funcs/rxn.py +40 -0
- hdl/data/dataset/loaders/general.py +23 -0
- hdl/data/dataset/loaders/spliter.py +86 -0
- hdl/data/dataset/samplers/__init__.py +0 -0
- hdl/data/dataset/samplers/chiral.py +19 -0
- hdl/data/dataset/seq/__init__.py +0 -0
- hdl/data/dataset/seq/rxn_dataset.py +61 -0
- hdl/data/dataset/utils.py +31 -0
- hdl/data/to_mols.py +0 -0
- hdl/features/__init__.py +0 -0
- hdl/features/fp/__init__.py +0 -0
- hdl/features/fp/features_generators.py +235 -0
- hdl/features/graph/__init__.py +0 -0
- hdl/features/graph/featurization.py +297 -0
- hdl/features/utils/__init__.py +0 -0
- hdl/features/utils/utils.py +111 -0
- hdl/layers/__init__.py +0 -0
- hdl/layers/general/__init__.py +0 -0
- hdl/layers/general/gp.py +14 -0
- hdl/layers/general/linear.py +641 -0
- hdl/layers/graph/__init__.py +0 -0
- hdl/layers/graph/chiral_graph.py +230 -0
- hdl/layers/graph/gcn.py +16 -0
- hdl/layers/graph/gin.py +45 -0
- hdl/layers/graph/tetra.py +158 -0
- hdl/layers/graph/transformer.py +188 -0
- hdl/layers/sequential/__init__.py +0 -0
- hdl/metric_loss/__init__.py +0 -0
- hdl/metric_loss/loss.py +79 -0
- hdl/metric_loss/metric.py +178 -0
- hdl/metric_loss/multi_label.py +42 -0
- hdl/metric_loss/nt_xent.py +65 -0
- hdl/models/__init__.py +0 -0
- hdl/models/chiral_gnn.py +176 -0
- hdl/models/fast_transformer.py +234 -0
- hdl/models/ginet.py +189 -0
- hdl/models/linear.py +137 -0
- hdl/models/model_dict.py +18 -0
- hdl/models/norm_flows.py +33 -0
- hdl/models/optim_dict.py +16 -0
- hdl/models/rxn.py +63 -0
- hdl/models/utils.py +83 -0
- hdl/ops/__init__.py +0 -0
- hdl/ops/utils.py +42 -0
- hdl/optims/__init__.py +0 -0
- hdl/optims/nadam.py +86 -0
- hdl/utils/__init__.py +0 -0
- hdl/utils/chemical_tools/__init__.py +2 -0
- hdl/utils/chemical_tools/query_info.py +149 -0
- hdl/utils/chemical_tools/sdf.py +20 -0
- hdl/utils/database_tools/__init__.py +0 -0
- hdl/utils/database_tools/connect.py +28 -0
- hdl/utils/general/__init__.py +0 -0
- hdl/utils/general/glob.py +21 -0
- hdl/utils/schedulers/__init__.py +0 -0
- hdl/utils/schedulers/norm_lr.py +108 -0
- hjxdl-0.0.1.dist-info/METADATA +19 -0
- hjxdl-0.0.1.dist-info/RECORD +91 -0
- hjxdl-0.0.1.dist-info/WHEEL +5 -0
- hjxdl-0.0.1.dist-info/top_level.txt +1 -0
|
File without changes
|
|
@@ -0,0 +1,219 @@
|
|
|
1
|
+
from os import path as osp
|
|
2
|
+
import typing as t
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from torch import nn
|
|
6
|
+
|
|
7
|
+
from hdl.models.rxn import build_rxn_mu
|
|
8
|
+
from hdl.models.utils import load_model, save_model
|
|
9
|
+
from hdl.data.dataset.seq.rxn_dataset import RXNCSVDataset
|
|
10
|
+
from hdl.data.dataset.loaders.rxn_loader import RXNLoader
|
|
11
|
+
from hdl.metric_loss.loss import mtmc_loss
|
|
12
|
+
from jupyfuncs.pbar import tnrange, tqdm
|
|
13
|
+
from jupyfuncs.glob import makedirs
|
|
14
|
+
# from hdl.optims.nadam import Nadam
|
|
15
|
+
from torch.optim import Adam
|
|
16
|
+
# from .trainer_base import TorchTrainer
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def train_a_batch(
|
|
20
|
+
model,
|
|
21
|
+
batch_data,
|
|
22
|
+
loss_func,
|
|
23
|
+
optimizer,
|
|
24
|
+
device,
|
|
25
|
+
individual,
|
|
26
|
+
**kwargs
|
|
27
|
+
):
|
|
28
|
+
optimizer.zero_grad()
|
|
29
|
+
|
|
30
|
+
X = [x.to(device) for x in batch_data[0]]
|
|
31
|
+
y = batch_data[1].T.to(device)
|
|
32
|
+
|
|
33
|
+
y_preds = model(X)
|
|
34
|
+
loss = mtmc_loss(
|
|
35
|
+
y_preds,
|
|
36
|
+
y,
|
|
37
|
+
loss_func,
|
|
38
|
+
individual=individual, **kwargs
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
if not individual:
|
|
42
|
+
final_loss = loss
|
|
43
|
+
individual_losses = []
|
|
44
|
+
else:
|
|
45
|
+
final_loss = loss[0]
|
|
46
|
+
individual_losses = loss[1]
|
|
47
|
+
|
|
48
|
+
final_loss.backward()
|
|
49
|
+
optimizer.step()
|
|
50
|
+
|
|
51
|
+
return final_loss, individual_losses
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def train_an_epoch(
|
|
55
|
+
base_dir: str,
|
|
56
|
+
model,
|
|
57
|
+
data_loader,
|
|
58
|
+
epoch_id: int,
|
|
59
|
+
loss_func,
|
|
60
|
+
optimizer,
|
|
61
|
+
device,
|
|
62
|
+
num_warm_epochs: int = 0,
|
|
63
|
+
individual: bool = True,
|
|
64
|
+
**kwargs
|
|
65
|
+
):
|
|
66
|
+
if epoch_id < num_warm_epochs:
|
|
67
|
+
model.freeze_encoder = True
|
|
68
|
+
else:
|
|
69
|
+
model.freeze_encoder = False
|
|
70
|
+
|
|
71
|
+
for batch in tqdm(data_loader):
|
|
72
|
+
loss, individual_losses = train_a_batch(
|
|
73
|
+
model=model,
|
|
74
|
+
batch_data=batch,
|
|
75
|
+
loss_func=loss_func,
|
|
76
|
+
optimizer=optimizer,
|
|
77
|
+
device=device,
|
|
78
|
+
individual=individual,
|
|
79
|
+
**kwargs
|
|
80
|
+
)
|
|
81
|
+
with open(
|
|
82
|
+
osp.join(base_dir, 'loss.log'),
|
|
83
|
+
'a'
|
|
84
|
+
) as f:
|
|
85
|
+
f.write(str(loss.item()))
|
|
86
|
+
f.write('\t')
|
|
87
|
+
for individual_loss in individual_losses:
|
|
88
|
+
f.write(str(individual_loss))
|
|
89
|
+
f.write('\t')
|
|
90
|
+
f.write('\n')
|
|
91
|
+
|
|
92
|
+
ckpt_file = osp.join(
|
|
93
|
+
base_dir,
|
|
94
|
+
f'model.{epoch_id}.ckpt'
|
|
95
|
+
)
|
|
96
|
+
save_model(
|
|
97
|
+
model=model,
|
|
98
|
+
save_dir=ckpt_file,
|
|
99
|
+
epoch=epoch_id,
|
|
100
|
+
optimizer=optimizer,
|
|
101
|
+
loss=loss,
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def train_rxn(
|
|
106
|
+
base_dir,
|
|
107
|
+
model,
|
|
108
|
+
num_epochs,
|
|
109
|
+
loss_func,
|
|
110
|
+
data_loader,
|
|
111
|
+
optimizer,
|
|
112
|
+
device,
|
|
113
|
+
num_warm_epochs: int = 10,
|
|
114
|
+
ckpt_file: str = None,
|
|
115
|
+
individual: bool = True,
|
|
116
|
+
**kwargs
|
|
117
|
+
):
|
|
118
|
+
|
|
119
|
+
epoch = 0
|
|
120
|
+
if ckpt_file is not None:
|
|
121
|
+
|
|
122
|
+
model, optimizer, epoch, _ = load_model(
|
|
123
|
+
ckpt_file,
|
|
124
|
+
model=model,
|
|
125
|
+
optimizer=optimizer,
|
|
126
|
+
train=True,
|
|
127
|
+
device=device,
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
for epoch_id in tnrange(num_epochs):
|
|
131
|
+
|
|
132
|
+
train_an_epoch(
|
|
133
|
+
base_dir=base_dir,
|
|
134
|
+
model=model,
|
|
135
|
+
data_loader=data_loader,
|
|
136
|
+
epoch_id=epoch + epoch_id,
|
|
137
|
+
loss_func=loss_func,
|
|
138
|
+
optimizer=optimizer,
|
|
139
|
+
num_warm_epochs=num_warm_epochs,
|
|
140
|
+
device=device,
|
|
141
|
+
individual=individual,
|
|
142
|
+
**kwargs
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def rxn_engine(
|
|
147
|
+
base_dir: str,
|
|
148
|
+
csv_file: str,
|
|
149
|
+
splitter: str,
|
|
150
|
+
smiles_col: str,
|
|
151
|
+
hard: bool = False,
|
|
152
|
+
num_epochs: int = 20,
|
|
153
|
+
target_cols: t.List = [],
|
|
154
|
+
nums_classes: t.List = [],
|
|
155
|
+
loss_func: str = 'ce',
|
|
156
|
+
num_warm_epochs: int = 10,
|
|
157
|
+
batch_size: int = 128,
|
|
158
|
+
hidden_size: int = 128,
|
|
159
|
+
lr: float = 0.01,
|
|
160
|
+
num_hidden_layers: int = 10,
|
|
161
|
+
shuffle: bool = True,
|
|
162
|
+
num_workers: int = 12,
|
|
163
|
+
dim=-1,
|
|
164
|
+
out_act='softmax',
|
|
165
|
+
device_id: int = 0,
|
|
166
|
+
individual: bool = True,
|
|
167
|
+
**kwargs
|
|
168
|
+
):
|
|
169
|
+
|
|
170
|
+
base_dir = osp.abspath(base_dir)
|
|
171
|
+
makedirs(base_dir)
|
|
172
|
+
model, device = build_rxn_mu(
|
|
173
|
+
nums_classes=nums_classes,
|
|
174
|
+
hard=hard,
|
|
175
|
+
hidden_size=hidden_size,
|
|
176
|
+
nums_hidden_layers=num_hidden_layers,
|
|
177
|
+
dim=dim,
|
|
178
|
+
out_act=out_act,
|
|
179
|
+
device_id=device_id
|
|
180
|
+
)
|
|
181
|
+
if torch.cuda.device_count() > 1:
|
|
182
|
+
model = nn.DataParallel(model)
|
|
183
|
+
model.train()
|
|
184
|
+
|
|
185
|
+
params = [{
|
|
186
|
+
'params': model.parameters(),
|
|
187
|
+
'lr': lr,
|
|
188
|
+
'weight_decay': 0
|
|
189
|
+
}]
|
|
190
|
+
optimizer = Adam(params)
|
|
191
|
+
|
|
192
|
+
dataset = RXNCSVDataset(
|
|
193
|
+
csv_file=csv_file,
|
|
194
|
+
splitter=splitter,
|
|
195
|
+
smiles_col=smiles_col,
|
|
196
|
+
target_cols=target_cols,
|
|
197
|
+
)
|
|
198
|
+
data_loader = RXNLoader(
|
|
199
|
+
dataset=dataset,
|
|
200
|
+
batch_size=batch_size,
|
|
201
|
+
shuffle=shuffle,
|
|
202
|
+
num_workers=num_workers
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
train_rxn(
|
|
206
|
+
base_dir=base_dir,
|
|
207
|
+
model=model,
|
|
208
|
+
num_epochs=num_epochs,
|
|
209
|
+
loss_func=loss_func,
|
|
210
|
+
data_loader=data_loader,
|
|
211
|
+
optimizer=optimizer,
|
|
212
|
+
device=device,
|
|
213
|
+
num_warm_epochs=num_warm_epochs,
|
|
214
|
+
ckpt_file=None,
|
|
215
|
+
individual=individual,
|
|
216
|
+
**kwargs
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
from torch_geometric.data import DataLoader
|
|
2
|
+
import warnings
|
|
3
|
+
warnings.filterwarnings("ignore")
|
|
4
|
+
|
|
5
|
+
# Root mean squared error
|
|
6
|
+
loss_fn = torch.nn.MSELoss()
|
|
7
|
+
optimizer = torch.optim.Adam(model.parameters(), lr=0.0007)
|
|
8
|
+
|
|
9
|
+
# Use GPU for training
|
|
10
|
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
11
|
+
model = model.to(device)
|
|
12
|
+
|
|
13
|
+
# Wrap data in a data loader
|
|
14
|
+
data_size = len(data)
|
|
15
|
+
NUM_GRAPHS_PER_BATCH = 64
|
|
16
|
+
loader = DataLoader(data[:int(data_size * 0.8)],
|
|
17
|
+
batch_size=NUM_GRAPHS_PER_BATCH, shuffle=True)
|
|
18
|
+
test_loader = DataLoader(data[int(data_size * 0.8):],
|
|
19
|
+
batch_size=NUM_GRAPHS_PER_BATCH, shuffle=True)
|
|
20
|
+
|
|
21
|
+
def train(data):
|
|
22
|
+
# Enumerate over the data
|
|
23
|
+
for batch in loader:
|
|
24
|
+
# Use GPU
|
|
25
|
+
batch.to(device)
|
|
26
|
+
# Reset gradients
|
|
27
|
+
optimizer.zero_grad()
|
|
28
|
+
# Passing the node features and the connection info
|
|
29
|
+
pred, embedding = model(batch.x.float(), batch.edge_index, batch.batch)
|
|
30
|
+
# Calculating the loss and gradients
|
|
31
|
+
loss = torch.sqrt(loss_fn(pred, batch.y))
|
|
32
|
+
loss.backward()
|
|
33
|
+
# Update using the gradients
|
|
34
|
+
optimizer.step()
|
|
35
|
+
return loss, embedding
|
|
36
|
+
|
|
37
|
+
print("Starting training...")
|
|
38
|
+
losses = []
|
|
39
|
+
for epoch in range(2000):
|
|
40
|
+
loss, h = train(data)
|
|
41
|
+
losses.append(loss)
|
|
42
|
+
if epoch % 100 == 0:
|
|
43
|
+
print(f"Epoch {epoch} | Train Loss {loss}")
|
|
44
|
+
# Visualize learning (training loss)
|
|
45
|
+
import seaborn as sns
|
|
46
|
+
losses_float = [float(loss.cpu().detach().numpy()) for loss in losses]
|
|
47
|
+
loss_indices = [i for i,l in enumerate(losses_float)]
|
|
48
|
+
plt = sns.lineplot(loss_indices, losses_float)
|
|
49
|
+
plt
|
|
50
|
+
As result we get something like this:
|
|
@@ -0,0 +1,316 @@
|
|
|
1
|
+
import typing as t
|
|
2
|
+
from os import path as osp
|
|
3
|
+
# from os import path as osp
|
|
4
|
+
from itertools import cycle
|
|
5
|
+
# import datetime
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
import numpy as np
|
|
9
|
+
import pandas as pd
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
# from jupyfuncs.glob import makedirs
|
|
13
|
+
from jupyfuncs.pbar import tnrange, tqdm
|
|
14
|
+
# from hdl.data.dataset.graph.gin import MoleculeDataset
|
|
15
|
+
from hdl.data.dataset.graph.gin import MoleculeDatasetWrapper
|
|
16
|
+
# from hdl.metric_loss.loss import get_lossfunc
|
|
17
|
+
# from hdl.models.utils import save_model
|
|
18
|
+
from .trainer_base import TorchTrainer
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class GINTrainer(TorchTrainer):
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
base_dir,
|
|
25
|
+
data_loader,
|
|
26
|
+
test_loader,
|
|
27
|
+
metrics: t.List[str] = ['rsquared', 'rmse', 'mae'],
|
|
28
|
+
loss_func: str = 'mse',
|
|
29
|
+
model=None,
|
|
30
|
+
model_name=None,
|
|
31
|
+
model_init_args=None,
|
|
32
|
+
ckpt_file=None,
|
|
33
|
+
model_ckpt=None,
|
|
34
|
+
fix_emb=True,
|
|
35
|
+
optimizer=None,
|
|
36
|
+
optimizer_name=None,
|
|
37
|
+
optimizer_kwargs=None,
|
|
38
|
+
device=torch.device('cpu'),
|
|
39
|
+
# logger=None
|
|
40
|
+
) -> None:
|
|
41
|
+
super().__init__(
|
|
42
|
+
base_dir=base_dir,
|
|
43
|
+
data_loader=data_loader,
|
|
44
|
+
test_loader=test_loader,
|
|
45
|
+
metrics=metrics,
|
|
46
|
+
loss_func=loss_func,
|
|
47
|
+
model=model,
|
|
48
|
+
model_name=model_name,
|
|
49
|
+
model_init_args=model_init_args,
|
|
50
|
+
ckpt_file=ckpt_file,
|
|
51
|
+
model_ckpt=model_ckpt,
|
|
52
|
+
optimizer=optimizer,
|
|
53
|
+
optimizer_name=optimizer_name,
|
|
54
|
+
optimizer_kwargs=optimizer_kwargs,
|
|
55
|
+
device=device,
|
|
56
|
+
)
|
|
57
|
+
# self.loss_func = get_lossfunc(self.loss_func)
|
|
58
|
+
# self.metrics = [get_metric(metric) for metric in metrics]
|
|
59
|
+
if fix_emb:
|
|
60
|
+
for gin in self.model.gins:
|
|
61
|
+
for param in gin.parameters():
|
|
62
|
+
param.requires_grad = False
|
|
63
|
+
|
|
64
|
+
def train_a_batch(self, data):
|
|
65
|
+
self.optimizer.zero_grad()
|
|
66
|
+
for i in data[: -1]:
|
|
67
|
+
for j in i:
|
|
68
|
+
j.to(self.device)
|
|
69
|
+
y = data[-1].to(self.device)
|
|
70
|
+
y = y / 100
|
|
71
|
+
|
|
72
|
+
y_pred = self.model(data).flatten()
|
|
73
|
+
|
|
74
|
+
loss = self.loss_func(y_pred, y)
|
|
75
|
+
|
|
76
|
+
loss.backward()
|
|
77
|
+
self.optimizer.step()
|
|
78
|
+
|
|
79
|
+
return loss
|
|
80
|
+
|
|
81
|
+
def load_ckpt(self):
|
|
82
|
+
self.model.load_ckpt()
|
|
83
|
+
|
|
84
|
+
def train_an_epoch(
|
|
85
|
+
self,
|
|
86
|
+
):
|
|
87
|
+
for i, (data, test_data) in enumerate(
|
|
88
|
+
zip(
|
|
89
|
+
self.data_loader,
|
|
90
|
+
cycle(self.test_loader)
|
|
91
|
+
)
|
|
92
|
+
):
|
|
93
|
+
loss = self.train_a_batch(data)
|
|
94
|
+
self.losses.append(loss.item())
|
|
95
|
+
self.n_iter += 1
|
|
96
|
+
self.logger.add_scalar(
|
|
97
|
+
'train_loss',
|
|
98
|
+
loss.item(),
|
|
99
|
+
global_step=self.n_iter
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
if self.n_iter % 10 == 0:
|
|
103
|
+
for i in test_data[: -1]:
|
|
104
|
+
for j in i:
|
|
105
|
+
j.to(self.device)
|
|
106
|
+
y = test_data[-1].to(self.device)
|
|
107
|
+
y = y / 100
|
|
108
|
+
|
|
109
|
+
y_pred = self.model(test_data).flatten()
|
|
110
|
+
valid_loss = self.loss_func(y_pred, y)
|
|
111
|
+
|
|
112
|
+
y_pred = y_pred.cpu().detach().numpy()
|
|
113
|
+
y = y.cpu().detach().numpy()
|
|
114
|
+
|
|
115
|
+
self.logger.add_scalar(
|
|
116
|
+
'valid_loss',
|
|
117
|
+
valid_loss.item(),
|
|
118
|
+
global_step=self.n_iter
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
for metric_name, metric in zip(
|
|
122
|
+
self.metric_names,
|
|
123
|
+
self.metrics
|
|
124
|
+
):
|
|
125
|
+
self.logger.add_scalar(
|
|
126
|
+
metric_name,
|
|
127
|
+
metric(y_pred, y),
|
|
128
|
+
global_step=self.n_iter
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
self.save()
|
|
132
|
+
self.epoch_id += 1
|
|
133
|
+
|
|
134
|
+
def train(self, num_epochs):
|
|
135
|
+
# dir_name = datetime.now().strftime('%b%d_%H-%M-%S')
|
|
136
|
+
# makedirs(osp.join(self.base_dir, dir_name))
|
|
137
|
+
|
|
138
|
+
for _ in tnrange(num_epochs):
|
|
139
|
+
self.train_an_epoch()
|
|
140
|
+
|
|
141
|
+
def predict(self, data_loader):
|
|
142
|
+
result_list = []
|
|
143
|
+
for data in tqdm(data_loader):
|
|
144
|
+
for i in data[: -1]:
|
|
145
|
+
for j in i:
|
|
146
|
+
j.to(self.device)
|
|
147
|
+
# print(data[0][0].x.device)
|
|
148
|
+
# for param in self.model.parameters():
|
|
149
|
+
# print(param.device)
|
|
150
|
+
# break
|
|
151
|
+
y_pred = self.model(data).flatten()
|
|
152
|
+
result_list.append(y_pred.cpu().detach().numpy())
|
|
153
|
+
results = np.hstack(result_list)
|
|
154
|
+
return results
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def engine(
|
|
158
|
+
base_dir,
|
|
159
|
+
data_path,
|
|
160
|
+
test_data_path,
|
|
161
|
+
batch_size=128,
|
|
162
|
+
num_workers=64,
|
|
163
|
+
model_name='GINMLPR',
|
|
164
|
+
num_layers=5,
|
|
165
|
+
emb_dim=300,
|
|
166
|
+
feat_dim=512,
|
|
167
|
+
out_dim=1,
|
|
168
|
+
drop_ratio=0.0,
|
|
169
|
+
pool='mean',
|
|
170
|
+
ckpt_file=None,
|
|
171
|
+
fix_emb: bool = False,
|
|
172
|
+
device='cuda:1',
|
|
173
|
+
num_epochs=300,
|
|
174
|
+
optimizer_name='adam',
|
|
175
|
+
lr=0.001,
|
|
176
|
+
file_type: str = 'csv',
|
|
177
|
+
smiles_col_names: t.List = [],
|
|
178
|
+
y_col_name: str = None, # "yield (%)",
|
|
179
|
+
loss_func: str = 'mse',
|
|
180
|
+
metrics: t.List[str] = ['rsquared', 'rmse', 'mae'],
|
|
181
|
+
):
|
|
182
|
+
model_init_args = {
|
|
183
|
+
"num_layer": num_layers,
|
|
184
|
+
"emb_dim": emb_dim,
|
|
185
|
+
"feat_dim": feat_dim,
|
|
186
|
+
"out_dim": out_dim,
|
|
187
|
+
"drop_ratio": drop_ratio,
|
|
188
|
+
"pool": pool,
|
|
189
|
+
"ckpt_file": ckpt_file,
|
|
190
|
+
"num_smiles": len(smiles_col_names),
|
|
191
|
+
}
|
|
192
|
+
wrapper = MoleculeDatasetWrapper(
|
|
193
|
+
batch_size=batch_size,
|
|
194
|
+
num_workers=num_workers,
|
|
195
|
+
valid_size=0,
|
|
196
|
+
data_path=data_path,
|
|
197
|
+
file_type=file_type,
|
|
198
|
+
smi_col_names=smiles_col_names,
|
|
199
|
+
y_col_name=y_col_name
|
|
200
|
+
)
|
|
201
|
+
test_wrapper = MoleculeDatasetWrapper(
|
|
202
|
+
batch_size=batch_size,
|
|
203
|
+
num_workers=num_workers,
|
|
204
|
+
valid_size=0,
|
|
205
|
+
data_path=test_data_path,
|
|
206
|
+
file_type=file_type,
|
|
207
|
+
smi_col_names=smiles_col_names,
|
|
208
|
+
y_col_name=y_col_name
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
data_loader = wrapper.get_test_loader(
|
|
212
|
+
shuffle=True
|
|
213
|
+
)
|
|
214
|
+
test_loader = test_wrapper.get_test_loader(
|
|
215
|
+
shuffle=False
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
trainer = GINTrainer(
|
|
219
|
+
base_dir=base_dir,
|
|
220
|
+
model_name=model_name,
|
|
221
|
+
model_init_args=model_init_args,
|
|
222
|
+
optimizer_name=optimizer_name,
|
|
223
|
+
ckpt_file=ckpt_file,
|
|
224
|
+
fix_emb=fix_emb,
|
|
225
|
+
optimizer_kwargs={"lr": lr},
|
|
226
|
+
data_loader=data_loader,
|
|
227
|
+
test_loader=test_loader,
|
|
228
|
+
metrics=metrics,
|
|
229
|
+
loss_func=loss_func,
|
|
230
|
+
device=device
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
trainer.train(num_epochs=num_epochs)
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
def predict(
|
|
237
|
+
base_dir,
|
|
238
|
+
data_path,
|
|
239
|
+
batch_size=128,
|
|
240
|
+
num_workers=64,
|
|
241
|
+
model_name='GINMLPR',
|
|
242
|
+
num_layers=5,
|
|
243
|
+
emb_dim=300,
|
|
244
|
+
feat_dim=512,
|
|
245
|
+
out_dim=1,
|
|
246
|
+
drop_ratio=0.0,
|
|
247
|
+
pool='mean',
|
|
248
|
+
ckpt_file=None,
|
|
249
|
+
model_ckpt=None,
|
|
250
|
+
device='cuda:1',
|
|
251
|
+
file_type: str = 'csv',
|
|
252
|
+
smiles_col_names: t.List = [],
|
|
253
|
+
y_col_name: str = None, # "yield (%)",
|
|
254
|
+
metrics: t.List[str] = ['rsquared', 'rmse', 'mae'],
|
|
255
|
+
):
|
|
256
|
+
model_init_args = {
|
|
257
|
+
"num_layer": num_layers,
|
|
258
|
+
"emb_dim": emb_dim,
|
|
259
|
+
"feat_dim": feat_dim,
|
|
260
|
+
"out_dim": out_dim,
|
|
261
|
+
"drop_ratio": drop_ratio,
|
|
262
|
+
"pool": pool,
|
|
263
|
+
"ckpt_file": ckpt_file,
|
|
264
|
+
"num_smiles": len(smiles_col_names),
|
|
265
|
+
}
|
|
266
|
+
wrapper = MoleculeDatasetWrapper(
|
|
267
|
+
batch_size=batch_size,
|
|
268
|
+
num_workers=num_workers,
|
|
269
|
+
valid_size=0,
|
|
270
|
+
data_path=data_path,
|
|
271
|
+
file_type=file_type,
|
|
272
|
+
smi_col_names=smiles_col_names,
|
|
273
|
+
y_col_name=y_col_name
|
|
274
|
+
)
|
|
275
|
+
data_loader = wrapper.get_test_loader(
|
|
276
|
+
shuffle=False
|
|
277
|
+
)
|
|
278
|
+
trainer = GINTrainer(
|
|
279
|
+
base_dir=base_dir,
|
|
280
|
+
model_name=model_name,
|
|
281
|
+
model_init_args=model_init_args,
|
|
282
|
+
model_ckpt=model_ckpt,
|
|
283
|
+
data_loader=data_loader,
|
|
284
|
+
test_loader=None,
|
|
285
|
+
metrics=metrics,
|
|
286
|
+
device=device
|
|
287
|
+
)
|
|
288
|
+
metric_list = trainer.metrics
|
|
289
|
+
trainer.load(ckpt_file=model_ckpt)
|
|
290
|
+
trainer.model.eval()
|
|
291
|
+
results = trainer.predict(data_loader)
|
|
292
|
+
|
|
293
|
+
df = pd.read_csv(data_path)
|
|
294
|
+
df['pred'] = results
|
|
295
|
+
df.to_csv(
|
|
296
|
+
osp.join(base_dir, 'pred.csv'),
|
|
297
|
+
index=False
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
if y_col_name is not None:
|
|
301
|
+
metrics_df = pd.DataFrame()
|
|
302
|
+
y = df[y_col_name].array / 100
|
|
303
|
+
for metric_name, metric in zip(
|
|
304
|
+
metrics,
|
|
305
|
+
metric_list
|
|
306
|
+
):
|
|
307
|
+
metrics_df[metric_name] = np.array([metric(
|
|
308
|
+
y, results
|
|
309
|
+
)])
|
|
310
|
+
metrics_df.to_csv(
|
|
311
|
+
osp.join(
|
|
312
|
+
base_dir, 'metrics.csv'
|
|
313
|
+
),
|
|
314
|
+
index=False
|
|
315
|
+
)
|
|
316
|
+
|