hjxdl 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hdl/__init__.py +0 -0
- hdl/_version.py +16 -0
- hdl/args/__init__.py +0 -0
- hdl/args/loss_args.py +5 -0
- hdl/controllers/__init__.py +0 -0
- hdl/controllers/al/__init__.py +0 -0
- hdl/controllers/al/al.py +0 -0
- hdl/controllers/al/dispatcher.py +0 -0
- hdl/controllers/al/feedback.py +0 -0
- hdl/controllers/explain/__init__.py +0 -0
- hdl/controllers/explain/shapley.py +293 -0
- hdl/controllers/explain/subgraphx.py +865 -0
- hdl/controllers/train/__init__.py +0 -0
- hdl/controllers/train/rxn_train.py +219 -0
- hdl/controllers/train/train.py +50 -0
- hdl/controllers/train/train_ginet.py +316 -0
- hdl/controllers/train/trainer_base.py +155 -0
- hdl/controllers/train/trainer_iterative.py +389 -0
- hdl/data/__init__.py +0 -0
- hdl/data/dataset/__init__.py +0 -0
- hdl/data/dataset/base_dataset.py +98 -0
- hdl/data/dataset/fp/__init__.py +0 -0
- hdl/data/dataset/fp/fp_dataset.py +122 -0
- hdl/data/dataset/graph/__init__.py +0 -0
- hdl/data/dataset/graph/chiral.py +62 -0
- hdl/data/dataset/graph/gin.py +255 -0
- hdl/data/dataset/graph/molnet.py +362 -0
- hdl/data/dataset/loaders/__init__.py +0 -0
- hdl/data/dataset/loaders/chiral_graph.py +71 -0
- hdl/data/dataset/loaders/collate_funcs/__init__.py +0 -0
- hdl/data/dataset/loaders/collate_funcs/fp.py +56 -0
- hdl/data/dataset/loaders/collate_funcs/rxn.py +40 -0
- hdl/data/dataset/loaders/general.py +23 -0
- hdl/data/dataset/loaders/spliter.py +86 -0
- hdl/data/dataset/samplers/__init__.py +0 -0
- hdl/data/dataset/samplers/chiral.py +19 -0
- hdl/data/dataset/seq/__init__.py +0 -0
- hdl/data/dataset/seq/rxn_dataset.py +61 -0
- hdl/data/dataset/utils.py +31 -0
- hdl/data/to_mols.py +0 -0
- hdl/features/__init__.py +0 -0
- hdl/features/fp/__init__.py +0 -0
- hdl/features/fp/features_generators.py +235 -0
- hdl/features/graph/__init__.py +0 -0
- hdl/features/graph/featurization.py +297 -0
- hdl/features/utils/__init__.py +0 -0
- hdl/features/utils/utils.py +111 -0
- hdl/layers/__init__.py +0 -0
- hdl/layers/general/__init__.py +0 -0
- hdl/layers/general/gp.py +14 -0
- hdl/layers/general/linear.py +641 -0
- hdl/layers/graph/__init__.py +0 -0
- hdl/layers/graph/chiral_graph.py +230 -0
- hdl/layers/graph/gcn.py +16 -0
- hdl/layers/graph/gin.py +45 -0
- hdl/layers/graph/tetra.py +158 -0
- hdl/layers/graph/transformer.py +188 -0
- hdl/layers/sequential/__init__.py +0 -0
- hdl/metric_loss/__init__.py +0 -0
- hdl/metric_loss/loss.py +79 -0
- hdl/metric_loss/metric.py +178 -0
- hdl/metric_loss/multi_label.py +42 -0
- hdl/metric_loss/nt_xent.py +65 -0
- hdl/models/__init__.py +0 -0
- hdl/models/chiral_gnn.py +176 -0
- hdl/models/fast_transformer.py +234 -0
- hdl/models/ginet.py +189 -0
- hdl/models/linear.py +137 -0
- hdl/models/model_dict.py +18 -0
- hdl/models/norm_flows.py +33 -0
- hdl/models/optim_dict.py +16 -0
- hdl/models/rxn.py +63 -0
- hdl/models/utils.py +83 -0
- hdl/ops/__init__.py +0 -0
- hdl/ops/utils.py +42 -0
- hdl/optims/__init__.py +0 -0
- hdl/optims/nadam.py +86 -0
- hdl/utils/__init__.py +0 -0
- hdl/utils/chemical_tools/__init__.py +2 -0
- hdl/utils/chemical_tools/query_info.py +149 -0
- hdl/utils/chemical_tools/sdf.py +20 -0
- hdl/utils/database_tools/__init__.py +0 -0
- hdl/utils/database_tools/connect.py +28 -0
- hdl/utils/general/__init__.py +0 -0
- hdl/utils/general/glob.py +21 -0
- hdl/utils/schedulers/__init__.py +0 -0
- hdl/utils/schedulers/norm_lr.py +108 -0
- hjxdl-0.0.1.dist-info/METADATA +19 -0
- hjxdl-0.0.1.dist-info/RECORD +91 -0
- hjxdl-0.0.1.dist-info/WHEEL +5 -0
- hjxdl-0.0.1.dist-info/top_level.txt +1 -0
File without changes
|
@@ -0,0 +1,219 @@
|
|
1
|
+
from os import path as osp
|
2
|
+
import typing as t
|
3
|
+
|
4
|
+
import torch
|
5
|
+
from torch import nn
|
6
|
+
|
7
|
+
from hdl.models.rxn import build_rxn_mu
|
8
|
+
from hdl.models.utils import load_model, save_model
|
9
|
+
from hdl.data.dataset.seq.rxn_dataset import RXNCSVDataset
|
10
|
+
from hdl.data.dataset.loaders.rxn_loader import RXNLoader
|
11
|
+
from hdl.metric_loss.loss import mtmc_loss
|
12
|
+
from jupyfuncs.pbar import tnrange, tqdm
|
13
|
+
from jupyfuncs.glob import makedirs
|
14
|
+
# from hdl.optims.nadam import Nadam
|
15
|
+
from torch.optim import Adam
|
16
|
+
# from .trainer_base import TorchTrainer
|
17
|
+
|
18
|
+
|
19
|
+
def train_a_batch(
|
20
|
+
model,
|
21
|
+
batch_data,
|
22
|
+
loss_func,
|
23
|
+
optimizer,
|
24
|
+
device,
|
25
|
+
individual,
|
26
|
+
**kwargs
|
27
|
+
):
|
28
|
+
optimizer.zero_grad()
|
29
|
+
|
30
|
+
X = [x.to(device) for x in batch_data[0]]
|
31
|
+
y = batch_data[1].T.to(device)
|
32
|
+
|
33
|
+
y_preds = model(X)
|
34
|
+
loss = mtmc_loss(
|
35
|
+
y_preds,
|
36
|
+
y,
|
37
|
+
loss_func,
|
38
|
+
individual=individual, **kwargs
|
39
|
+
)
|
40
|
+
|
41
|
+
if not individual:
|
42
|
+
final_loss = loss
|
43
|
+
individual_losses = []
|
44
|
+
else:
|
45
|
+
final_loss = loss[0]
|
46
|
+
individual_losses = loss[1]
|
47
|
+
|
48
|
+
final_loss.backward()
|
49
|
+
optimizer.step()
|
50
|
+
|
51
|
+
return final_loss, individual_losses
|
52
|
+
|
53
|
+
|
54
|
+
def train_an_epoch(
|
55
|
+
base_dir: str,
|
56
|
+
model,
|
57
|
+
data_loader,
|
58
|
+
epoch_id: int,
|
59
|
+
loss_func,
|
60
|
+
optimizer,
|
61
|
+
device,
|
62
|
+
num_warm_epochs: int = 0,
|
63
|
+
individual: bool = True,
|
64
|
+
**kwargs
|
65
|
+
):
|
66
|
+
if epoch_id < num_warm_epochs:
|
67
|
+
model.freeze_encoder = True
|
68
|
+
else:
|
69
|
+
model.freeze_encoder = False
|
70
|
+
|
71
|
+
for batch in tqdm(data_loader):
|
72
|
+
loss, individual_losses = train_a_batch(
|
73
|
+
model=model,
|
74
|
+
batch_data=batch,
|
75
|
+
loss_func=loss_func,
|
76
|
+
optimizer=optimizer,
|
77
|
+
device=device,
|
78
|
+
individual=individual,
|
79
|
+
**kwargs
|
80
|
+
)
|
81
|
+
with open(
|
82
|
+
osp.join(base_dir, 'loss.log'),
|
83
|
+
'a'
|
84
|
+
) as f:
|
85
|
+
f.write(str(loss.item()))
|
86
|
+
f.write('\t')
|
87
|
+
for individual_loss in individual_losses:
|
88
|
+
f.write(str(individual_loss))
|
89
|
+
f.write('\t')
|
90
|
+
f.write('\n')
|
91
|
+
|
92
|
+
ckpt_file = osp.join(
|
93
|
+
base_dir,
|
94
|
+
f'model.{epoch_id}.ckpt'
|
95
|
+
)
|
96
|
+
save_model(
|
97
|
+
model=model,
|
98
|
+
save_dir=ckpt_file,
|
99
|
+
epoch=epoch_id,
|
100
|
+
optimizer=optimizer,
|
101
|
+
loss=loss,
|
102
|
+
)
|
103
|
+
|
104
|
+
|
105
|
+
def train_rxn(
|
106
|
+
base_dir,
|
107
|
+
model,
|
108
|
+
num_epochs,
|
109
|
+
loss_func,
|
110
|
+
data_loader,
|
111
|
+
optimizer,
|
112
|
+
device,
|
113
|
+
num_warm_epochs: int = 10,
|
114
|
+
ckpt_file: str = None,
|
115
|
+
individual: bool = True,
|
116
|
+
**kwargs
|
117
|
+
):
|
118
|
+
|
119
|
+
epoch = 0
|
120
|
+
if ckpt_file is not None:
|
121
|
+
|
122
|
+
model, optimizer, epoch, _ = load_model(
|
123
|
+
ckpt_file,
|
124
|
+
model=model,
|
125
|
+
optimizer=optimizer,
|
126
|
+
train=True,
|
127
|
+
device=device,
|
128
|
+
)
|
129
|
+
|
130
|
+
for epoch_id in tnrange(num_epochs):
|
131
|
+
|
132
|
+
train_an_epoch(
|
133
|
+
base_dir=base_dir,
|
134
|
+
model=model,
|
135
|
+
data_loader=data_loader,
|
136
|
+
epoch_id=epoch + epoch_id,
|
137
|
+
loss_func=loss_func,
|
138
|
+
optimizer=optimizer,
|
139
|
+
num_warm_epochs=num_warm_epochs,
|
140
|
+
device=device,
|
141
|
+
individual=individual,
|
142
|
+
**kwargs
|
143
|
+
)
|
144
|
+
|
145
|
+
|
146
|
+
def rxn_engine(
|
147
|
+
base_dir: str,
|
148
|
+
csv_file: str,
|
149
|
+
splitter: str,
|
150
|
+
smiles_col: str,
|
151
|
+
hard: bool = False,
|
152
|
+
num_epochs: int = 20,
|
153
|
+
target_cols: t.List = [],
|
154
|
+
nums_classes: t.List = [],
|
155
|
+
loss_func: str = 'ce',
|
156
|
+
num_warm_epochs: int = 10,
|
157
|
+
batch_size: int = 128,
|
158
|
+
hidden_size: int = 128,
|
159
|
+
lr: float = 0.01,
|
160
|
+
num_hidden_layers: int = 10,
|
161
|
+
shuffle: bool = True,
|
162
|
+
num_workers: int = 12,
|
163
|
+
dim=-1,
|
164
|
+
out_act='softmax',
|
165
|
+
device_id: int = 0,
|
166
|
+
individual: bool = True,
|
167
|
+
**kwargs
|
168
|
+
):
|
169
|
+
|
170
|
+
base_dir = osp.abspath(base_dir)
|
171
|
+
makedirs(base_dir)
|
172
|
+
model, device = build_rxn_mu(
|
173
|
+
nums_classes=nums_classes,
|
174
|
+
hard=hard,
|
175
|
+
hidden_size=hidden_size,
|
176
|
+
nums_hidden_layers=num_hidden_layers,
|
177
|
+
dim=dim,
|
178
|
+
out_act=out_act,
|
179
|
+
device_id=device_id
|
180
|
+
)
|
181
|
+
if torch.cuda.device_count() > 1:
|
182
|
+
model = nn.DataParallel(model)
|
183
|
+
model.train()
|
184
|
+
|
185
|
+
params = [{
|
186
|
+
'params': model.parameters(),
|
187
|
+
'lr': lr,
|
188
|
+
'weight_decay': 0
|
189
|
+
}]
|
190
|
+
optimizer = Adam(params)
|
191
|
+
|
192
|
+
dataset = RXNCSVDataset(
|
193
|
+
csv_file=csv_file,
|
194
|
+
splitter=splitter,
|
195
|
+
smiles_col=smiles_col,
|
196
|
+
target_cols=target_cols,
|
197
|
+
)
|
198
|
+
data_loader = RXNLoader(
|
199
|
+
dataset=dataset,
|
200
|
+
batch_size=batch_size,
|
201
|
+
shuffle=shuffle,
|
202
|
+
num_workers=num_workers
|
203
|
+
)
|
204
|
+
|
205
|
+
train_rxn(
|
206
|
+
base_dir=base_dir,
|
207
|
+
model=model,
|
208
|
+
num_epochs=num_epochs,
|
209
|
+
loss_func=loss_func,
|
210
|
+
data_loader=data_loader,
|
211
|
+
optimizer=optimizer,
|
212
|
+
device=device,
|
213
|
+
num_warm_epochs=num_warm_epochs,
|
214
|
+
ckpt_file=None,
|
215
|
+
individual=individual,
|
216
|
+
**kwargs
|
217
|
+
)
|
218
|
+
|
219
|
+
|
@@ -0,0 +1,50 @@
|
|
1
|
+
from torch_geometric.data import DataLoader
|
2
|
+
import warnings
|
3
|
+
warnings.filterwarnings("ignore")
|
4
|
+
|
5
|
+
# Root mean squared error
|
6
|
+
loss_fn = torch.nn.MSELoss()
|
7
|
+
optimizer = torch.optim.Adam(model.parameters(), lr=0.0007)
|
8
|
+
|
9
|
+
# Use GPU for training
|
10
|
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
11
|
+
model = model.to(device)
|
12
|
+
|
13
|
+
# Wrap data in a data loader
|
14
|
+
data_size = len(data)
|
15
|
+
NUM_GRAPHS_PER_BATCH = 64
|
16
|
+
loader = DataLoader(data[:int(data_size * 0.8)],
|
17
|
+
batch_size=NUM_GRAPHS_PER_BATCH, shuffle=True)
|
18
|
+
test_loader = DataLoader(data[int(data_size * 0.8):],
|
19
|
+
batch_size=NUM_GRAPHS_PER_BATCH, shuffle=True)
|
20
|
+
|
21
|
+
def train(data):
|
22
|
+
# Enumerate over the data
|
23
|
+
for batch in loader:
|
24
|
+
# Use GPU
|
25
|
+
batch.to(device)
|
26
|
+
# Reset gradients
|
27
|
+
optimizer.zero_grad()
|
28
|
+
# Passing the node features and the connection info
|
29
|
+
pred, embedding = model(batch.x.float(), batch.edge_index, batch.batch)
|
30
|
+
# Calculating the loss and gradients
|
31
|
+
loss = torch.sqrt(loss_fn(pred, batch.y))
|
32
|
+
loss.backward()
|
33
|
+
# Update using the gradients
|
34
|
+
optimizer.step()
|
35
|
+
return loss, embedding
|
36
|
+
|
37
|
+
print("Starting training...")
|
38
|
+
losses = []
|
39
|
+
for epoch in range(2000):
|
40
|
+
loss, h = train(data)
|
41
|
+
losses.append(loss)
|
42
|
+
if epoch % 100 == 0:
|
43
|
+
print(f"Epoch {epoch} | Train Loss {loss}")
|
44
|
+
# Visualize learning (training loss)
|
45
|
+
import seaborn as sns
|
46
|
+
losses_float = [float(loss.cpu().detach().numpy()) for loss in losses]
|
47
|
+
loss_indices = [i for i,l in enumerate(losses_float)]
|
48
|
+
plt = sns.lineplot(loss_indices, losses_float)
|
49
|
+
plt
|
50
|
+
As result we get something like this:
|
@@ -0,0 +1,316 @@
|
|
1
|
+
import typing as t
|
2
|
+
from os import path as osp
|
3
|
+
# from os import path as osp
|
4
|
+
from itertools import cycle
|
5
|
+
# import datetime
|
6
|
+
|
7
|
+
import torch
|
8
|
+
import numpy as np
|
9
|
+
import pandas as pd
|
10
|
+
|
11
|
+
|
12
|
+
# from jupyfuncs.glob import makedirs
|
13
|
+
from jupyfuncs.pbar import tnrange, tqdm
|
14
|
+
# from hdl.data.dataset.graph.gin import MoleculeDataset
|
15
|
+
from hdl.data.dataset.graph.gin import MoleculeDatasetWrapper
|
16
|
+
# from hdl.metric_loss.loss import get_lossfunc
|
17
|
+
# from hdl.models.utils import save_model
|
18
|
+
from .trainer_base import TorchTrainer
|
19
|
+
|
20
|
+
|
21
|
+
class GINTrainer(TorchTrainer):
|
22
|
+
def __init__(
|
23
|
+
self,
|
24
|
+
base_dir,
|
25
|
+
data_loader,
|
26
|
+
test_loader,
|
27
|
+
metrics: t.List[str] = ['rsquared', 'rmse', 'mae'],
|
28
|
+
loss_func: str = 'mse',
|
29
|
+
model=None,
|
30
|
+
model_name=None,
|
31
|
+
model_init_args=None,
|
32
|
+
ckpt_file=None,
|
33
|
+
model_ckpt=None,
|
34
|
+
fix_emb=True,
|
35
|
+
optimizer=None,
|
36
|
+
optimizer_name=None,
|
37
|
+
optimizer_kwargs=None,
|
38
|
+
device=torch.device('cpu'),
|
39
|
+
# logger=None
|
40
|
+
) -> None:
|
41
|
+
super().__init__(
|
42
|
+
base_dir=base_dir,
|
43
|
+
data_loader=data_loader,
|
44
|
+
test_loader=test_loader,
|
45
|
+
metrics=metrics,
|
46
|
+
loss_func=loss_func,
|
47
|
+
model=model,
|
48
|
+
model_name=model_name,
|
49
|
+
model_init_args=model_init_args,
|
50
|
+
ckpt_file=ckpt_file,
|
51
|
+
model_ckpt=model_ckpt,
|
52
|
+
optimizer=optimizer,
|
53
|
+
optimizer_name=optimizer_name,
|
54
|
+
optimizer_kwargs=optimizer_kwargs,
|
55
|
+
device=device,
|
56
|
+
)
|
57
|
+
# self.loss_func = get_lossfunc(self.loss_func)
|
58
|
+
# self.metrics = [get_metric(metric) for metric in metrics]
|
59
|
+
if fix_emb:
|
60
|
+
for gin in self.model.gins:
|
61
|
+
for param in gin.parameters():
|
62
|
+
param.requires_grad = False
|
63
|
+
|
64
|
+
def train_a_batch(self, data):
|
65
|
+
self.optimizer.zero_grad()
|
66
|
+
for i in data[: -1]:
|
67
|
+
for j in i:
|
68
|
+
j.to(self.device)
|
69
|
+
y = data[-1].to(self.device)
|
70
|
+
y = y / 100
|
71
|
+
|
72
|
+
y_pred = self.model(data).flatten()
|
73
|
+
|
74
|
+
loss = self.loss_func(y_pred, y)
|
75
|
+
|
76
|
+
loss.backward()
|
77
|
+
self.optimizer.step()
|
78
|
+
|
79
|
+
return loss
|
80
|
+
|
81
|
+
def load_ckpt(self):
|
82
|
+
self.model.load_ckpt()
|
83
|
+
|
84
|
+
def train_an_epoch(
|
85
|
+
self,
|
86
|
+
):
|
87
|
+
for i, (data, test_data) in enumerate(
|
88
|
+
zip(
|
89
|
+
self.data_loader,
|
90
|
+
cycle(self.test_loader)
|
91
|
+
)
|
92
|
+
):
|
93
|
+
loss = self.train_a_batch(data)
|
94
|
+
self.losses.append(loss.item())
|
95
|
+
self.n_iter += 1
|
96
|
+
self.logger.add_scalar(
|
97
|
+
'train_loss',
|
98
|
+
loss.item(),
|
99
|
+
global_step=self.n_iter
|
100
|
+
)
|
101
|
+
|
102
|
+
if self.n_iter % 10 == 0:
|
103
|
+
for i in test_data[: -1]:
|
104
|
+
for j in i:
|
105
|
+
j.to(self.device)
|
106
|
+
y = test_data[-1].to(self.device)
|
107
|
+
y = y / 100
|
108
|
+
|
109
|
+
y_pred = self.model(test_data).flatten()
|
110
|
+
valid_loss = self.loss_func(y_pred, y)
|
111
|
+
|
112
|
+
y_pred = y_pred.cpu().detach().numpy()
|
113
|
+
y = y.cpu().detach().numpy()
|
114
|
+
|
115
|
+
self.logger.add_scalar(
|
116
|
+
'valid_loss',
|
117
|
+
valid_loss.item(),
|
118
|
+
global_step=self.n_iter
|
119
|
+
)
|
120
|
+
|
121
|
+
for metric_name, metric in zip(
|
122
|
+
self.metric_names,
|
123
|
+
self.metrics
|
124
|
+
):
|
125
|
+
self.logger.add_scalar(
|
126
|
+
metric_name,
|
127
|
+
metric(y_pred, y),
|
128
|
+
global_step=self.n_iter
|
129
|
+
)
|
130
|
+
|
131
|
+
self.save()
|
132
|
+
self.epoch_id += 1
|
133
|
+
|
134
|
+
def train(self, num_epochs):
|
135
|
+
# dir_name = datetime.now().strftime('%b%d_%H-%M-%S')
|
136
|
+
# makedirs(osp.join(self.base_dir, dir_name))
|
137
|
+
|
138
|
+
for _ in tnrange(num_epochs):
|
139
|
+
self.train_an_epoch()
|
140
|
+
|
141
|
+
def predict(self, data_loader):
|
142
|
+
result_list = []
|
143
|
+
for data in tqdm(data_loader):
|
144
|
+
for i in data[: -1]:
|
145
|
+
for j in i:
|
146
|
+
j.to(self.device)
|
147
|
+
# print(data[0][0].x.device)
|
148
|
+
# for param in self.model.parameters():
|
149
|
+
# print(param.device)
|
150
|
+
# break
|
151
|
+
y_pred = self.model(data).flatten()
|
152
|
+
result_list.append(y_pred.cpu().detach().numpy())
|
153
|
+
results = np.hstack(result_list)
|
154
|
+
return results
|
155
|
+
|
156
|
+
|
157
|
+
def engine(
|
158
|
+
base_dir,
|
159
|
+
data_path,
|
160
|
+
test_data_path,
|
161
|
+
batch_size=128,
|
162
|
+
num_workers=64,
|
163
|
+
model_name='GINMLPR',
|
164
|
+
num_layers=5,
|
165
|
+
emb_dim=300,
|
166
|
+
feat_dim=512,
|
167
|
+
out_dim=1,
|
168
|
+
drop_ratio=0.0,
|
169
|
+
pool='mean',
|
170
|
+
ckpt_file=None,
|
171
|
+
fix_emb: bool = False,
|
172
|
+
device='cuda:1',
|
173
|
+
num_epochs=300,
|
174
|
+
optimizer_name='adam',
|
175
|
+
lr=0.001,
|
176
|
+
file_type: str = 'csv',
|
177
|
+
smiles_col_names: t.List = [],
|
178
|
+
y_col_name: str = None, # "yield (%)",
|
179
|
+
loss_func: str = 'mse',
|
180
|
+
metrics: t.List[str] = ['rsquared', 'rmse', 'mae'],
|
181
|
+
):
|
182
|
+
model_init_args = {
|
183
|
+
"num_layer": num_layers,
|
184
|
+
"emb_dim": emb_dim,
|
185
|
+
"feat_dim": feat_dim,
|
186
|
+
"out_dim": out_dim,
|
187
|
+
"drop_ratio": drop_ratio,
|
188
|
+
"pool": pool,
|
189
|
+
"ckpt_file": ckpt_file,
|
190
|
+
"num_smiles": len(smiles_col_names),
|
191
|
+
}
|
192
|
+
wrapper = MoleculeDatasetWrapper(
|
193
|
+
batch_size=batch_size,
|
194
|
+
num_workers=num_workers,
|
195
|
+
valid_size=0,
|
196
|
+
data_path=data_path,
|
197
|
+
file_type=file_type,
|
198
|
+
smi_col_names=smiles_col_names,
|
199
|
+
y_col_name=y_col_name
|
200
|
+
)
|
201
|
+
test_wrapper = MoleculeDatasetWrapper(
|
202
|
+
batch_size=batch_size,
|
203
|
+
num_workers=num_workers,
|
204
|
+
valid_size=0,
|
205
|
+
data_path=test_data_path,
|
206
|
+
file_type=file_type,
|
207
|
+
smi_col_names=smiles_col_names,
|
208
|
+
y_col_name=y_col_name
|
209
|
+
)
|
210
|
+
|
211
|
+
data_loader = wrapper.get_test_loader(
|
212
|
+
shuffle=True
|
213
|
+
)
|
214
|
+
test_loader = test_wrapper.get_test_loader(
|
215
|
+
shuffle=False
|
216
|
+
)
|
217
|
+
|
218
|
+
trainer = GINTrainer(
|
219
|
+
base_dir=base_dir,
|
220
|
+
model_name=model_name,
|
221
|
+
model_init_args=model_init_args,
|
222
|
+
optimizer_name=optimizer_name,
|
223
|
+
ckpt_file=ckpt_file,
|
224
|
+
fix_emb=fix_emb,
|
225
|
+
optimizer_kwargs={"lr": lr},
|
226
|
+
data_loader=data_loader,
|
227
|
+
test_loader=test_loader,
|
228
|
+
metrics=metrics,
|
229
|
+
loss_func=loss_func,
|
230
|
+
device=device
|
231
|
+
)
|
232
|
+
|
233
|
+
trainer.train(num_epochs=num_epochs)
|
234
|
+
|
235
|
+
|
236
|
+
def predict(
|
237
|
+
base_dir,
|
238
|
+
data_path,
|
239
|
+
batch_size=128,
|
240
|
+
num_workers=64,
|
241
|
+
model_name='GINMLPR',
|
242
|
+
num_layers=5,
|
243
|
+
emb_dim=300,
|
244
|
+
feat_dim=512,
|
245
|
+
out_dim=1,
|
246
|
+
drop_ratio=0.0,
|
247
|
+
pool='mean',
|
248
|
+
ckpt_file=None,
|
249
|
+
model_ckpt=None,
|
250
|
+
device='cuda:1',
|
251
|
+
file_type: str = 'csv',
|
252
|
+
smiles_col_names: t.List = [],
|
253
|
+
y_col_name: str = None, # "yield (%)",
|
254
|
+
metrics: t.List[str] = ['rsquared', 'rmse', 'mae'],
|
255
|
+
):
|
256
|
+
model_init_args = {
|
257
|
+
"num_layer": num_layers,
|
258
|
+
"emb_dim": emb_dim,
|
259
|
+
"feat_dim": feat_dim,
|
260
|
+
"out_dim": out_dim,
|
261
|
+
"drop_ratio": drop_ratio,
|
262
|
+
"pool": pool,
|
263
|
+
"ckpt_file": ckpt_file,
|
264
|
+
"num_smiles": len(smiles_col_names),
|
265
|
+
}
|
266
|
+
wrapper = MoleculeDatasetWrapper(
|
267
|
+
batch_size=batch_size,
|
268
|
+
num_workers=num_workers,
|
269
|
+
valid_size=0,
|
270
|
+
data_path=data_path,
|
271
|
+
file_type=file_type,
|
272
|
+
smi_col_names=smiles_col_names,
|
273
|
+
y_col_name=y_col_name
|
274
|
+
)
|
275
|
+
data_loader = wrapper.get_test_loader(
|
276
|
+
shuffle=False
|
277
|
+
)
|
278
|
+
trainer = GINTrainer(
|
279
|
+
base_dir=base_dir,
|
280
|
+
model_name=model_name,
|
281
|
+
model_init_args=model_init_args,
|
282
|
+
model_ckpt=model_ckpt,
|
283
|
+
data_loader=data_loader,
|
284
|
+
test_loader=None,
|
285
|
+
metrics=metrics,
|
286
|
+
device=device
|
287
|
+
)
|
288
|
+
metric_list = trainer.metrics
|
289
|
+
trainer.load(ckpt_file=model_ckpt)
|
290
|
+
trainer.model.eval()
|
291
|
+
results = trainer.predict(data_loader)
|
292
|
+
|
293
|
+
df = pd.read_csv(data_path)
|
294
|
+
df['pred'] = results
|
295
|
+
df.to_csv(
|
296
|
+
osp.join(base_dir, 'pred.csv'),
|
297
|
+
index=False
|
298
|
+
)
|
299
|
+
|
300
|
+
if y_col_name is not None:
|
301
|
+
metrics_df = pd.DataFrame()
|
302
|
+
y = df[y_col_name].array / 100
|
303
|
+
for metric_name, metric in zip(
|
304
|
+
metrics,
|
305
|
+
metric_list
|
306
|
+
):
|
307
|
+
metrics_df[metric_name] = np.array([metric(
|
308
|
+
y, results
|
309
|
+
)])
|
310
|
+
metrics_df.to_csv(
|
311
|
+
osp.join(
|
312
|
+
base_dir, 'metrics.csv'
|
313
|
+
),
|
314
|
+
index=False
|
315
|
+
)
|
316
|
+
|