hjxdl 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (91) hide show
  1. hdl/__init__.py +0 -0
  2. hdl/_version.py +16 -0
  3. hdl/args/__init__.py +0 -0
  4. hdl/args/loss_args.py +5 -0
  5. hdl/controllers/__init__.py +0 -0
  6. hdl/controllers/al/__init__.py +0 -0
  7. hdl/controllers/al/al.py +0 -0
  8. hdl/controllers/al/dispatcher.py +0 -0
  9. hdl/controllers/al/feedback.py +0 -0
  10. hdl/controllers/explain/__init__.py +0 -0
  11. hdl/controllers/explain/shapley.py +293 -0
  12. hdl/controllers/explain/subgraphx.py +865 -0
  13. hdl/controllers/train/__init__.py +0 -0
  14. hdl/controllers/train/rxn_train.py +219 -0
  15. hdl/controllers/train/train.py +50 -0
  16. hdl/controllers/train/train_ginet.py +316 -0
  17. hdl/controllers/train/trainer_base.py +155 -0
  18. hdl/controllers/train/trainer_iterative.py +389 -0
  19. hdl/data/__init__.py +0 -0
  20. hdl/data/dataset/__init__.py +0 -0
  21. hdl/data/dataset/base_dataset.py +98 -0
  22. hdl/data/dataset/fp/__init__.py +0 -0
  23. hdl/data/dataset/fp/fp_dataset.py +122 -0
  24. hdl/data/dataset/graph/__init__.py +0 -0
  25. hdl/data/dataset/graph/chiral.py +62 -0
  26. hdl/data/dataset/graph/gin.py +255 -0
  27. hdl/data/dataset/graph/molnet.py +362 -0
  28. hdl/data/dataset/loaders/__init__.py +0 -0
  29. hdl/data/dataset/loaders/chiral_graph.py +71 -0
  30. hdl/data/dataset/loaders/collate_funcs/__init__.py +0 -0
  31. hdl/data/dataset/loaders/collate_funcs/fp.py +56 -0
  32. hdl/data/dataset/loaders/collate_funcs/rxn.py +40 -0
  33. hdl/data/dataset/loaders/general.py +23 -0
  34. hdl/data/dataset/loaders/spliter.py +86 -0
  35. hdl/data/dataset/samplers/__init__.py +0 -0
  36. hdl/data/dataset/samplers/chiral.py +19 -0
  37. hdl/data/dataset/seq/__init__.py +0 -0
  38. hdl/data/dataset/seq/rxn_dataset.py +61 -0
  39. hdl/data/dataset/utils.py +31 -0
  40. hdl/data/to_mols.py +0 -0
  41. hdl/features/__init__.py +0 -0
  42. hdl/features/fp/__init__.py +0 -0
  43. hdl/features/fp/features_generators.py +235 -0
  44. hdl/features/graph/__init__.py +0 -0
  45. hdl/features/graph/featurization.py +297 -0
  46. hdl/features/utils/__init__.py +0 -0
  47. hdl/features/utils/utils.py +111 -0
  48. hdl/layers/__init__.py +0 -0
  49. hdl/layers/general/__init__.py +0 -0
  50. hdl/layers/general/gp.py +14 -0
  51. hdl/layers/general/linear.py +641 -0
  52. hdl/layers/graph/__init__.py +0 -0
  53. hdl/layers/graph/chiral_graph.py +230 -0
  54. hdl/layers/graph/gcn.py +16 -0
  55. hdl/layers/graph/gin.py +45 -0
  56. hdl/layers/graph/tetra.py +158 -0
  57. hdl/layers/graph/transformer.py +188 -0
  58. hdl/layers/sequential/__init__.py +0 -0
  59. hdl/metric_loss/__init__.py +0 -0
  60. hdl/metric_loss/loss.py +79 -0
  61. hdl/metric_loss/metric.py +178 -0
  62. hdl/metric_loss/multi_label.py +42 -0
  63. hdl/metric_loss/nt_xent.py +65 -0
  64. hdl/models/__init__.py +0 -0
  65. hdl/models/chiral_gnn.py +176 -0
  66. hdl/models/fast_transformer.py +234 -0
  67. hdl/models/ginet.py +189 -0
  68. hdl/models/linear.py +137 -0
  69. hdl/models/model_dict.py +18 -0
  70. hdl/models/norm_flows.py +33 -0
  71. hdl/models/optim_dict.py +16 -0
  72. hdl/models/rxn.py +63 -0
  73. hdl/models/utils.py +83 -0
  74. hdl/ops/__init__.py +0 -0
  75. hdl/ops/utils.py +42 -0
  76. hdl/optims/__init__.py +0 -0
  77. hdl/optims/nadam.py +86 -0
  78. hdl/utils/__init__.py +0 -0
  79. hdl/utils/chemical_tools/__init__.py +2 -0
  80. hdl/utils/chemical_tools/query_info.py +149 -0
  81. hdl/utils/chemical_tools/sdf.py +20 -0
  82. hdl/utils/database_tools/__init__.py +0 -0
  83. hdl/utils/database_tools/connect.py +28 -0
  84. hdl/utils/general/__init__.py +0 -0
  85. hdl/utils/general/glob.py +21 -0
  86. hdl/utils/schedulers/__init__.py +0 -0
  87. hdl/utils/schedulers/norm_lr.py +108 -0
  88. hjxdl-0.0.1.dist-info/METADATA +19 -0
  89. hjxdl-0.0.1.dist-info/RECORD +91 -0
  90. hjxdl-0.0.1.dist-info/WHEEL +5 -0
  91. hjxdl-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,155 @@
