hjxdl 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hdl/__init__.py +0 -0
- hdl/_version.py +16 -0
- hdl/args/__init__.py +0 -0
- hdl/args/loss_args.py +5 -0
- hdl/controllers/__init__.py +0 -0
- hdl/controllers/al/__init__.py +0 -0
- hdl/controllers/al/al.py +0 -0
- hdl/controllers/al/dispatcher.py +0 -0
- hdl/controllers/al/feedback.py +0 -0
- hdl/controllers/explain/__init__.py +0 -0
- hdl/controllers/explain/shapley.py +293 -0
- hdl/controllers/explain/subgraphx.py +865 -0
- hdl/controllers/train/__init__.py +0 -0
- hdl/controllers/train/rxn_train.py +219 -0
- hdl/controllers/train/train.py +50 -0
- hdl/controllers/train/train_ginet.py +316 -0
- hdl/controllers/train/trainer_base.py +155 -0
- hdl/controllers/train/trainer_iterative.py +389 -0
- hdl/data/__init__.py +0 -0
- hdl/data/dataset/__init__.py +0 -0
- hdl/data/dataset/base_dataset.py +98 -0
- hdl/data/dataset/fp/__init__.py +0 -0
- hdl/data/dataset/fp/fp_dataset.py +122 -0
- hdl/data/dataset/graph/__init__.py +0 -0
- hdl/data/dataset/graph/chiral.py +62 -0
- hdl/data/dataset/graph/gin.py +255 -0
- hdl/data/dataset/graph/molnet.py +362 -0
- hdl/data/dataset/loaders/__init__.py +0 -0
- hdl/data/dataset/loaders/chiral_graph.py +71 -0
- hdl/data/dataset/loaders/collate_funcs/__init__.py +0 -0
- hdl/data/dataset/loaders/collate_funcs/fp.py +56 -0
- hdl/data/dataset/loaders/collate_funcs/rxn.py +40 -0
- hdl/data/dataset/loaders/general.py +23 -0
- hdl/data/dataset/loaders/spliter.py +86 -0
- hdl/data/dataset/samplers/__init__.py +0 -0
- hdl/data/dataset/samplers/chiral.py +19 -0
- hdl/data/dataset/seq/__init__.py +0 -0
- hdl/data/dataset/seq/rxn_dataset.py +61 -0
- hdl/data/dataset/utils.py +31 -0
- hdl/data/to_mols.py +0 -0
- hdl/features/__init__.py +0 -0
- hdl/features/fp/__init__.py +0 -0
- hdl/features/fp/features_generators.py +235 -0
- hdl/features/graph/__init__.py +0 -0
- hdl/features/graph/featurization.py +297 -0
- hdl/features/utils/__init__.py +0 -0
- hdl/features/utils/utils.py +111 -0
- hdl/layers/__init__.py +0 -0
- hdl/layers/general/__init__.py +0 -0
- hdl/layers/general/gp.py +14 -0
- hdl/layers/general/linear.py +641 -0
- hdl/layers/graph/__init__.py +0 -0
- hdl/layers/graph/chiral_graph.py +230 -0
- hdl/layers/graph/gcn.py +16 -0
- hdl/layers/graph/gin.py +45 -0
- hdl/layers/graph/tetra.py +158 -0
- hdl/layers/graph/transformer.py +188 -0
- hdl/layers/sequential/__init__.py +0 -0
- hdl/metric_loss/__init__.py +0 -0
- hdl/metric_loss/loss.py +79 -0
- hdl/metric_loss/metric.py +178 -0
- hdl/metric_loss/multi_label.py +42 -0
- hdl/metric_loss/nt_xent.py +65 -0
- hdl/models/__init__.py +0 -0
- hdl/models/chiral_gnn.py +176 -0
- hdl/models/fast_transformer.py +234 -0
- hdl/models/ginet.py +189 -0
- hdl/models/linear.py +137 -0
- hdl/models/model_dict.py +18 -0
- hdl/models/norm_flows.py +33 -0
- hdl/models/optim_dict.py +16 -0
- hdl/models/rxn.py +63 -0
- hdl/models/utils.py +83 -0
- hdl/ops/__init__.py +0 -0
- hdl/ops/utils.py +42 -0
- hdl/optims/__init__.py +0 -0
- hdl/optims/nadam.py +86 -0
- hdl/utils/__init__.py +0 -0
- hdl/utils/chemical_tools/__init__.py +2 -0
- hdl/utils/chemical_tools/query_info.py +149 -0
- hdl/utils/chemical_tools/sdf.py +20 -0
- hdl/utils/database_tools/__init__.py +0 -0
- hdl/utils/database_tools/connect.py +28 -0
- hdl/utils/general/__init__.py +0 -0
- hdl/utils/general/glob.py +21 -0
- hdl/utils/schedulers/__init__.py +0 -0
- hdl/utils/schedulers/norm_lr.py +108 -0
- hjxdl-0.0.1.dist-info/METADATA +19 -0
- hjxdl-0.0.1.dist-info/RECORD +91 -0
- hjxdl-0.0.1.dist-info/WHEEL +5 -0
- hjxdl-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,155 @@
|
|
1
|
+
from abc import abstractmethod, ABC
|
2
|
+
from os import path as osp
|
3
|
+
|
4
|
+
import torch
|
5
|
+
from torch.utils.tensorboard import SummaryWriter
|
6
|
+
|
7
|
+
from jupyfuncs.glob import makedirs
|
8
|
+
|
9
|
+
from hdl.models.optim_dict import OPTIM_DICT
|
10
|
+
from hdl.models.model_dict import MODEL_DICT
|
11
|
+
from hdl.models.utils import save_model, load_model
|
12
|
+
from hdl.metric_loss.loss import get_lossfunc
|
13
|
+
from hdl.metric_loss.metric import get_metric
|
14
|
+
|
15
|
+
|
16
|
+
class TorchTrainer(ABC):
|
17
|
+
def __init__(
|
18
|
+
self,
|
19
|
+
base_dir,
|
20
|
+
data_loader,
|
21
|
+
test_loader,
|
22
|
+
metrics,
|
23
|
+
loss_func,
|
24
|
+
model=None,
|
25
|
+
model_name=None,
|
26
|
+
model_init_args=None,
|
27
|
+
ckpt_file=None,
|
28
|
+
model_ckpt=None,
|
29
|
+
optimizer=None,
|
30
|
+
optimizer_name=None,
|
31
|
+
optimizer_kwargs=None,
|
32
|
+
device=torch.device('cpu'),
|
33
|
+
) -> None:
|
34
|
+
super().__init__()
|
35
|
+
self.base_dir = base_dir
|
36
|
+
self.data_loader = data_loader
|
37
|
+
self.test_loader = test_loader
|
38
|
+
|
39
|
+
if metrics is not None:
|
40
|
+
self.metric_names = metrics
|
41
|
+
self.metrics = [get_metric(metric) for metric in metrics]
|
42
|
+
if loss_func is not None:
|
43
|
+
self.loss_name = loss_func
|
44
|
+
self.loss_func = get_lossfunc(loss_func)
|
45
|
+
if isinstance(device, str):
|
46
|
+
self.device = torch.device(device)
|
47
|
+
else:
|
48
|
+
self.device = device
|
49
|
+
self.logger = SummaryWriter(log_dir=self.base_dir)
|
50
|
+
|
51
|
+
self.losses = []
|
52
|
+
|
53
|
+
if model is not None:
|
54
|
+
self.model = model
|
55
|
+
else:
|
56
|
+
assert model_name is not None and model_init_args is not None
|
57
|
+
self.model = MODEL_DICT[model_name](**model_init_args)
|
58
|
+
|
59
|
+
self.ckpt_file = ckpt_file
|
60
|
+
self.model_ckpt = model_ckpt
|
61
|
+
|
62
|
+
self.model.to(self.device)
|
63
|
+
|
64
|
+
if optimizer is not None:
|
65
|
+
self.optimizer = optimizer
|
66
|
+
elif optimizer_name is not None and optimizer_kwargs is not None:
|
67
|
+
params = [{
|
68
|
+
'params': self.model.parameters(),
|
69
|
+
**optimizer_kwargs
|
70
|
+
}]
|
71
|
+
self.optimizer = OPTIM_DICT[optimizer_name](params)
|
72
|
+
else:
|
73
|
+
self.optimizer = None
|
74
|
+
|
75
|
+
self.n_iter = 0
|
76
|
+
self.epoch_id = 0
|
77
|
+
self.metrics = [get_metric(metric) for metric in metrics]
|
78
|
+
|
79
|
+
@abstractmethod
|
80
|
+
def load_ckpt(self, ckpt_file, train=False):
|
81
|
+
raise NotImplementedError
|
82
|
+
|
83
|
+
@abstractmethod
|
84
|
+
def train_a_batch(self):
|
85
|
+
raise NotImplementedError
|
86
|
+
|
87
|
+
@abstractmethod
|
88
|
+
def train_an_epoch(self):
|
89
|
+
raise NotImplementedError
|
90
|
+
|
91
|
+
@abstractmethod
|
92
|
+
def train(self):
|
93
|
+
raise NotImplementedError
|
94
|
+
|
95
|
+
def save(self):
|
96
|
+
makedirs(osp.join(self.base_dir, 'ckpt'))
|
97
|
+
ckpt_file = osp.join(
|
98
|
+
self.base_dir, 'ckpt',
|
99
|
+
f'model_{self.epoch_id}.ckpt'
|
100
|
+
)
|
101
|
+
|
102
|
+
save_model(
|
103
|
+
model=self.model,
|
104
|
+
save_dir=ckpt_file,
|
105
|
+
epoch=self.epoch_id,
|
106
|
+
optimizer=self.optimizer,
|
107
|
+
loss=self.losses
|
108
|
+
)
|
109
|
+
|
110
|
+
def load(self, ckpt_file, train=False):
|
111
|
+
load_model(
|
112
|
+
save_dir=ckpt_file,
|
113
|
+
model=self.model,
|
114
|
+
optimizer=self.optimizer,
|
115
|
+
train=train
|
116
|
+
)
|
117
|
+
|
118
|
+
@abstractmethod
|
119
|
+
def predict(self, data_loader):
|
120
|
+
raise NotImplementedError
|
121
|
+
|
122
|
+
|
123
|
+
class IterativeTrainer(TorchTrainer):
|
124
|
+
def __init__(
|
125
|
+
self,
|
126
|
+
base_dir,
|
127
|
+
data_loader,
|
128
|
+
test_loader,
|
129
|
+
metrics,
|
130
|
+
target_names,
|
131
|
+
loss_func,
|
132
|
+
logger
|
133
|
+
) -> None:
|
134
|
+
super().__init__(
|
135
|
+
base_dir,
|
136
|
+
data_loader,
|
137
|
+
test_loader,
|
138
|
+
metrics,
|
139
|
+
loss_func,
|
140
|
+
logger
|
141
|
+
)
|
142
|
+
self.target_names = target_names
|
143
|
+
|
144
|
+
@abstractmethod
|
145
|
+
def train_a_batch(self):
|
146
|
+
raise NotImplementedError
|
147
|
+
|
148
|
+
@abstractmethod
|
149
|
+
def train_an_epoch(self):
|
150
|
+
raise NotImplementedError
|
151
|
+
|
152
|
+
@abstractmethod
|
153
|
+
def train(self):
|
154
|
+
raise NotImplementedError
|
155
|
+
|
@@ -0,0 +1,389 @@
|
|
1
|
+
from os import path as osp
|
2
|
+
import typing as t
|
3
|
+
|
4
|
+
# import numpy as np
|
5
|
+
import torch
|
6
|
+
# from torch import nn
|
7
|
+
# from torch.optim import Adam
|
8
|
+
from torch import nn
|
9
|
+
# import pandas as pd
|
10
|
+
|
11
|
+
from hdl.models.utils import load_model, save_model
|
12
|
+
from hdl.models.model_dict import MODEL_DICT
|
13
|
+
from hdl.models.optim_dict import OPTIM_DICT
|
14
|
+
# from hdl.models.linear import MMIterLinear
|
15
|
+
from hdl.features.fp.features_generators import FP_BITS_DICT
|
16
|
+
from hdl.data.dataset.fp.fp_dataset import FPDataset
|
17
|
+
from hdl.data.dataset.loaders.general import Loader
|
18
|
+
from jupyfuncs.pbar import tnrange, tqdm
|
19
|
+
from jupyfuncs.glob import makedirs
|
20
|
+
from jupyfuncs.tensor import get_valid_indices
|
21
|
+
from hdl.metric_loss.loss import mtmc_loss
|
22
|
+
from hdl.controllers.train.trainer_base import IterativeTrainer
|
23
|
+
|
24
|
+
|
25
|
+
class MMIterTrainer(IterativeTrainer):
|
26
|
+
def __init__(
|
27
|
+
self,
|
28
|
+
base_dir,
|
29
|
+
data_loader,
|
30
|
+
target_names,
|
31
|
+
loss_func,
|
32
|
+
missing_labels=[],
|
33
|
+
task_weights=None,
|
34
|
+
test_loder=None,
|
35
|
+
metrics=None,
|
36
|
+
model=None,
|
37
|
+
model_name=None,
|
38
|
+
model_init_args=None,
|
39
|
+
ckpt_file=None,
|
40
|
+
optimizer=None,
|
41
|
+
optimizer_name=None,
|
42
|
+
optimizer_kwargs=None,
|
43
|
+
# logger=None,
|
44
|
+
device=torch.device('cpu'),
|
45
|
+
parallel=False,
|
46
|
+
):
|
47
|
+
super().__init__(
|
48
|
+
base_dir=base_dir,
|
49
|
+
data_loader=data_loader,
|
50
|
+
test_loader=test_loder,
|
51
|
+
metrics=metrics,
|
52
|
+
loss_func=loss_func,
|
53
|
+
target_names=target_names,
|
54
|
+
# logger=logger
|
55
|
+
)
|
56
|
+
assert len(missing_labels) == len(target_names)
|
57
|
+
self.epoch_id = 0
|
58
|
+
if model is not None:
|
59
|
+
self.model = model
|
60
|
+
else:
|
61
|
+
assert model_name is not None and model_init_args is not None
|
62
|
+
self.model = MODEL_DICT[model_name](**model_init_args)
|
63
|
+
self.model.to(device)
|
64
|
+
if optimizer is not None:
|
65
|
+
self.optimizer = optimizer
|
66
|
+
else:
|
67
|
+
assert optimizer_name is not None and optimizer_kwargs is not None
|
68
|
+
params = [{
|
69
|
+
'params': self.model.parameters(),
|
70
|
+
**optimizer_kwargs
|
71
|
+
}]
|
72
|
+
self.optimizer = OPTIM_DICT[optimizer_name](params)
|
73
|
+
|
74
|
+
if ckpt_file is not None:
|
75
|
+
self.model, self.optimizer, self.epoch_id, _ = load_model(
|
76
|
+
save_dir=ckpt_file,
|
77
|
+
model=self.model,
|
78
|
+
optimizer=self.optimizer,
|
79
|
+
train=True
|
80
|
+
)
|
81
|
+
if self.epoch_id != 0:
|
82
|
+
self.epoch_id += 1
|
83
|
+
|
84
|
+
self.metrics = metrics
|
85
|
+
self.device = device
|
86
|
+
self.missing_labels = missing_labels
|
87
|
+
if parallel:
|
88
|
+
self.model = nn.DataParallel(self.model)
|
89
|
+
|
90
|
+
if isinstance(loss_func, str):
|
91
|
+
self.loss_names = [loss_func] * len(target_names)
|
92
|
+
elif isinstance(loss_func, (t.List, t.Tuple)):
|
93
|
+
assert len(loss_func) == len(target_names)
|
94
|
+
self.loss_names = loss_func
|
95
|
+
|
96
|
+
if task_weights is None:
|
97
|
+
task_weights = [1] * len(target_names)
|
98
|
+
self.task_weights = task_weights
|
99
|
+
|
100
|
+
def train_a_batch(self, batch):
|
101
|
+
self.optimizer.zero_grad()
|
102
|
+
fps = [x.to(self.device) for x in batch[0]]
|
103
|
+
target_tensors = [
|
104
|
+
target_tensor.to(self.device) for target_tensor in batch[1]
|
105
|
+
]
|
106
|
+
target_list = batch[-1]
|
107
|
+
target_valid_dict = {}
|
108
|
+
for target_name, target_labels in zip(self.target_names, target_list):
|
109
|
+
|
110
|
+
valid_indices = get_valid_indices(labels=target_labels)
|
111
|
+
valid_indices.to(self.device)
|
112
|
+
|
113
|
+
target_valid_dict[target_name] = valid_indices
|
114
|
+
|
115
|
+
y_preds = self.model(fps, target_tensors, teach=True)
|
116
|
+
|
117
|
+
# process with y_true
|
118
|
+
y_trues = []
|
119
|
+
y_preds_list = []
|
120
|
+
for target_name, target_tensor, target_labels, loss_name in zip(
|
121
|
+
self.target_names, target_tensors, target_list, self.loss_names
|
122
|
+
):
|
123
|
+
valid_indices = target_valid_dict[target_name]
|
124
|
+
if loss_name in ['ce']:
|
125
|
+
y_true = target_labels[valid_indices].long().to(self.device)
|
126
|
+
elif loss_name in ['mse', 'bpmll']:
|
127
|
+
y_true = target_tensor[valid_indices].to(self.device)
|
128
|
+
y_pred = y_preds[target_name][valid_indices].to(self.device)
|
129
|
+
y_preds_list.append(y_pred)
|
130
|
+
y_trues.append(y_true)
|
131
|
+
|
132
|
+
# print(y_preds, y_trues)
|
133
|
+
loss, loss_list = mtmc_loss(
|
134
|
+
y_preds=y_preds_list,
|
135
|
+
y_trues=y_trues,
|
136
|
+
loss_names=self.loss_names,
|
137
|
+
individual=True,
|
138
|
+
task_weights=self.task_weights,
|
139
|
+
device=self.device
|
140
|
+
)
|
141
|
+
with open(osp.join(self.base_dir, 'loss.log'), 'a') as f:
|
142
|
+
f.write(str(loss.item()))
|
143
|
+
f.write('\t')
|
144
|
+
for i_loss in loss_list:
|
145
|
+
f.write(str(i_loss.item()))
|
146
|
+
f.write('\t')
|
147
|
+
f.write('\n')
|
148
|
+
f.flush()
|
149
|
+
|
150
|
+
loss.backward()
|
151
|
+
self.optimizer.step()
|
152
|
+
|
153
|
+
return loss
|
154
|
+
|
155
|
+
def train_an_epoch(self, epoch_id):
|
156
|
+
for batch in tqdm(self.data_loader):
|
157
|
+
loss = self.train_a_batch(
|
158
|
+
batch=batch
|
159
|
+
)
|
160
|
+
makedirs(osp.join(self.base_dir, 'ckpt'))
|
161
|
+
self.ckpt_file = osp.join(
|
162
|
+
self.base_dir, 'ckpt',
|
163
|
+
f'model.{epoch_id}.ckpt'
|
164
|
+
)
|
165
|
+
save_model(
|
166
|
+
model=self.model,
|
167
|
+
save_dir=self.ckpt_file,
|
168
|
+
epoch=epoch_id,
|
169
|
+
optimizer=self.optimizer,
|
170
|
+
loss=loss
|
171
|
+
)
|
172
|
+
|
173
|
+
def train(self, num_epochs):
|
174
|
+
for self.epoch_id in tnrange(
|
175
|
+
self.epoch_id,
|
176
|
+
self.epoch_id + num_epochs
|
177
|
+
):
|
178
|
+
self.train_an_epoch(
|
179
|
+
epoch_id=self.epoch_id
|
180
|
+
)
|
181
|
+
|
182
|
+
|
183
|
+
class MMIterTrainerBack(IterativeTrainer):
|
184
|
+
def __init__(
|
185
|
+
self,
|
186
|
+
base_dir,
|
187
|
+
model,
|
188
|
+
optimizer,
|
189
|
+
data_loader,
|
190
|
+
target_cols,
|
191
|
+
num_epochs,
|
192
|
+
loss_func,
|
193
|
+
ckpt_file,
|
194
|
+
device,
|
195
|
+
individual,
|
196
|
+
logger=None
|
197
|
+
):
|
198
|
+
super().__init__(
|
199
|
+
base_dir=base_dir,
|
200
|
+
data_loader=data_loader,
|
201
|
+
loss_func=loss_func,
|
202
|
+
logger=logger
|
203
|
+
)
|
204
|
+
self.model = model
|
205
|
+
self.optimizer = optimizer
|
206
|
+
self.target_cols = target_cols
|
207
|
+
self.num_epochs = num_epochs
|
208
|
+
self.ckpt_file = ckpt_file
|
209
|
+
self.device = device
|
210
|
+
self.individual = individual
|
211
|
+
|
212
|
+
def run(self):
|
213
|
+
for i, task in tqdm(enumerate(self.target_cols)):
|
214
|
+
if self.ckpt_file is not None:
|
215
|
+
self.model, self.optimizer, _, _ = load_model(
|
216
|
+
self.ckpt_file,
|
217
|
+
model=self.model,
|
218
|
+
optimizer=self.optimizer,
|
219
|
+
train=True,
|
220
|
+
)
|
221
|
+
|
222
|
+
self.model.freeze_classifier[i] = False
|
223
|
+
|
224
|
+
for epoch_id in tnrange(self.num_epochs):
|
225
|
+
self.train_an_epoch(
|
226
|
+
target_ind=i,
|
227
|
+
target_name=task,
|
228
|
+
epoch_id=epoch_id
|
229
|
+
)
|
230
|
+
|
231
|
+
def train_a_batch(
|
232
|
+
self,
|
233
|
+
batch,
|
234
|
+
target_ind,
|
235
|
+
target_name,
|
236
|
+
# epoch_id
|
237
|
+
):
|
238
|
+
self.optimizer.zero_grad()
|
239
|
+
|
240
|
+
y = (batch[-1][target_ind]).to(self.device)
|
241
|
+
X = [x.to(self.device)[y >= 0].float() for x in batch[0]]
|
242
|
+
y = y[y >= 0]
|
243
|
+
|
244
|
+
y_preds = self.model(X, teach=False)[target_name]
|
245
|
+
|
246
|
+
loss_name = self.loss_func[target_ind] \
|
247
|
+
if isinstance(self.loss_func, list) else self.loss_func
|
248
|
+
|
249
|
+
loss = mtmc_loss(
|
250
|
+
[y_preds],
|
251
|
+
[y],
|
252
|
+
loss_names=loss_name,
|
253
|
+
individual=self.individual,
|
254
|
+
device=self.device
|
255
|
+
)
|
256
|
+
|
257
|
+
if not self.individual:
|
258
|
+
final_loss = loss
|
259
|
+
individual_losses = []
|
260
|
+
else:
|
261
|
+
final_loss = loss[0]
|
262
|
+
individual_losses = loss[1]
|
263
|
+
|
264
|
+
final_loss.backward()
|
265
|
+
self.optimizer.step()
|
266
|
+
|
267
|
+
with open(
|
268
|
+
osp.join(self.base_dir, target_name + '_loss.log'),
|
269
|
+
'a'
|
270
|
+
) as f:
|
271
|
+
f.write(str(final_loss.item()))
|
272
|
+
f.write('\t')
|
273
|
+
for individual_loss in individual_losses:
|
274
|
+
f.write(str(individual_loss))
|
275
|
+
f.write('\t')
|
276
|
+
f.write('\n')
|
277
|
+
return loss
|
278
|
+
|
279
|
+
def train_an_epoch(
|
280
|
+
self,
|
281
|
+
target_ind,
|
282
|
+
target_name,
|
283
|
+
epoch_id,
|
284
|
+
):
|
285
|
+
|
286
|
+
for batch in tqdm(self.data_loader):
|
287
|
+
loss = self.train_a_batch(
|
288
|
+
batch=batch,
|
289
|
+
target_ind=target_ind,
|
290
|
+
target_name=target_name
|
291
|
+
)
|
292
|
+
|
293
|
+
makedirs(osp.join(self.base_dir, 'ckpt'))
|
294
|
+
self.ckpt_file = osp.join(
|
295
|
+
self.base_dir, 'ckpt',
|
296
|
+
f'model.{target_name}_{epoch_id}.ckpt'
|
297
|
+
)
|
298
|
+
save_model(
|
299
|
+
model=self.model,
|
300
|
+
save_dir=self.ckpt_file,
|
301
|
+
epoch=epoch_id,
|
302
|
+
optimizer=self.optimizer,
|
303
|
+
loss=loss
|
304
|
+
)
|
305
|
+
|
306
|
+
|
307
|
+
def train(
|
308
|
+
base_dir: str,
|
309
|
+
csv_file: str,
|
310
|
+
splitter: str,
|
311
|
+
model_name: str,
|
312
|
+
# model_init_args: t.Dict,
|
313
|
+
ckpt_file: str = None,
|
314
|
+
smiles_cols: t.List = [],
|
315
|
+
fp_type: str = 'morgan_count',
|
316
|
+
num_epochs: int = 20,
|
317
|
+
target_cols: t.List = [],
|
318
|
+
nums_classes: t.List = [],
|
319
|
+
missing_labels: t.List = [],
|
320
|
+
target_transform: t.List = [],
|
321
|
+
optimizer_name: str = 'adam',
|
322
|
+
loss_func: str = 'ce',
|
323
|
+
batch_size: int = 128,
|
324
|
+
hidden_size: int = 128,
|
325
|
+
num_hidden_layers: int = 10,
|
326
|
+
num_workers: int = 12,
|
327
|
+
device_id: int = 0,
|
328
|
+
**kwargs
|
329
|
+
):
|
330
|
+
base_dir = osp.abspath(base_dir)
|
331
|
+
makedirs(base_dir)
|
332
|
+
|
333
|
+
device = torch.device(f'cuda:{device_id}') \
|
334
|
+
if torch.cuda.is_available() \
|
335
|
+
else torch.device('cpu')
|
336
|
+
if kwargs.get('cpu', False):
|
337
|
+
device = torch.device('cpu')
|
338
|
+
|
339
|
+
converters = kwargs.get('converters', {})
|
340
|
+
dataset = FPDataset(
|
341
|
+
csv_file=csv_file,
|
342
|
+
splitter=splitter,
|
343
|
+
smiles_cols=smiles_cols,
|
344
|
+
target_cols=target_cols,
|
345
|
+
num_classes=nums_classes,
|
346
|
+
missing_labels=missing_labels,
|
347
|
+
target_transform=target_transform,
|
348
|
+
fp_type=fp_type,
|
349
|
+
converters=converters
|
350
|
+
)
|
351
|
+
data_loader = Loader(
|
352
|
+
dataset=dataset,
|
353
|
+
batch_size=batch_size,
|
354
|
+
shuffle=True,
|
355
|
+
num_workers=num_workers
|
356
|
+
)
|
357
|
+
model_init_args = {}
|
358
|
+
model_init_args['nums_classes'] = nums_classes
|
359
|
+
model_init_args['target_names'] = target_cols
|
360
|
+
model_init_args['num_fp_bits'] = FP_BITS_DICT[fp_type]
|
361
|
+
model_init_args['hidden_size'] = hidden_size
|
362
|
+
model_init_args['num_hidden_layers'] = num_hidden_layers
|
363
|
+
model_init_args['dim'] = kwargs.get('dim', -1)
|
364
|
+
model_init_args['hard_select'] = kwargs.get('hard_select', False)
|
365
|
+
model_init_args['iterative'] = kwargs.get('iterative', True)
|
366
|
+
model_init_args['num_in_feats'] = kwargs.get('num_in_feats', 1024)
|
367
|
+
|
368
|
+
trainer = MMIterTrainer(
|
369
|
+
base_dir=base_dir,
|
370
|
+
data_loader=data_loader,
|
371
|
+
target_names=target_cols,
|
372
|
+
loss_func=loss_func,
|
373
|
+
missing_labels=missing_labels,
|
374
|
+
task_weights=kwargs.get('task_weights', None),
|
375
|
+
test_loder=None,
|
376
|
+
metrics=None,
|
377
|
+
model_name=model_name,
|
378
|
+
model_init_args=model_init_args,
|
379
|
+
ckpt_file=ckpt_file,
|
380
|
+
optimizer_name=optimizer_name,
|
381
|
+
optimizer_kwargs={
|
382
|
+
'lr': kwargs.get('lr', 0.01),
|
383
|
+
'weight_decay': kwargs.get('weight_decay', 0)
|
384
|
+
},
|
385
|
+
logger=None,
|
386
|
+
device=device,
|
387
|
+
parallel=kwargs.get('parallel', False)
|
388
|
+
)
|
389
|
+
trainer.train(num_epochs=num_epochs)
|
hdl/data/__init__.py
ADDED
File without changes
|
File without changes
|
@@ -0,0 +1,98 @@
|
|
1
|
+
from os import path as osp
|
2
|
+
import typing as t
|
3
|
+
|
4
|
+
import torch.utils.data as tud
|
5
|
+
import pandas as pd
|
6
|
+
|
7
|
+
from jupyfuncs.dataframe import rm_index
|
8
|
+
from jupyfuncs.tensor import (
|
9
|
+
label_to_onehot,
|
10
|
+
label_to_tensor
|
11
|
+
)
|
12
|
+
|
13
|
+
|
14
|
+
def percent(x, *args, **kwargs):
|
15
|
+
return x / 100
|
16
|
+
|
17
|
+
|
18
|
+
label_trans_dict = {
|
19
|
+
'onehot': label_to_onehot,
|
20
|
+
'tensor': label_to_tensor,
|
21
|
+
'percent': percent
|
22
|
+
}
|
23
|
+
|
24
|
+
|
25
|
+
class CSVDataset(tud.Dataset):
|
26
|
+
def __init__(
|
27
|
+
self,
|
28
|
+
csv_file: str,
|
29
|
+
splitter: str = ',',
|
30
|
+
smiles_col: str = 'SMILES',
|
31
|
+
target_cols: t.List[str] = [],
|
32
|
+
num_classes: t.List[int] = [],
|
33
|
+
target_transform: t.Union[str, t.List[str]] = None,
|
34
|
+
**kwargs
|
35
|
+
) -> None:
|
36
|
+
super().__init__()
|
37
|
+
self.csv = osp.abspath(csv_file)
|
38
|
+
df = pd.read_csv(
|
39
|
+
self.csv,
|
40
|
+
sep=splitter,
|
41
|
+
**kwargs
|
42
|
+
)
|
43
|
+
self.df = rm_index(df)
|
44
|
+
self.smiles_col = smiles_col
|
45
|
+
self.target_cols = target_cols
|
46
|
+
self.num_classes = num_classes
|
47
|
+
if target_transform is not None:
|
48
|
+
if not num_classes:
|
49
|
+
self.num_classes = [1 for _ in range(len(target_cols))]
|
50
|
+
else:
|
51
|
+
assert len(self.num_classes) == len(target_cols)
|
52
|
+
if isinstance(target_transform, str):
|
53
|
+
self.target_transform = [label_trans_dict[target_transform]] * \
|
54
|
+
len(self.num_classes)
|
55
|
+
elif isinstance(target_transform, t.Iterable):
|
56
|
+
self.target_transform = [
|
57
|
+
label_trans_dict[target_trans]
|
58
|
+
for target_trans in target_transform
|
59
|
+
]
|
60
|
+
else:
|
61
|
+
self.target_transform = None
|
62
|
+
|
63
|
+
def __getitem__(self, index):
|
64
|
+
raise NotImplementedError
|
65
|
+
|
66
|
+
def __len__(self):
|
67
|
+
return len(self.df)
|
68
|
+
|
69
|
+
|
70
|
+
class CSVRDataset(tud.Dataset):
|
71
|
+
def __init__(
|
72
|
+
self,
|
73
|
+
csv_file: str,
|
74
|
+
splitter: str,
|
75
|
+
smiles_col: str,
|
76
|
+
target_col: str = None,
|
77
|
+
missing_label: str = None,
|
78
|
+
target_transform: t.Union[str, t.List[str]] = None,
|
79
|
+
**kwargs
|
80
|
+
) -> None:
|
81
|
+
self.csv_file = csv_file
|
82
|
+
df = pd.read_csv(
|
83
|
+
self.csv_file,
|
84
|
+
sep=splitter,
|
85
|
+
**kwargs
|
86
|
+
)
|
87
|
+
self.df = rm_index(df)
|
88
|
+
self.smiles_col = smiles_col
|
89
|
+
self.target_col = target_col
|
90
|
+
self.miss_label = missing_label
|
91
|
+
if target_transform is not None:
|
92
|
+
self.target_transform = label_trans_dict[target_transform]
|
93
|
+
|
94
|
+
def __getitem__(self, index):
|
95
|
+
raise NotImplementedError
|
96
|
+
|
97
|
+
def __len__(self):
|
98
|
+
return len(self.df)
|
File without changes
|