hjxdl 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hdl/__init__.py +0 -0
- hdl/_version.py +16 -0
- hdl/args/__init__.py +0 -0
- hdl/args/loss_args.py +5 -0
- hdl/controllers/__init__.py +0 -0
- hdl/controllers/al/__init__.py +0 -0
- hdl/controllers/al/al.py +0 -0
- hdl/controllers/al/dispatcher.py +0 -0
- hdl/controllers/al/feedback.py +0 -0
- hdl/controllers/explain/__init__.py +0 -0
- hdl/controllers/explain/shapley.py +293 -0
- hdl/controllers/explain/subgraphx.py +865 -0
- hdl/controllers/train/__init__.py +0 -0
- hdl/controllers/train/rxn_train.py +219 -0
- hdl/controllers/train/train.py +50 -0
- hdl/controllers/train/train_ginet.py +316 -0
- hdl/controllers/train/trainer_base.py +155 -0
- hdl/controllers/train/trainer_iterative.py +389 -0
- hdl/data/__init__.py +0 -0
- hdl/data/dataset/__init__.py +0 -0
- hdl/data/dataset/base_dataset.py +98 -0
- hdl/data/dataset/fp/__init__.py +0 -0
- hdl/data/dataset/fp/fp_dataset.py +122 -0
- hdl/data/dataset/graph/__init__.py +0 -0
- hdl/data/dataset/graph/chiral.py +62 -0
- hdl/data/dataset/graph/gin.py +255 -0
- hdl/data/dataset/graph/molnet.py +362 -0
- hdl/data/dataset/loaders/__init__.py +0 -0
- hdl/data/dataset/loaders/chiral_graph.py +71 -0
- hdl/data/dataset/loaders/collate_funcs/__init__.py +0 -0
- hdl/data/dataset/loaders/collate_funcs/fp.py +56 -0
- hdl/data/dataset/loaders/collate_funcs/rxn.py +40 -0
- hdl/data/dataset/loaders/general.py +23 -0
- hdl/data/dataset/loaders/spliter.py +86 -0
- hdl/data/dataset/samplers/__init__.py +0 -0
- hdl/data/dataset/samplers/chiral.py +19 -0
- hdl/data/dataset/seq/__init__.py +0 -0
- hdl/data/dataset/seq/rxn_dataset.py +61 -0
- hdl/data/dataset/utils.py +31 -0
- hdl/data/to_mols.py +0 -0
- hdl/features/__init__.py +0 -0
- hdl/features/fp/__init__.py +0 -0
- hdl/features/fp/features_generators.py +235 -0
- hdl/features/graph/__init__.py +0 -0
- hdl/features/graph/featurization.py +297 -0
- hdl/features/utils/__init__.py +0 -0
- hdl/features/utils/utils.py +111 -0
- hdl/layers/__init__.py +0 -0
- hdl/layers/general/__init__.py +0 -0
- hdl/layers/general/gp.py +14 -0
- hdl/layers/general/linear.py +641 -0
- hdl/layers/graph/__init__.py +0 -0
- hdl/layers/graph/chiral_graph.py +230 -0
- hdl/layers/graph/gcn.py +16 -0
- hdl/layers/graph/gin.py +45 -0
- hdl/layers/graph/tetra.py +158 -0
- hdl/layers/graph/transformer.py +188 -0
- hdl/layers/sequential/__init__.py +0 -0
- hdl/metric_loss/__init__.py +0 -0
- hdl/metric_loss/loss.py +79 -0
- hdl/metric_loss/metric.py +178 -0
- hdl/metric_loss/multi_label.py +42 -0
- hdl/metric_loss/nt_xent.py +65 -0
- hdl/models/__init__.py +0 -0
- hdl/models/chiral_gnn.py +176 -0
- hdl/models/fast_transformer.py +234 -0
- hdl/models/ginet.py +189 -0
- hdl/models/linear.py +137 -0
- hdl/models/model_dict.py +18 -0
- hdl/models/norm_flows.py +33 -0
- hdl/models/optim_dict.py +16 -0
- hdl/models/rxn.py +63 -0
- hdl/models/utils.py +83 -0
- hdl/ops/__init__.py +0 -0
- hdl/ops/utils.py +42 -0
- hdl/optims/__init__.py +0 -0
- hdl/optims/nadam.py +86 -0
- hdl/utils/__init__.py +0 -0
- hdl/utils/chemical_tools/__init__.py +2 -0
- hdl/utils/chemical_tools/query_info.py +149 -0
- hdl/utils/chemical_tools/sdf.py +20 -0
- hdl/utils/database_tools/__init__.py +0 -0
- hdl/utils/database_tools/connect.py +28 -0
- hdl/utils/general/__init__.py +0 -0
- hdl/utils/general/glob.py +21 -0
- hdl/utils/schedulers/__init__.py +0 -0
- hdl/utils/schedulers/norm_lr.py +108 -0
- hjxdl-0.0.1.dist-info/METADATA +19 -0
- hjxdl-0.0.1.dist-info/RECORD +91 -0
- hjxdl-0.0.1.dist-info/WHEEL +5 -0
- hjxdl-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,122 @@
|
|
1
|
+
import typing as t
|
2
|
+
|
3
|
+
import torch
|
4
|
+
from rdkit import Chem
|
5
|
+
# import torch.utils.data as tud
|
6
|
+
|
7
|
+
from hdl.data.dataset.base_dataset import CSVDataset, CSVRDataset
|
8
|
+
from hdl.features.fp.features_generators import (
|
9
|
+
get_features_generator,
|
10
|
+
get_available_features_generators,
|
11
|
+
FP_BITS_DICT
|
12
|
+
)
|
13
|
+
# from hdl.features.fp.rxn import get_rxnrep_fingerprint
|
14
|
+
|
15
|
+
|
16
|
+
class FPDataset(CSVDataset):
|
17
|
+
def __init__(
|
18
|
+
self,
|
19
|
+
csv_file: str,
|
20
|
+
splitter: str,
|
21
|
+
smiles_cols: t.List,
|
22
|
+
target_cols: t.List = [],
|
23
|
+
missing_labels: t.List = [],
|
24
|
+
num_classes: t.List = [],
|
25
|
+
target_transform: t.Union[str, t.List[str]] = None,
|
26
|
+
fp_type: str = 'morgan_count',
|
27
|
+
**kwargs
|
28
|
+
) -> None:
|
29
|
+
super().__init__(
|
30
|
+
csv_file,
|
31
|
+
splitter=splitter,
|
32
|
+
smiles_col=smiles_cols,
|
33
|
+
target_cols=target_cols,
|
34
|
+
num_classes=num_classes,
|
35
|
+
target_transform=target_transform,
|
36
|
+
**kwargs
|
37
|
+
)
|
38
|
+
self.smiles_cols = smiles_cols
|
39
|
+
assert fp_type in get_available_features_generators()
|
40
|
+
self.fp_type = fp_type
|
41
|
+
self.fp_generator = get_features_generator(self.fp_type)
|
42
|
+
self.fp_numbits = FP_BITS_DICT[self.fp_type]
|
43
|
+
self.missing_labels = missing_labels
|
44
|
+
|
45
|
+
def __getitem__(self, index):
|
46
|
+
smiles_list = self.df.loc[index][self.smiles_cols].tolist()
|
47
|
+
|
48
|
+
fingerprint_list = list(
|
49
|
+
map(
|
50
|
+
lambda x: torch.LongTensor(self.fp_generator(Chem.MolFromSmiles(x))),
|
51
|
+
smiles_list
|
52
|
+
)
|
53
|
+
)
|
54
|
+
if any(self.target_cols):
|
55
|
+
target_list = self.df.loc[index][self.target_cols].tolist()
|
56
|
+
|
57
|
+
# process with missing label
|
58
|
+
final_targets = []
|
59
|
+
for target, missing_label in zip(target_list, self.missing_labels):
|
60
|
+
if missing_label is not None and target == missing_label:
|
61
|
+
final_targets.append(float('nan'))
|
62
|
+
else:
|
63
|
+
final_targets.append(target)
|
64
|
+
|
65
|
+
if self.target_transform is None:
|
66
|
+
return fingerprint_list, final_targets
|
67
|
+
else:
|
68
|
+
# print(final_targets)
|
69
|
+
target_tensors = [
|
70
|
+
trans(target, num_class, missing_label=float('nan'))
|
71
|
+
for trans, target, num_class in zip(
|
72
|
+
self.target_transform,
|
73
|
+
final_targets,
|
74
|
+
self.num_classes
|
75
|
+
)
|
76
|
+
]
|
77
|
+
# print(target_tensors)
|
78
|
+
return fingerprint_list, target_tensors, final_targets
|
79
|
+
else:
|
80
|
+
return fingerprint_list
|
81
|
+
|
82
|
+
|
83
|
+
class FPRDataset(CSVRDataset):
|
84
|
+
def __init__(
|
85
|
+
self,
|
86
|
+
csv_file: str,
|
87
|
+
splitter: str,
|
88
|
+
smiles_col: str,
|
89
|
+
target_col: str = None,
|
90
|
+
missing_label: str = None,
|
91
|
+
target_transform: t.Union[str, t.List[str]] = None,
|
92
|
+
fp_type: str = 'morgan_count',
|
93
|
+
**kwargs
|
94
|
+
) -> None:
|
95
|
+
super().__init__(
|
96
|
+
csv_file,
|
97
|
+
splitter=splitter,
|
98
|
+
smiles_col=smiles_col,
|
99
|
+
target_col=target_col,
|
100
|
+
target_transform=target_transform,
|
101
|
+
missing_label=missing_label,
|
102
|
+
**kwargs
|
103
|
+
)
|
104
|
+
assert fp_type in get_available_features_generators()
|
105
|
+
self.fp_type = fp_type
|
106
|
+
self.fp_generator = get_features_generator(self.fp_type)
|
107
|
+
self.fp_numbits = FP_BITS_DICT[self.fp_type]
|
108
|
+
self.missing_label = missing_label
|
109
|
+
|
110
|
+
def __getitem__(self, index):
|
111
|
+
smiles = self.df.loc[index][self.smiles_col]
|
112
|
+
try:
|
113
|
+
fp = torch.LongTensor(self.fp_generator(Chem.MolFromSmiles(smiles)))
|
114
|
+
except Exception as _:
|
115
|
+
fp = torch.zeros(self.fp_numbits).long()
|
116
|
+
|
117
|
+
if self.target_col is not None:
|
118
|
+
target = self.df.loc[index][self.target_col]
|
119
|
+
target = (target, )
|
120
|
+
return fp, target
|
121
|
+
else:
|
122
|
+
return fp
|
File without changes
|
@@ -0,0 +1,62 @@
|
|
1
|
+
import typing as t
|
2
|
+
|
3
|
+
import numpy as np
|
4
|
+
import torch
|
5
|
+
from torch._C import dtype
|
6
|
+
import torch_geometric as tg
|
7
|
+
from torch_geometric.data import Dataset
|
8
|
+
|
9
|
+
from hdl.features.graph.featurization import MolGraph
|
10
|
+
|
11
|
+
|
12
|
+
class MolDataset(Dataset):
|
13
|
+
|
14
|
+
def __init__(
|
15
|
+
self,
|
16
|
+
smiles: t.List,
|
17
|
+
labels: t.List,
|
18
|
+
chiral_features: bool = False,
|
19
|
+
global_chiral_features: bool = False,
|
20
|
+
):
|
21
|
+
super(MolDataset, self).__init__()
|
22
|
+
|
23
|
+
# self.split = list(range(len(smiles))) # fix this
|
24
|
+
# self.smiles = [smiles[i] for i in self.split]
|
25
|
+
# self.labels = [labels[i] for i in self.split]
|
26
|
+
self.smiles = smiles
|
27
|
+
self.labels = labels
|
28
|
+
# self.data_map = {k: v for k, v in zip(range(len(self.smiles)), self.split)}
|
29
|
+
# self.args = args
|
30
|
+
self.chiral_features = chiral_features
|
31
|
+
self.global_chiral_features = global_chiral_features
|
32
|
+
|
33
|
+
self.mean = np.mean(self.labels)
|
34
|
+
self.std = np.std(self.labels)
|
35
|
+
|
36
|
+
def process_key(self, key):
|
37
|
+
smi = self.smiles[key]
|
38
|
+
molgraph = MolGraph(
|
39
|
+
smi,
|
40
|
+
self.chiral_features,
|
41
|
+
self.global_chiral_features
|
42
|
+
)
|
43
|
+
mol = self.molgraph2data(molgraph, key)
|
44
|
+
return mol
|
45
|
+
|
46
|
+
def molgraph2data(self, molgraph, key):
|
47
|
+
data = tg.data.Data()
|
48
|
+
data.x = torch.tensor(molgraph.f_atoms, dtype=torch.float)
|
49
|
+
data.edge_index = torch.tensor(molgraph.edge_index, dtype=torch.long).t().contiguous()
|
50
|
+
data.edge_attr = torch.tensor(molgraph.f_bonds, dtype=torch.float)
|
51
|
+
data.y = torch.tensor([self.labels[key]], dtype=torch.float)
|
52
|
+
data.parity_atoms = torch.tensor(molgraph.parity_atoms, dtype=torch.long)
|
53
|
+
data.parity_bond_index = torch.tensor(molgraph.parity_bond_index, dtype=torch.long)
|
54
|
+
data.smiles = self.smiles[key]
|
55
|
+
|
56
|
+
return data
|
57
|
+
|
58
|
+
def __len__(self):
|
59
|
+
return len(self.smiles)
|
60
|
+
|
61
|
+
def __getitem__(self, key):
|
62
|
+
return self.process_key(key)
|
@@ -0,0 +1,255 @@
|
|
1
|
+
# import os
|
2
|
+
import csv
|
3
|
+
import math
|
4
|
+
# import time
|
5
|
+
import random
|
6
|
+
# import networkx as nx
|
7
|
+
import numpy as np
|
8
|
+
from copy import deepcopy
|
9
|
+
import typing as t
|
10
|
+
|
11
|
+
import torch
|
12
|
+
# import torch.nn.functional as F
|
13
|
+
# from torch.utils.data import Dataset, DataLoader
|
14
|
+
from torch.utils.data.sampler import SubsetRandomSampler
|
15
|
+
# import torchvision.transforms as transforms
|
16
|
+
|
17
|
+
# from torch_scatter import scatter
|
18
|
+
from torch_geometric.data import Data, Dataset
|
19
|
+
from torch_geometric.loader import DataLoader
|
20
|
+
|
21
|
+
# import rdkit
|
22
|
+
from rdkit import Chem
|
23
|
+
# from rdkit.Chem.rdchem import HybridizationType
|
24
|
+
from rdkit.Chem.rdchem import BondType as BT
|
25
|
+
# from rdkit.Chem import AllChem
|
26
|
+
|
27
|
+
from hdl.data.dataset.utils import read_smiles
|
28
|
+
|
29
|
+
|
30
|
+
__all__ = [
|
31
|
+
"MoleculeDataset",
|
32
|
+
"MoleculeDatasetWrapper"
|
33
|
+
]
|
34
|
+
|
35
|
+
|
36
|
+
ATOM_LIST = list(range(1, 119))
|
37
|
+
CHIRALITY_LIST = [
|
38
|
+
Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
|
39
|
+
Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
|
40
|
+
Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
|
41
|
+
Chem.rdchem.ChiralType.CHI_OTHER
|
42
|
+
]
|
43
|
+
BOND_LIST = [
|
44
|
+
BT.SINGLE,
|
45
|
+
BT.DOUBLE,
|
46
|
+
BT.TRIPLE,
|
47
|
+
BT.AROMATIC
|
48
|
+
]
|
49
|
+
BONDDIR_LIST = [
|
50
|
+
Chem.rdchem.BondDir.NONE,
|
51
|
+
Chem.rdchem.BondDir.ENDUPRIGHT,
|
52
|
+
Chem.rdchem.BondDir.ENDDOWNRIGHT
|
53
|
+
]
|
54
|
+
|
55
|
+
|
56
|
+
class MoleculeDataset(Dataset):
|
57
|
+
def __init__(
|
58
|
+
self,
|
59
|
+
data_path,
|
60
|
+
file_type: str = 'smi',
|
61
|
+
smi_col_names: t.List = [],
|
62
|
+
y_col_name: str = None,
|
63
|
+
):
|
64
|
+
super(Dataset, self).__init__()
|
65
|
+
self.smiles_data = read_smiles(
|
66
|
+
data_path=data_path,
|
67
|
+
file_type=file_type,
|
68
|
+
smi_col_names=smi_col_names,
|
69
|
+
y_col_name=y_col_name
|
70
|
+
)
|
71
|
+
self.smi_col_names = smi_col_names
|
72
|
+
self.y_col_name = y_col_name
|
73
|
+
|
74
|
+
def __getitem__(
|
75
|
+
self,
|
76
|
+
idx: int
|
77
|
+
):
|
78
|
+
if any(self.smi_col_names):
|
79
|
+
item = [
|
80
|
+
self.getitem(smiles)
|
81
|
+
for smiles in self.smiles_data[idx][: len(self.smi_col_names)]
|
82
|
+
]
|
83
|
+
if self.y_col_name is not None:
|
84
|
+
item.append(float(self.smiles_data[idx][-1]))
|
85
|
+
return item
|
86
|
+
else:
|
87
|
+
return self.getitem(self.smiles_data[idx])
|
88
|
+
|
89
|
+
def getitem(self, smiles):
|
90
|
+
mol = Chem.MolFromSmiles(smiles)
|
91
|
+
# mol = Chem.AddHs(mol)
|
92
|
+
|
93
|
+
N = mol.GetNumAtoms()
|
94
|
+
M = mol.GetNumBonds()
|
95
|
+
|
96
|
+
type_idx = []
|
97
|
+
chirality_idx = []
|
98
|
+
atomic_number = []
|
99
|
+
# aromatic = []
|
100
|
+
# sp, sp2, sp3, sp3d = [], [], [], []
|
101
|
+
# num_hs = []
|
102
|
+
for atom in mol.GetAtoms():
|
103
|
+
type_idx.append(ATOM_LIST.index(atom.GetAtomicNum()))
|
104
|
+
chirality_idx.append(CHIRALITY_LIST.index(atom.GetChiralTag()))
|
105
|
+
atomic_number.append(atom.GetAtomicNum())
|
106
|
+
# aromatic.append(1 if atom.GetIsAromatic() else 0)
|
107
|
+
# hybridization = atom.GetHybridization()
|
108
|
+
# sp.append(1 if hybridization == HybridizationType.SP else 0)
|
109
|
+
# sp2.append(1 if hybridization == HybridizationType.SP2 else 0)
|
110
|
+
# sp3.append(1 if hybridization == HybridizationType.SP3 else 0)
|
111
|
+
# sp3d.append(1 if hybridization == HybridizationType.SP3D else 0)
|
112
|
+
|
113
|
+
# z = torch.tensor(atomic_number, dtype=torch.long)
|
114
|
+
x1 = torch.tensor(type_idx, dtype=torch.long).view(-1,1)
|
115
|
+
x2 = torch.tensor(chirality_idx, dtype=torch.long).view(-1,1)
|
116
|
+
x = torch.cat([x1, x2], dim=-1)
|
117
|
+
# x2 = torch.tensor([atomic_number, aromatic, sp, sp2, sp3, sp3d, num_hs],
|
118
|
+
# dtype=torch.float).t().contiguous()
|
119
|
+
# x = torch.cat([x1.to(torch.float), x2], dim=-1)
|
120
|
+
|
121
|
+
row, col, edge_feat = [], [], []
|
122
|
+
for bond in mol.GetBonds():
|
123
|
+
start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
|
124
|
+
row += [start, end]
|
125
|
+
col += [end, start]
|
126
|
+
# edge_type += 2 * [MOL_BONDS[bond.GetBondType()]]
|
127
|
+
edge_feat.append([
|
128
|
+
BOND_LIST.index(bond.GetBondType()),
|
129
|
+
BONDDIR_LIST.index(bond.GetBondDir())
|
130
|
+
])
|
131
|
+
edge_feat.append([
|
132
|
+
BOND_LIST.index(bond.GetBondType()),
|
133
|
+
BONDDIR_LIST.index(bond.GetBondDir())
|
134
|
+
])
|
135
|
+
|
136
|
+
edge_index = torch.tensor([row, col], dtype=torch.long)
|
137
|
+
edge_attr = torch.tensor(np.array(edge_feat), dtype=torch.long)
|
138
|
+
|
139
|
+
# random mask a subgraph of the molecule
|
140
|
+
num_mask_nodes = max([1, math.floor(0.25*N)])
|
141
|
+
num_mask_edges = max([0, math.floor(0.25*M)])
|
142
|
+
mask_nodes_i = random.sample(list(range(N)), num_mask_nodes)
|
143
|
+
mask_nodes_j = random.sample(list(range(N)), num_mask_nodes)
|
144
|
+
|
145
|
+
mask_edges_i_single = random.sample(list(range(M)), num_mask_edges)
|
146
|
+
mask_edges_j_single = random.sample(list(range(M)), num_mask_edges)
|
147
|
+
mask_edges_i = [2*i for i in mask_edges_i_single] + [2*i+1 for i in mask_edges_i_single]
|
148
|
+
mask_edges_j = [2*i for i in mask_edges_j_single] + [2*i+1 for i in mask_edges_j_single]
|
149
|
+
|
150
|
+
x_i = deepcopy(x)
|
151
|
+
for atom_idx in mask_nodes_i:
|
152
|
+
x_i[atom_idx,:] = torch.tensor([len(ATOM_LIST), 0])
|
153
|
+
edge_index_i = torch.zeros((2, 2*(M-num_mask_edges)), dtype=torch.long)
|
154
|
+
edge_attr_i = torch.zeros((2*(M-num_mask_edges), 2), dtype=torch.long)
|
155
|
+
count = 0
|
156
|
+
for bond_idx in range(2*M):
|
157
|
+
if bond_idx not in mask_edges_i:
|
158
|
+
edge_index_i[:,count] = edge_index[:,bond_idx]
|
159
|
+
edge_attr_i[count,:] = edge_attr[bond_idx,:]
|
160
|
+
count += 1
|
161
|
+
data_i = Data(x=x_i, edge_index=edge_index_i, edge_attr=edge_attr_i)
|
162
|
+
|
163
|
+
x_j = deepcopy(x)
|
164
|
+
for atom_idx in mask_nodes_j:
|
165
|
+
x_j[atom_idx,:] = torch.tensor([len(ATOM_LIST), 0])
|
166
|
+
edge_index_j = torch.zeros((2, 2*(M-num_mask_edges)), dtype=torch.long)
|
167
|
+
edge_attr_j = torch.zeros((2*(M-num_mask_edges), 2), dtype=torch.long)
|
168
|
+
count = 0
|
169
|
+
for bond_idx in range(2*M):
|
170
|
+
if bond_idx not in mask_edges_j:
|
171
|
+
edge_index_j[:,count] = edge_index[:,bond_idx]
|
172
|
+
edge_attr_j[count,:] = edge_attr[bond_idx,:]
|
173
|
+
count += 1
|
174
|
+
data_j = Data(x=x_j, edge_index=edge_index_j, edge_attr=edge_attr_j)
|
175
|
+
|
176
|
+
return data_i, data_j
|
177
|
+
|
178
|
+
def __len__(self):
|
179
|
+
return len(self.smiles_data)
|
180
|
+
|
181
|
+
|
182
|
+
class MoleculeDatasetWrapper(object):
|
183
|
+
def __init__(
|
184
|
+
self,
|
185
|
+
batch_size,
|
186
|
+
num_workers,
|
187
|
+
valid_size,
|
188
|
+
data_path,
|
189
|
+
file_type: str = 'smi',
|
190
|
+
smi_col_names: t.List = [],
|
191
|
+
y_col_name: str = None,
|
192
|
+
):
|
193
|
+
super(object, self).__init__()
|
194
|
+
self.data_path = data_path
|
195
|
+
self.batch_size = batch_size
|
196
|
+
self.num_workers = num_workers
|
197
|
+
self.valid_size = valid_size
|
198
|
+
self.file_type = file_type
|
199
|
+
self.smi_col_names = smi_col_names
|
200
|
+
self.y_col_name = y_col_name
|
201
|
+
|
202
|
+
def get_data_loaders(self):
|
203
|
+
train_dataset = MoleculeDataset(data_path=self.data_path)
|
204
|
+
train_loader, valid_loader = self.get_train_validation_data_loaders(train_dataset)
|
205
|
+
return train_loader, valid_loader
|
206
|
+
|
207
|
+
def get_test_loader(self, shuffle=False):
|
208
|
+
test_dataset = MoleculeDataset(
|
209
|
+
data_path=self.data_path,
|
210
|
+
file_type=self.file_type,
|
211
|
+
smi_col_names=self.smi_col_names,
|
212
|
+
y_col_name=self.y_col_name
|
213
|
+
)
|
214
|
+
test_loader = self.get_test_data_loader(
|
215
|
+
test_dataset,
|
216
|
+
shuffle=shuffle
|
217
|
+
)
|
218
|
+
return test_loader
|
219
|
+
|
220
|
+
def get_train_validation_data_loaders(self, train_dataset):
|
221
|
+
# obtain training indices that will be used for validation
|
222
|
+
num_train = len(train_dataset)
|
223
|
+
indices = list(range(num_train))
|
224
|
+
np.random.shuffle(indices)
|
225
|
+
|
226
|
+
split = int(np.floor(self.valid_size * num_train))
|
227
|
+
train_idx, valid_idx = indices[split:], indices[:split]
|
228
|
+
|
229
|
+
# define samplers for obtaining training and validation batches
|
230
|
+
train_sampler = SubsetRandomSampler(train_idx)
|
231
|
+
valid_sampler = SubsetRandomSampler(valid_idx)
|
232
|
+
|
233
|
+
train_loader = DataLoader(train_dataset, batch_size=self.batch_size, sampler=train_sampler,
|
234
|
+
num_workers=self.num_workers, drop_last=True)
|
235
|
+
|
236
|
+
valid_loader = DataLoader(train_dataset, batch_size=self.batch_size, sampler=valid_sampler,
|
237
|
+
num_workers=self.num_workers, drop_last=True)
|
238
|
+
|
239
|
+
return train_loader, valid_loader
|
240
|
+
|
241
|
+
def get_test_data_loader(
|
242
|
+
self,
|
243
|
+
test_dataset,
|
244
|
+
shuffle=False
|
245
|
+
):
|
246
|
+
# num_test = len(test_dataset)
|
247
|
+
# indices = list(range(num_test))
|
248
|
+
test_loader = DataLoader(
|
249
|
+
test_dataset,
|
250
|
+
batch_size=self.batch_size,
|
251
|
+
num_workers=self.num_workers,
|
252
|
+
drop_last=False,
|
253
|
+
shuffle=shuffle
|
254
|
+
)
|
255
|
+
return test_loader
|