1
+ from abc import abstractmethod, ABC
2
+ from os import path as osp
3
+
4
+ import torch
5
+ from torch.utils.tensorboard import SummaryWriter
6
+
7
+ from jupyfuncs.glob import makedirs
8
+
9
+ from hdl.models.optim_dict import OPTIM_DICT
10
+ from hdl.models.model_dict import MODEL_DICT
11
+ from hdl.models.utils import save_model, load_model
12
+ from hdl.metric_loss.loss import get_lossfunc
13
+ from hdl.metric_loss.metric import get_metric
14
+
15
+
16
+ class TorchTrainer(ABC):
17
+ def __init__(
18
+ self,
19
+ base_dir,
20
+ data_loader,
21
+ test_loader,
22
+ metrics,
23
+ loss_func,
24
+ model=None,
25
+ model_name=None,
26
+ model_init_args=None,
27
+ ckpt_file=None,
28
+ model_ckpt=None,
29
+ optimizer=None,
30
+ optimizer_name=None,
31
+ optimizer_kwargs=None,
32
+ device=torch.device('cpu'),
33
+ ) -> None:
34
+ super().__init__()
35
+ self.base_dir = base_dir
36
+ self.data_loader = data_loader
37
+ self.test_loader = test_loader
38
+
39
+ if metrics is not None:
40
+ self.metric_names = metrics
41
+ self.metrics = [get_metric(metric) for metric in metrics]
42
+ if loss_func is not None:
43
+ self.loss_name = loss_func
44
+ self.loss_func = get_lossfunc(loss_func)
45
+ if isinstance(device, str):
46
+ self.device = torch.device(device)
47
+ else:
48
+ self.device = device
49
+ self.logger = SummaryWriter(log_dir=self.base_dir)
50
+
51
+ self.losses = []
52
+
53
+ if model is not None:
54
+ self.model = model
55
+ else:
56
+ assert model_name is not None and model_init_args is not None
57
+ self.model = MODEL_DICT[model_name](**model_init_args)
58
+
59
+ self.ckpt_file = ckpt_file
60
+ self.model_ckpt = model_ckpt
61
+
62
+ self.model.to(self.device)
63
+
64
+ if optimizer is not None:
65
+ self.optimizer = optimizer
66
+ elif optimizer_name is not None and optimizer_kwargs is not None:
67
+ params = [{
68
+ 'params': self.model.parameters(),
69
+ **optimizer_kwargs
70
+ }]
71
+ self.optimizer = OPTIM_DICT[optimizer_name](params)
72
+ else:
73
+ self.optimizer = None
74
+
75
+ self.n_iter = 0
76
+ self.epoch_id = 0
77
+ self.metrics = [get_metric(metric) for metric in metrics]
78
+
79
+ @abstractmethod
80
+ def load_ckpt(self, ckpt_file, train=False):
81
+ raise NotImplementedError
82
+
83
+ @abstractmethod
84
+ def train_a_batch(self):
85
+ raise NotImplementedError
86
+
87
+ @abstractmethod
88
+ def train_an_epoch(self):
89
+ raise NotImplementedError
90
+
91
+ @abstractmethod
92
+ def train(self):
93
+ raise NotImplementedError
94
+
95
+ def save(self):
96
+ makedirs(osp.join(self.base_dir, 'ckpt'))
97
+ ckpt_file = osp.join(
98
+ self.base_dir, 'ckpt',
99
+ f'model_{self.epoch_id}.ckpt'
100
+ )
101
+
102
+ save_model(
103
+ model=self.model,
104
+ save_dir=ckpt_file,
105
+ epoch=self.epoch_id,
106
+ optimizer=self.optimizer,
107
+ loss=self.losses
108
+ )
109
+
110
+ def load(self, ckpt_file, train=False):
111
+ load_model(
112
+ save_dir=ckpt_file,
113
+ model=self.model,
114
+ optimizer=self.optimizer,
115
+ train=train
116
+ )
117
+
118
+ @abstractmethod
119
+ def predict(self, data_loader):
120
+ raise NotImplementedError
121
+
122
+
123
+ class IterativeTrainer(TorchTrainer):
124
+ def __init__(
125
+ self,
126
+ base_dir,
127
+ data_loader,
128
+ test_loader,
129
+ metrics,
130
+ target_names,
131
+ loss_func,
132
+ logger
133
+ ) -> None:
134
+ super().__init__(
135
+ base_dir,
136
+ data_loader,
137
+ test_loader,
138
+ metrics,
139
+ loss_func,
140
+ logger
141
+ )
142
+ self.target_names = target_names
143
+
144
+ @abstractmethod
145
+ def train_a_batch(self):
146
+ raise NotImplementedError
147
+
148
+ @abstractmethod
149
+ def train_an_epoch(self):
150
+ raise NotImplementedError
151
+
152
+ @abstractmethod
153
+ def train(self):
154
+ raise NotImplementedError
155
+
@@ -0,0 +1,389 @@
1
+ from os import path as osp
2
+ import typing as t
3
+
4
+ # import numpy as np
5
+ import torch
6
+ # from torch import nn
7
+ # from torch.optim import Adam
8
+ from torch import nn
9
+ # import pandas as pd
10
+
11
+ from hdl.models.utils import load_model, save_model
12
+ from hdl.models.model_dict import MODEL_DICT
13
+ from hdl.models.optim_dict import OPTIM_DICT
14
+ # from hdl.models.linear import MMIterLinear
15
+ from hdl.features.fp.features_generators import FP_BITS_DICT
16
+ from hdl.data.dataset.fp.fp_dataset import FPDataset
17
+ from hdl.data.dataset.loaders.general import Loader
18
+ from jupyfuncs.pbar import tnrange, tqdm
19
+ from jupyfuncs.glob import makedirs
20
+ from jupyfuncs.tensor import get_valid_indices
21
+ from hdl.metric_loss.loss import mtmc_loss
22
+ from hdl.controllers.train.trainer_base import IterativeTrainer
23
+
24
+
25
+ class MMIterTrainer(IterativeTrainer):
26
+ def __init__(
27
+ self,
28
+ base_dir,
29
+ data_loader,
30
+ target_names,
31
+ loss_func,
32
+ missing_labels=[],
33
+ task_weights=None,
34
+ test_loder=None,
35
+ metrics=None,
36
+ model=None,
37
+ model_name=None,
38
+ model_init_args=None,
39
+ ckpt_file=None,
40
+ optimizer=None,
41
+ optimizer_name=None,
42
+ optimizer_kwargs=None,
43
+ # logger=None,
44
+ device=torch.device('cpu'),
45
+ parallel=False,
46
+ ):
47
+ super().__init__(
48
+ base_dir=base_dir,
49
+ data_loader=data_loader,
50
+ test_loader=test_loder,
51
+ metrics=metrics,
52
+ loss_func=loss_func,
53
+ target_names=target_names,
54
+ # logger=logger
55
+ )
56
+ assert len(missing_labels) == len(target_names)
57
+ self.epoch_id = 0
58
+ if model is not None:
59
+ self.model = model
60
+ else:
61
+ assert model_name is not None and model_init_args is not None
62
+ self.model = MODEL_DICT[model_name](**model_init_args)
63
+ self.model.to(device)
64
+ if optimizer is not None:
65
+ self.optimizer = optimizer
66
+ else:
67
+ assert optimizer_name is not None and optimizer_kwargs is not None
68
+ params = [{
69
+ 'params': self.model.parameters(),
70
+ **optimizer_kwargs
71
+ }]
72
+ self.optimizer = OPTIM_DICT[optimizer_name](params)
73
+
74
+ if ckpt_file is not None:
75
+ self.model, self.optimizer, self.epoch_id, _ = load_model(
76
+ save_dir=ckpt_file,
77
+ model=self.model,
78
+ optimizer=self.optimizer,
79
+ train=True
80
+ )
81
+ if self.epoch_id != 0:
82
+ self.epoch_id += 1
83
+
84
+ self.metrics = metrics
85
+ self.device = device
86
+ self.missing_labels = missing_labels
87
+ if parallel:
88
+ self.model = nn.DataParallel(self.model)
89
+
90
+ if isinstance(loss_func, str):
91
+ self.loss_names = [loss_func] * len(target_names)
92
+ elif isinstance(loss_func, (t.List, t.Tuple)):
93
+ assert len(loss_func) == len(target_names)
94
+ self.loss_names = loss_func
95
+
96
+ if task_weights is None:
97
+ task_weights = [1] * len(target_names)
98
+ self.task_weights = task_weights
99
+
100
+ def train_a_batch(self, batch):
101
+ self.optimizer.zero_grad()
102
+ fps = [x.to(self.device) for x in batch[0]]
103
+ target_tensors = [
104
+ target_tensor.to(self.device) for target_tensor in batch[1]
105
+ ]
106
+ target_list = batch[-1]
107
+ target_valid_dict = {}
108
+ for target_name, target_labels in zip(self.target_names, target_list):
109
+
110
+ valid_indices = get_valid_indices(labels=target_labels)
111
+ valid_indices.to(self.device)
112
+
113
+ target_valid_dict[target_name] = valid_indices
114
+
115
+ y_preds = self.model(fps, target_tensors, teach=True)
116
+
117
+ # process with y_true
118
+ y_trues = []
119
+ y_preds_list = []
120
+ for target_name, target_tensor, target_labels, loss_name in zip(
121
+ self.target_names, target_tensors, target_list, self.loss_names
122
+ ):
123
+ valid_indices = target_valid_dict[target_name]
124
+ if loss_name in ['ce']:
125
+ y_true = target_labels[valid_indices].long().to(self.device)
126
+ elif loss_name in ['mse', 'bpmll']:
127
+ y_true = target_tensor[valid_indices].to(self.device)
128
+ y_pred = y_preds[target_name][valid_indices].to(self.device)
129
+ y_preds_list.append(y_pred)
130
+ y_trues.append(y_true)
131
+
132
+ # print(y_preds, y_trues)
133
+ loss, loss_list = mtmc_loss(
134
+ y_preds=y_preds_list,
135
+ y_trues=y_trues,
136
+ loss_names=self.loss_names,
137
+ individual=True,
138
+ task_weights=self.task_weights,
139
+ device=self.device
140
+ )
141
+ with open(osp.join(self.base_dir, 'loss.log'), 'a') as f:
142
+ f.write(str(loss.item()))
143
+ f.write('\t')
144
+ for i_loss in loss_list:
145
+ f.write(str(i_loss.item()))
146
+ f.write('\t')
147
+ f.write('\n')
148
+ f.flush()
149
+
150
+ loss.backward()
151
+ self.optimizer.step()
152
+
153
+ return loss
154
+
155
+ def train_an_epoch(self, epoch_id):
156
+ for batch in tqdm(self.data_loader):
157
+ loss = self.train_a_batch(
158
+ batch=batch
159
+ )
160
+ makedirs(osp.join(self.base_dir, 'ckpt'))
161
+ self.ckpt_file = osp.join(
162
+ self.base_dir, 'ckpt',
163
+ f'model.{epoch_id}.ckpt'
164
+ )
165
+ save_model(
166
+ model=self.model,
167
+ save_dir=self.ckpt_file,
168
+ epoch=epoch_id,
169
+ optimizer=self.optimizer,
170
+ loss=loss
171
+ )
172
+
173
+ def train(self, num_epochs):
174
+ for self.epoch_id in tnrange(
175
+ self.epoch_id,
176
+ self.epoch_id + num_epochs
177
+ ):
178
+ self.train_an_epoch(
179
+ epoch_id=self.epoch_id
180
+ )
181
+
182
+
183
+ class MMIterTrainerBack(IterativeTrainer):
184
+ def __init__(
185
+ self,
186
+ base_dir,
187
+ model,
188
+ optimizer,
189
+ data_loader,
190
+ target_cols,
191
+ num_epochs,
192
+ loss_func,
193
+ ckpt_file,
194
+ device,
195
+ individual,
196
+ logger=None
197
+ ):
198
+ super().__init__(
199
+ base_dir=base_dir,
200
+ data_loader=data_loader,
201
+ loss_func=loss_func,
202
+ logger=logger
203
+ )
204
+ self.model = model
205
+ self.optimizer = optimizer
206
+ self.target_cols = target_cols
207
+ self.num_epochs = num_epochs
208
+ self.ckpt_file = ckpt_file
209
+ self.device = device
210
+ self.individual = individual
211
+
212
+ def run(self):
213
+ for i, task in tqdm(enumerate(self.target_cols)):
214
+ if self.ckpt_file is not None:
215
+ self.model, self.optimizer, _, _ = load_model(
216
+ self.ckpt_file,
217
+ model=self.model,
218
+ optimizer=self.optimizer,
219
+ train=True,
220
+ )
221
+
222
+ self.model.freeze_classifier[i] = False
223
+
224
+ for epoch_id in tnrange(self.num_epochs):
225
+ self.train_an_epoch(
226
+ target_ind=i,
227
+ target_name=task,
228
+ epoch_id=epoch_id
229
+ )
230
+
231
+ def train_a_batch(
232
+ self,
233
+ batch,
234
+ target_ind,
235
+ target_name,
236
+ # epoch_id
237
+ ):
238
+ self.optimizer.zero_grad()
239
+
240
+ y = (batch[-1][target_ind]).to(self.device)
241
+ X = [x.to(self.device)[y >= 0].float() for x in batch[0]]
242
+ y = y[y >= 0]
243
+
244
+ y_preds = self.model(X, teach=False)[target_name]
245
+
246
+ loss_name = self.loss_func[target_ind] \
247
+ if isinstance(self.loss_func, list) else self.loss_func
248
+
249
+ loss = mtmc_loss(
250
+ [y_preds],
251
+ [y],
252
+ loss_names=loss_name,
253
+ individual=self.individual,
254
+ device=self.device
255
+ )
256
+
257
+ if not self.individual:
258
+ final_loss = loss
259
+ individual_losses = []
260
+ else:
261
+ final_loss = loss[0]
262
+ individual_losses = loss[1]
263
+
264
+ final_loss.backward()
265
+ self.optimizer.step()
266
+
267
+ with open(
268
+ osp.join(self.base_dir, target_name + '_loss.log'),
269
+ 'a'
270
+ ) as f:
271
+ f.write(str(final_loss.item()))
272
+ f.write('\t')
273
+ for individual_loss in individual_losses:
274
+ f.write(str(individual_loss))
275
+ f.write('\t')
276
+ f.write('\n')
277
+ return loss
278
+
279
+ def train_an_epoch(
280
+ self,
281
+ target_ind,
282
+ target_name,
283
+ epoch_id,
284
+ ):
285
+
286
+ for batch in tqdm(self.data_loader):
287
+ loss = self.train_a_batch(
288
+ batch=batch,
289
+ target_ind=target_ind,
290
+ target_name=target_name
291
+ )
292
+
293
+ makedirs(osp.join(self.base_dir, 'ckpt'))
294
+ self.ckpt_file = osp.join(
295
+ self.base_dir, 'ckpt',
296
+ f'model.{target_name}_{epoch_id}.ckpt'
297
+ )
298
+ save_model(
299
+ model=self.model,
300
+ save_dir=self.ckpt_file,
301
+ epoch=epoch_id,
302
+ optimizer=self.optimizer,
303
+ loss=loss
304
+ )
305
+
306
+
307
+ def train(
308
+ base_dir: str,
309
+ csv_file: str,
310
+ splitter: str,
311
+ model_name: str,
312
+ # model_init_args: t.Dict,
313
+ ckpt_file: str = None,
314
+ smiles_cols: t.List = [],
315
+ fp_type: str = 'morgan_count',
316
+ num_epochs: int = 20,
317
+ target_cols: t.List = [],
318
+ nums_classes: t.List = [],
319
+ missing_labels: t.List = [],
320
+ target_transform: t.List = [],
321
+ optimizer_name: str = 'adam',
322
+ loss_func: str = 'ce',
323
+ batch_size: int = 128,
324
+ hidden_size: int = 128,
325
+ num_hidden_layers: int = 10,
326
+ num_workers: int = 12,
327
+ device_id: int = 0,
328
+ **kwargs
329
+ ):
330
+ base_dir = osp.abspath(base_dir)
331
+ makedirs(base_dir)
332
+
333
+ device = torch.device(f'cuda:{device_id}') \
334
+ if torch.cuda.is_available() \
335
+ else torch.device('cpu')
336
+ if kwargs.get('cpu', False):
337
+ device = torch.device('cpu')
338
+
339
+ converters = kwargs.get('converters', {})
340
+ dataset = FPDataset(
341
+ csv_file=csv_file,
342
+ splitter=splitter,
343
+ smiles_cols=smiles_cols,
344
+ target_cols=target_cols,
345
+ num_classes=nums_classes,
346
+ missing_labels=missing_labels,
347
+ target_transform=target_transform,
348
+ fp_type=fp_type,
349
+ converters=converters
350
+ )
351
+ data_loader = Loader(
352
+ dataset=dataset,
353
+ batch_size=batch_size,
354
+ shuffle=True,
355
+ num_workers=num_workers
356
+ )
357
+ model_init_args = {}
358
+ model_init_args['nums_classes'] = nums_classes
359
+ model_init_args['target_names'] = target_cols
360
+ model_init_args['num_fp_bits'] = FP_BITS_DICT[fp_type]
361
+ model_init_args['hidden_size'] = hidden_size
362
+ model_init_args['num_hidden_layers'] = num_hidden_layers
363
+ model_init_args['dim'] = kwargs.get('dim', -1)
364
+ model_init_args['hard_select'] = kwargs.get('hard_select', False)
365
+ model_init_args['iterative'] = kwargs.get('iterative', True)
366
+ model_init_args['num_in_feats'] = kwargs.get('num_in_feats', 1024)
367
+
368
+ trainer = MMIterTrainer(
369
+ base_dir=base_dir,
370
+ data_loader=data_loader,
371
+ target_names=target_cols,
372
+ loss_func=loss_func,
373
+ missing_labels=missing_labels,
374
+ task_weights=kwargs.get('task_weights', None),
375
+ test_loder=None,
376
+ metrics=None,
377
+ model_name=model_name,
378
+ model_init_args=model_init_args,
379
+ ckpt_file=ckpt_file,
380
+ optimizer_name=optimizer_name,
381
+ optimizer_kwargs={
382
+ 'lr': kwargs.get('lr', 0.01),
383
+ 'weight_decay': kwargs.get('weight_decay', 0)
384
+ },
385
+ logger=None,
386
+ device=device,
387
+ parallel=kwargs.get('parallel', False)
388
+ )
389
+ trainer.train(num_epochs=num_epochs)
hdl/data/__init__.py ADDED
File without changes
File without changes
@@ -0,0 +1,98 @@
1
+ from os import path as osp
2
+ import typing as t
3
+
4
+ import torch.utils.data as tud
5
+ import pandas as pd
6
+
7
+ from jupyfuncs.dataframe import rm_index
8
+ from jupyfuncs.tensor import (
9
+ label_to_onehot,
10
+ label_to_tensor
11
+ )
12
+
13
+
14
+ def percent(x, *args, **kwargs):
15
+ return x / 100
16
+
17
+
18
+ label_trans_dict = {
19
+ 'onehot': label_to_onehot,
20
+ 'tensor': label_to_tensor,
21
+ 'percent': percent
22
+ }
23
+
24
+
25
+ class CSVDataset(tud.Dataset):
26
+ def __init__(
27
+ self,
28
+ csv_file: str,
29
+ splitter: str = ',',
30
+ smiles_col: str = 'SMILES',
31
+ target_cols: t.List[str] = [],
32
+ num_classes: t.List[int] = [],
33
+ target_transform: t.Union[str, t.List[str]] = None,
34
+ **kwargs
35
+ ) -> None:
36
+ super().__init__()
37
+ self.csv = osp.abspath(csv_file)
38
+ df = pd.read_csv(
39
+ self.csv,
40
+ sep=splitter,
41
+ **kwargs
42
+ )
43
+ self.df = rm_index(df)
44
+ self.smiles_col = smiles_col
45
+ self.target_cols = target_cols
46
+ self.num_classes = num_classes
47
+ if target_transform is not None:
48
+ if not num_classes:
49
+ self.num_classes = [1 for _ in range(len(target_cols))]
50
+ else:
51
+ assert len(self.num_classes) == len(target_cols)
52
+ if isinstance(target_transform, str):
53
+ self.target_transform = [label_trans_dict[target_transform]] * \
54
+ len(self.num_classes)
55
+ elif isinstance(target_transform, t.Iterable):
56
+ self.target_transform = [
57
+ label_trans_dict[target_trans]
58
+ for target_trans in target_transform
59
+ ]
60
+ else:
61
+ self.target_transform = None
62
+
63
+ def __getitem__(self, index):
64
+ raise NotImplementedError
65
+
66
+ def __len__(self):
67
+ return len(self.df)
68
+
69
+
70
+ class CSVRDataset(tud.Dataset):
71
+ def __init__(
72
+ self,
73
+ csv_file: str,
74
+ splitter: str,
75
+ smiles_col: str,
76
+ target_col: str = None,
77
+ missing_label: str = None,
78
+ target_transform: t.Union[str, t.List[str]] = None,
79
+ **kwargs
80
+ ) -> None:
81
+ self.csv_file = csv_file
82
+ df = pd.read_csv(
83
+ self.csv_file,
84
+ sep=splitter,
85
+ **kwargs
86
+ )
87
+ self.df = rm_index(df)
88
+ self.smiles_col = smiles_col
89
+ self.target_col = target_col
90
+ self.miss_label = missing_label
91
+ if target_transform is not None:
92
+ self.target_transform = label_trans_dict[target_transform]
93
+
94
+ def __getitem__(self, index):
95
+ raise NotImplementedError
96
+
97
+ def __len__(self):
98
+ return len(self.df)
File without changes