hjxdl 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hdl/__init__.py +0 -0
- hdl/_version.py +16 -0
- hdl/args/__init__.py +0 -0
- hdl/args/loss_args.py +5 -0
- hdl/controllers/__init__.py +0 -0
- hdl/controllers/al/__init__.py +0 -0
- hdl/controllers/al/al.py +0 -0
- hdl/controllers/al/dispatcher.py +0 -0
- hdl/controllers/al/feedback.py +0 -0
- hdl/controllers/explain/__init__.py +0 -0
- hdl/controllers/explain/shapley.py +293 -0
- hdl/controllers/explain/subgraphx.py +865 -0
- hdl/controllers/train/__init__.py +0 -0
- hdl/controllers/train/rxn_train.py +219 -0
- hdl/controllers/train/train.py +50 -0
- hdl/controllers/train/train_ginet.py +316 -0
- hdl/controllers/train/trainer_base.py +155 -0
- hdl/controllers/train/trainer_iterative.py +389 -0
- hdl/data/__init__.py +0 -0
- hdl/data/dataset/__init__.py +0 -0
- hdl/data/dataset/base_dataset.py +98 -0
- hdl/data/dataset/fp/__init__.py +0 -0
- hdl/data/dataset/fp/fp_dataset.py +122 -0
- hdl/data/dataset/graph/__init__.py +0 -0
- hdl/data/dataset/graph/chiral.py +62 -0
- hdl/data/dataset/graph/gin.py +255 -0
- hdl/data/dataset/graph/molnet.py +362 -0
- hdl/data/dataset/loaders/__init__.py +0 -0
- hdl/data/dataset/loaders/chiral_graph.py +71 -0
- hdl/data/dataset/loaders/collate_funcs/__init__.py +0 -0
- hdl/data/dataset/loaders/collate_funcs/fp.py +56 -0
- hdl/data/dataset/loaders/collate_funcs/rxn.py +40 -0
- hdl/data/dataset/loaders/general.py +23 -0
- hdl/data/dataset/loaders/spliter.py +86 -0
- hdl/data/dataset/samplers/__init__.py +0 -0
- hdl/data/dataset/samplers/chiral.py +19 -0
- hdl/data/dataset/seq/__init__.py +0 -0
- hdl/data/dataset/seq/rxn_dataset.py +61 -0
- hdl/data/dataset/utils.py +31 -0
- hdl/data/to_mols.py +0 -0
- hdl/features/__init__.py +0 -0
- hdl/features/fp/__init__.py +0 -0
- hdl/features/fp/features_generators.py +235 -0
- hdl/features/graph/__init__.py +0 -0
- hdl/features/graph/featurization.py +297 -0
- hdl/features/utils/__init__.py +0 -0
- hdl/features/utils/utils.py +111 -0
- hdl/layers/__init__.py +0 -0
- hdl/layers/general/__init__.py +0 -0
- hdl/layers/general/gp.py +14 -0
- hdl/layers/general/linear.py +641 -0
- hdl/layers/graph/__init__.py +0 -0
- hdl/layers/graph/chiral_graph.py +230 -0
- hdl/layers/graph/gcn.py +16 -0
- hdl/layers/graph/gin.py +45 -0
- hdl/layers/graph/tetra.py +158 -0
- hdl/layers/graph/transformer.py +188 -0
- hdl/layers/sequential/__init__.py +0 -0
- hdl/metric_loss/__init__.py +0 -0
- hdl/metric_loss/loss.py +79 -0
- hdl/metric_loss/metric.py +178 -0
- hdl/metric_loss/multi_label.py +42 -0
- hdl/metric_loss/nt_xent.py +65 -0
- hdl/models/__init__.py +0 -0
- hdl/models/chiral_gnn.py +176 -0
- hdl/models/fast_transformer.py +234 -0
- hdl/models/ginet.py +189 -0
- hdl/models/linear.py +137 -0
- hdl/models/model_dict.py +18 -0
- hdl/models/norm_flows.py +33 -0
- hdl/models/optim_dict.py +16 -0
- hdl/models/rxn.py +63 -0
- hdl/models/utils.py +83 -0
- hdl/ops/__init__.py +0 -0
- hdl/ops/utils.py +42 -0
- hdl/optims/__init__.py +0 -0
- hdl/optims/nadam.py +86 -0
- hdl/utils/__init__.py +0 -0
- hdl/utils/chemical_tools/__init__.py +2 -0
- hdl/utils/chemical_tools/query_info.py +149 -0
- hdl/utils/chemical_tools/sdf.py +20 -0
- hdl/utils/database_tools/__init__.py +0 -0
- hdl/utils/database_tools/connect.py +28 -0
- hdl/utils/general/__init__.py +0 -0
- hdl/utils/general/glob.py +21 -0
- hdl/utils/schedulers/__init__.py +0 -0
- hdl/utils/schedulers/norm_lr.py +108 -0
- hjxdl-0.0.1.dist-info/METADATA +19 -0
- hjxdl-0.0.1.dist-info/RECORD +91 -0
- hjxdl-0.0.1.dist-info/WHEEL +5 -0
- hjxdl-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,362 @@
|
|
1
|
+
# import os
|
2
|
+
import copy
|
3
|
+
import os.path as osp
|
4
|
+
import re
|
5
|
+
from itertools import product, repeat
|
6
|
+
from typing import List, Tuple, Dict, Optional, Union
|
7
|
+
from collections import Sequence
|
8
|
+
|
9
|
+
import torch
|
10
|
+
from torch import Tensor
|
11
|
+
from torch_geometric.data import (
|
12
|
+
InMemoryDataset, Data
|
13
|
+
)
|
14
|
+
from torch_geometric.data.dataset import Dataset, IndexType
|
15
|
+
import numpy as np
|
16
|
+
|
17
|
+
try:
|
18
|
+
from rdkit import Chem
|
19
|
+
except ImportError:
|
20
|
+
Chem = None
|
21
|
+
|
22
|
+
from jupyfuncs.pbar import tqdm
|
23
|
+
|
24
|
+
|
25
|
+
x_map = {
|
26
|
+
'atomic_num':
|
27
|
+
list(range(0, 119)),
|
28
|
+
'chirality': [
|
29
|
+
'CHI_UNSPECIFIED',
|
30
|
+
'CHI_TETRAHEDRAL_CW',
|
31
|
+
'CHI_TETRAHEDRAL_CCW',
|
32
|
+
'CHI_OTHER',
|
33
|
+
],
|
34
|
+
'degree':
|
35
|
+
list(range(0, 11)),
|
36
|
+
'formal_charge':
|
37
|
+
list(range(-5, 7)),
|
38
|
+
'num_hs':
|
39
|
+
list(range(0, 9)),
|
40
|
+
'num_radical_electrons':
|
41
|
+
list(range(0, 5)),
|
42
|
+
'hybridization': [
|
43
|
+
'UNSPECIFIED',
|
44
|
+
'S',
|
45
|
+
'SP',
|
46
|
+
'SP2',
|
47
|
+
'SP3',
|
48
|
+
'SP3D',
|
49
|
+
'SP3D2',
|
50
|
+
'OTHER',
|
51
|
+
],
|
52
|
+
'is_aromatic': [False, True],
|
53
|
+
'is_in_ring': [False, True],
|
54
|
+
}
|
55
|
+
|
56
|
+
e_map = {
|
57
|
+
'bond_type': [
|
58
|
+
'misc',
|
59
|
+
'SINGLE',
|
60
|
+
'DOUBLE',
|
61
|
+
'TRIPLE',
|
62
|
+
'AROMATIC',
|
63
|
+
],
|
64
|
+
'stereo': [
|
65
|
+
'STEREONONE',
|
66
|
+
'STEREOZ',
|
67
|
+
'STEREOE',
|
68
|
+
'STEREOCIS',
|
69
|
+
'STEREOTRANS',
|
70
|
+
'STEREOANY',
|
71
|
+
],
|
72
|
+
'is_conjugated': [False, True],
|
73
|
+
}
|
74
|
+
|
75
|
+
|
76
|
+
class MoleculeNet(torch.utils.data.Dataset):
|
77
|
+
r"""The `MoleculeNet <http://moleculenet.ai/datasets-1>`_ benchmark
|
78
|
+
collection from the `"MoleculeNet: A Benchmark for Molecular Machine
|
79
|
+
Learning" <https://arxiv.org/abs/1703.00564>`_ paper, containing datasets
|
80
|
+
from physical chemistry, biophysics and physiology.
|
81
|
+
All datasets come with the additional node and edge features introduced by
|
82
|
+
the `Open Graph Benchmark <https://ogb.stanford.edu/docs/graphprop/>`_.
|
83
|
+
Args:
|
84
|
+
root (string): Root directory where the dataset should be saved.
|
85
|
+
name (string): The name of the dataset (:obj:`"ESOL"`,
|
86
|
+
:obj:`"FreeSolv"`, :obj:`"Lipo"`, :obj:`"PCBA"`, :obj:`"MUV"`,
|
87
|
+
:obj:`"HIV"`, :obj:`"BACE"`, :obj:`"BBPB"`, :obj:`"Tox21"`,
|
88
|
+
:obj:`"ToxCast"`, :obj:`"SIDER"`, :obj:`"ClinTox"`).
|
89
|
+
transform (callable, optional): A function/transform that takes in an
|
90
|
+
:obj:`torch_geometric.data.Data` object and returns a transformed
|
91
|
+
version. The data object will be transformed before every access.
|
92
|
+
(default: :obj:`None`)
|
93
|
+
pre_transform (callable, optional): A function/transform that takes in
|
94
|
+
an :obj:`torch_geometric.data.Data` object and returns a
|
95
|
+
transformed version. The data object will be transformed before
|
96
|
+
being saved to disk. (default: :obj:`None`)
|
97
|
+
pre_filter (callable, optional): A function that takes in an
|
98
|
+
:obj:`torch_geometric.data.Data` object and returns a boolean
|
99
|
+
value, indicating whether the data object should be included in the
|
100
|
+
final dataset. (default: :obj:`None`)
|
101
|
+
"""
|
102
|
+
|
103
|
+
def __init__(self, root, file_type='smi_in_csv',
|
104
|
+
transform=None, pre_transform=None,
|
105
|
+
pre_filter=None):
|
106
|
+
|
107
|
+
self.file_type = 'smi_in_csv'
|
108
|
+
|
109
|
+
if Chem is None:
|
110
|
+
raise ImportError('`MoleculeNet` requires `rdkit`.')
|
111
|
+
|
112
|
+
self.transform = transform
|
113
|
+
self.pre_transform = pre_transform
|
114
|
+
self.pre_filter = pre_filter
|
115
|
+
|
116
|
+
# self.name = name.lower()
|
117
|
+
# assert self.name in self.names.keys()
|
118
|
+
# super(MoleculeNet, self).__init__(root, transform, pre_transform,
|
119
|
+
# pre_filter)
|
120
|
+
self.root_dir = ''
|
121
|
+
self.processed_file = osp.join(
|
122
|
+
self.root_dir, 'processed.pt'
|
123
|
+
)
|
124
|
+
self.process()
|
125
|
+
self.data, self.slices = torch.load(self.processed_file)
|
126
|
+
|
127
|
+
def process(self):
|
128
|
+
with open(self.raw_paths[0], 'r') as f:
|
129
|
+
dataset = f.read().split('\n')[1:-1]
|
130
|
+
dataset = [x for x in dataset if len(x) > 0] # Filter empty lines.
|
131
|
+
|
132
|
+
data_list = []
|
133
|
+
for line in tqdm(dataset):
|
134
|
+
line = re.sub(r'\".*\"', '', line) # Replace ".*" strings.
|
135
|
+
line = line.split(',')
|
136
|
+
|
137
|
+
smiles = line[self.names[self.name][3]]
|
138
|
+
ys = line[self.names[self.name][4]]
|
139
|
+
ys = ys if isinstance(ys, list) else [ys]
|
140
|
+
|
141
|
+
ys = [float(y) if len(y) > 0 else float('NaN') for y in ys]
|
142
|
+
y = torch.tensor(ys, dtype=torch.float).view(1, -1)
|
143
|
+
|
144
|
+
mol = Chem.MolFromSmiles(smiles)
|
145
|
+
if mol is None:
|
146
|
+
continue
|
147
|
+
|
148
|
+
xs = []
|
149
|
+
for atom in mol.GetAtoms():
|
150
|
+
x = []
|
151
|
+
x.append(x_map['atomic_num'].index(atom.GetAtomicNum()))
|
152
|
+
x.append(x_map['chirality'].index(str(atom.GetChiralTag())))
|
153
|
+
x.append(x_map['degree'].index(atom.GetTotalDegree()))
|
154
|
+
x.append(x_map['formal_charge'].index(atom.GetFormalCharge()))
|
155
|
+
x.append(x_map['num_hs'].index(atom.GetTotalNumHs()))
|
156
|
+
x.append(x_map['num_radical_electrons'].index(
|
157
|
+
atom.GetNumRadicalElectrons()))
|
158
|
+
x.append(x_map['hybridization'].index(
|
159
|
+
str(atom.GetHybridization())))
|
160
|
+
x.append(x_map['is_aromatic'].index(atom.GetIsAromatic()))
|
161
|
+
x.append(x_map['is_in_ring'].index(atom.IsInRing()))
|
162
|
+
xs.append(x)
|
163
|
+
|
164
|
+
x = torch.tensor(xs, dtype=torch.long).view(-1, 9)
|
165
|
+
|
166
|
+
edge_indices, edge_attrs = [], []
|
167
|
+
for bond in mol.GetBonds():
|
168
|
+
i = bond.GetBeginAtomIdx()
|
169
|
+
j = bond.GetEndAtomIdx()
|
170
|
+
|
171
|
+
e = []
|
172
|
+
e.append(e_map['bond_type'].index(str(bond.GetBondType())))
|
173
|
+
e.append(e_map['stereo'].index(str(bond.GetStereo())))
|
174
|
+
e.append(e_map['is_conjugated'].index(bond.GetIsConjugated()))
|
175
|
+
|
176
|
+
edge_indices += [[i, j], [j, i]]
|
177
|
+
edge_attrs += [e, e]
|
178
|
+
|
179
|
+
edge_index = torch.tensor(edge_indices)
|
180
|
+
edge_index = edge_index.t().to(torch.long).view(2, -1)
|
181
|
+
edge_attr = torch.tensor(edge_attrs, dtype=torch.long).view(-1, 3)
|
182
|
+
|
183
|
+
# Sort indices.
|
184
|
+
if edge_index.numel() > 0:
|
185
|
+
perm = (edge_index[0] * x.size(0) + edge_index[1]).argsort()
|
186
|
+
edge_index, edge_attr = edge_index[:, perm], edge_attr[perm]
|
187
|
+
|
188
|
+
data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y,
|
189
|
+
smiles=smiles)
|
190
|
+
|
191
|
+
if self.pre_filter is not None and not self.pre_filter(data):
|
192
|
+
continue
|
193
|
+
|
194
|
+
if self.pre_transform is not None:
|
195
|
+
data = self.pre_transform(data)
|
196
|
+
|
197
|
+
data_list.append(data)
|
198
|
+
|
199
|
+
torch.save(
|
200
|
+
self.collate(data_list),
|
201
|
+
self.processed_file
|
202
|
+
)
|
203
|
+
|
204
|
+
def collate(data_list: List[Data]) -> Tuple[Data, Dict[str, Tensor]]:
|
205
|
+
r"""Collates a python list of data objects to the internal storage
|
206
|
+
format of :class:`torch_geometric.data.InMemoryDataset`."""
|
207
|
+
keys = data_list[0].keys
|
208
|
+
data = data_list[0].__class__()
|
209
|
+
|
210
|
+
for key in keys:
|
211
|
+
data[key] = []
|
212
|
+
slices = {key: [0] for key in keys}
|
213
|
+
|
214
|
+
for item, key in product(data_list, keys):
|
215
|
+
data[key].append(item[key])
|
216
|
+
if isinstance(item[key], Tensor) and item[key].dim() > 0:
|
217
|
+
cat_dim = item.__cat_dim__(key, item[key])
|
218
|
+
cat_dim = 0 if cat_dim is None else cat_dim
|
219
|
+
s = slices[key][-1] + item[key].size(cat_dim)
|
220
|
+
else:
|
221
|
+
s = slices[key][-1] + 1
|
222
|
+
slices[key].append(s)
|
223
|
+
|
224
|
+
def __repr__(self):
|
225
|
+
return 'MoleculeNet ({})'.format(len(self))
|
226
|
+
|
227
|
+
def copy(self, idx: Optional[IndexType] = None):
|
228
|
+
if idx is None:
|
229
|
+
data_list = [self.get(i) for i in range(len(self))]
|
230
|
+
else:
|
231
|
+
data_list = [self.get(i) for i in self.index_select(idx).indices()]
|
232
|
+
|
233
|
+
dataset = copy.copy(self)
|
234
|
+
dataset._indices = None
|
235
|
+
dataset._data_list = data_list
|
236
|
+
dataset.data, dataset.slices = self.collate(data_list)
|
237
|
+
return dataset
|
238
|
+
|
239
|
+
@property
|
240
|
+
def num_classes(self) -> int:
|
241
|
+
r"""The number of classes in the dataset."""
|
242
|
+
y = self.data.y
|
243
|
+
if y is None:
|
244
|
+
return 0
|
245
|
+
elif y.numel() == y.size(0) and not torch.is_floating_point(y):
|
246
|
+
return int(self.data.y.max()) + 1
|
247
|
+
elif y.numel() == y.size(0) and torch.is_floating_point(y):
|
248
|
+
return torch.unique(y).numel()
|
249
|
+
else:
|
250
|
+
return self.data.y.size(-1)
|
251
|
+
|
252
|
+
def len(self) -> int:
|
253
|
+
for item in self.slices.values():
|
254
|
+
return len(item) - 1
|
255
|
+
return 0
|
256
|
+
|
257
|
+
def get(self, idx: int) -> Data:
|
258
|
+
if hasattr(self, '_data_list'):
|
259
|
+
if self._data_list is None:
|
260
|
+
self._data_list = self.len() * [None]
|
261
|
+
else:
|
262
|
+
data = self._data_list[idx]
|
263
|
+
if data is not None:
|
264
|
+
return copy.copy(data)
|
265
|
+
|
266
|
+
data = self.data.__class__()
|
267
|
+
if hasattr(self.data, '__num_nodes__'):
|
268
|
+
data.num_nodes = self.data.__num_nodes__[idx]
|
269
|
+
|
270
|
+
for key in self.data.keys:
|
271
|
+
item, slices = self.data[key], self.slices[key]
|
272
|
+
start, end = slices[idx].item(), slices[idx + 1].item()
|
273
|
+
if torch.is_tensor(item):
|
274
|
+
s = list(repeat(slice(None), item.dim()))
|
275
|
+
cat_dim = self.data.__cat_dim__(key, item)
|
276
|
+
if cat_dim is None:
|
277
|
+
cat_dim = 0
|
278
|
+
s[cat_dim] = slice(start, end)
|
279
|
+
elif start + 1 == end:
|
280
|
+
s = slices[start]
|
281
|
+
else:
|
282
|
+
s = slice(start, end)
|
283
|
+
data[key] = item[s]
|
284
|
+
|
285
|
+
if hasattr(self, '_data_list'):
|
286
|
+
self._data_list[idx] = copy.copy(data)
|
287
|
+
|
288
|
+
return data
|
289
|
+
|
290
|
+
def __len__(self) -> int:
|
291
|
+
r"""The number of examples in the dataset."""
|
292
|
+
return len(self.indices())
|
293
|
+
|
294
|
+
def __getitem__(
|
295
|
+
self,
|
296
|
+
idx: Union[int, np.integer, IndexType],
|
297
|
+
) -> Union['Dataset', Data]:
|
298
|
+
r"""In case :obj:`idx` is of type integer, will return the data object
|
299
|
+
at index :obj:`idx` (and transforms it in case :obj:`transform` is
|
300
|
+
present).
|
301
|
+
In case :obj:`idx` is a slicing object, *e.g.*, :obj:`[2:5]`, a list, a
|
302
|
+
tuple, a PyTorch :obj:`LongTensor` or a :obj:`BoolTensor`, or a numpy
|
303
|
+
:obj:`np.array`, will return a subset of the dataset at the specified
|
304
|
+
indices."""
|
305
|
+
if (isinstance(idx, (int, np.integer))
|
306
|
+
or (isinstance(idx, Tensor) and idx.dim() == 0)
|
307
|
+
or (isinstance(idx, np.ndarray) and np.isscalar(idx))):
|
308
|
+
|
309
|
+
data = self.get(self.indices()[idx])
|
310
|
+
data = data if self.transform is None else self.transform(data)
|
311
|
+
return data
|
312
|
+
|
313
|
+
else:
|
314
|
+
return self.index_select(idx)
|
315
|
+
|
316
|
+
def index_select(self, idx: IndexType) -> 'Dataset':
|
317
|
+
indices = self.indices()
|
318
|
+
|
319
|
+
if isinstance(idx, slice):
|
320
|
+
indices = indices[idx]
|
321
|
+
|
322
|
+
elif isinstance(idx, Tensor) and idx.dtype == torch.long:
|
323
|
+
return self.index_select(idx.flatten().tolist())
|
324
|
+
|
325
|
+
elif isinstance(idx, Tensor) and idx.dtype == torch.bool:
|
326
|
+
idx = idx.flatten().nonzero(as_tuple=False)
|
327
|
+
return self.index_select(idx.flatten().tolist())
|
328
|
+
|
329
|
+
elif isinstance(idx, np.ndarray) and idx.dtype == np.int64:
|
330
|
+
return self.index_select(idx.flatten().tolist())
|
331
|
+
|
332
|
+
elif isinstance(idx, np.ndarray) and idx.dtype == np.bool:
|
333
|
+
idx = idx.flatten().nonzero()[0]
|
334
|
+
return self.index_select(idx.flatten().tolist())
|
335
|
+
|
336
|
+
elif isinstance(idx, Sequence) and not isinstance(idx, str):
|
337
|
+
indices = [indices[i] for i in idx]
|
338
|
+
|
339
|
+
else:
|
340
|
+
raise IndexError(
|
341
|
+
f"Only integers, slices (':'), list, tuples, torch.tensor and "
|
342
|
+
f"np.ndarray of dtype long or bool are valid indices (got "
|
343
|
+
f"'{type(idx).__name__}')")
|
344
|
+
|
345
|
+
dataset = copy.copy(self)
|
346
|
+
dataset._indices = indices
|
347
|
+
return dataset
|
348
|
+
|
349
|
+
def shuffle(
|
350
|
+
self,
|
351
|
+
return_perm: bool = False,
|
352
|
+
) -> Union['Dataset', Tuple['Dataset', Tensor]]:
|
353
|
+
r"""Randomly shuffles the examples in the dataset.
|
354
|
+
|
355
|
+
Args:
|
356
|
+
return_perm (bool, optional): If set to :obj:`True`, will return
|
357
|
+
the random permutation used to shuffle the dataset in addition.
|
358
|
+
(default: :obj:`False`)
|
359
|
+
"""
|
360
|
+
perm = torch.randperm(len(self))
|
361
|
+
dataset = self.index_select(perm)
|
362
|
+
return (dataset, perm) if return_perm is True else dataset
|
File without changes
|
@@ -0,0 +1,71 @@
|
|
1
|
+
import typing as t
|
2
|
+
|
3
|
+
import pandas as pd
|
4
|
+
import numpy as np
|
5
|
+
from torch_geometric.loader import DataLoader
|
6
|
+
|
7
|
+
from hdl.data.dataset.graph.chiral import MolDataset
|
8
|
+
from hdl.data.dataset.samplers.chiral import StereoSampler
|
9
|
+
from hdl.data.dataset.loaders.spliter import split_data
|
10
|
+
|
11
|
+
|
12
|
+
def get_chiralgraph_loader(
|
13
|
+
data_path: str = None,
|
14
|
+
smiles_list: t.List = [],
|
15
|
+
label_list: t.List = [],
|
16
|
+
batch_size: int = 1,
|
17
|
+
shuffle: bool = False,
|
18
|
+
smiles_col: str = 'SMILES',
|
19
|
+
label_col: str = 'label',
|
20
|
+
num_workers: int = 10,
|
21
|
+
shuffle_pairs: bool = False,
|
22
|
+
chiral_features: bool = True,
|
23
|
+
global_chiral_features: bool = True
|
24
|
+
):
|
25
|
+
|
26
|
+
if data_path is not None:
|
27
|
+
data_df = pd.read_csv(data_path)
|
28
|
+
|
29
|
+
# smiles = data_df.iloc[:, 0].values
|
30
|
+
# labels = data_df.iloc[:, 1].values.astype(np.float32)
|
31
|
+
smiles = data_df[smiles_col].tolist()
|
32
|
+
labels = data_df[label_col].to_numpy()
|
33
|
+
else:
|
34
|
+
smiles = smiles_list
|
35
|
+
labels = np.array(label_list)
|
36
|
+
|
37
|
+
dataset = MolDataset(
|
38
|
+
smiles=smiles,
|
39
|
+
labels=labels,
|
40
|
+
chiral_features=chiral_features,
|
41
|
+
global_chiral_features=global_chiral_features
|
42
|
+
)
|
43
|
+
loader = DataLoader(
|
44
|
+
dataset=dataset,
|
45
|
+
batch_size=batch_size,
|
46
|
+
shuffle=shuffle,
|
47
|
+
num_workers=num_workers,
|
48
|
+
pin_memory=True,
|
49
|
+
sampler=StereoSampler(dataset) if shuffle_pairs else None)
|
50
|
+
return loader, dataset
|
51
|
+
|
52
|
+
split_loader_list = []
|
53
|
+
split_data_list = split_data(smiles, labels, split_type="random")
|
54
|
+
for split_smiles, split_labels in split_data_list:
|
55
|
+
dataset = MolDataset(
|
56
|
+
smiles=split_smiles,
|
57
|
+
labels=split_labels,
|
58
|
+
chiral_features=chiral_features,
|
59
|
+
global_chiral_features=global_chiral_features,
|
60
|
+
)
|
61
|
+
|
62
|
+
# train_dataset = dataset
|
63
|
+
loader = DataLoader(dataset=dataset,
|
64
|
+
batch_size=batch_size,
|
65
|
+
shuffle=shuffle,
|
66
|
+
num_workers=num_workers,
|
67
|
+
pin_memory=True,
|
68
|
+
sampler=StereoSampler(dataset) if shuffle_pairs else None)
|
69
|
+
split_loader_list.append(loader)
|
70
|
+
|
71
|
+
return split_loader_list, dataset
|
File without changes
|
@@ -0,0 +1,56 @@
|
|
1
|
+
r""""Contains definitions of the methods used by the _BaseDataLoaderIter workers to
|
2
|
+
collate samples fetched from dataset into Tensor(s).
|
3
|
+
|
4
|
+
These **needs** to be in global scope since Py2 doesn't support serializing
|
5
|
+
static methods.
|
6
|
+
"""
|
7
|
+
import typing as t
|
8
|
+
|
9
|
+
import numpy as np
|
10
|
+
import pandas as pd
|
11
|
+
import torch
|
12
|
+
|
13
|
+
|
14
|
+
int_types = (
|
15
|
+
int,
|
16
|
+
np.int32,
|
17
|
+
np.int64,
|
18
|
+
pd.Int16Dtype,
|
19
|
+
pd.Int32Dtype,
|
20
|
+
pd.Int64Dtype,
|
21
|
+
# torch.int32
|
22
|
+
)
|
23
|
+
|
24
|
+
|
25
|
+
def fp_collate(batch):
|
26
|
+
transposed = list(zip(*batch))
|
27
|
+
|
28
|
+
# fps
|
29
|
+
fps = list(zip(*transposed[0]))
|
30
|
+
fps = [torch.vstack(fp).float() for fp in fps]
|
31
|
+
if len(transposed) == 1:
|
32
|
+
return fps
|
33
|
+
|
34
|
+
# target_list
|
35
|
+
targets = list(zip(*transposed[-1]))
|
36
|
+
targets_list = []
|
37
|
+
for target_labels in targets:
|
38
|
+
if not isinstance(target_labels[0], t.Iterable):
|
39
|
+
target_labels = torch.Tensor(target_labels)
|
40
|
+
else:
|
41
|
+
target_labels = list(target_labels)
|
42
|
+
# if isinstance(target_labels[0], int_types):
|
43
|
+
# target_labels = torch.LongTensor(target_labels)
|
44
|
+
targets_list.append(target_labels)
|
45
|
+
|
46
|
+
# target_tensors
|
47
|
+
if len(transposed) == 3:
|
48
|
+
target_tensors = list(zip(*transposed[1]))
|
49
|
+
target_tensors = [
|
50
|
+
torch.vstack(target_tensor).float()
|
51
|
+
for target_tensor in target_tensors
|
52
|
+
]
|
53
|
+
|
54
|
+
return fps, target_tensors, targets_list
|
55
|
+
else:
|
56
|
+
return fps, targets, targets_list
|
@@ -0,0 +1,40 @@
|
|
1
|
+
from rxnfp.tokenization import (
|
2
|
+
SmilesTokenizer,
|
3
|
+
convert_reaction_to_valid_features_batch,
|
4
|
+
)
|
5
|
+
import torch
|
6
|
+
import numpy as np
|
7
|
+
import pkg_resources
|
8
|
+
|
9
|
+
|
10
|
+
__all__ = [
|
11
|
+
'collate_rxn',
|
12
|
+
]
|
13
|
+
|
14
|
+
|
15
|
+
def collate_rxn(
|
16
|
+
rxn_list,
|
17
|
+
labels,
|
18
|
+
vocab_path: str = None,
|
19
|
+
max_len: int = 512
|
20
|
+
):
|
21
|
+
if vocab_path is None:
|
22
|
+
vocab_path = pkg_resources.resource_filename(
|
23
|
+
"rxnfp",
|
24
|
+
"models/transformers/bert_ft/vocab.txt"
|
25
|
+
)
|
26
|
+
tokenizer = SmilesTokenizer(
|
27
|
+
vocab_path, max_len=max_len
|
28
|
+
)
|
29
|
+
|
30
|
+
feats = convert_reaction_to_valid_features_batch(
|
31
|
+
rxn_list,
|
32
|
+
tokenizer
|
33
|
+
)
|
34
|
+
X = [
|
35
|
+
torch.tensor(feats.input_ids.astype(np.int64)),
|
36
|
+
torch.tensor(feats.input_mask.astype(np.int64)),
|
37
|
+
torch.tensor(feats.segment_ids.astype(np.int64))
|
38
|
+
]
|
39
|
+
y = torch.LongTensor(labels)
|
40
|
+
return X, y
|
@@ -0,0 +1,23 @@
|
|
1
|
+
import typing as t
|
2
|
+
|
3
|
+
import torch.utils.data as tud
|
4
|
+
|
5
|
+
from hdl.data.dataset.loaders.collate_funcs.fp import fp_collate
|
6
|
+
|
7
|
+
|
8
|
+
class Loader(tud.DataLoader):
|
9
|
+
def __init__(
|
10
|
+
self,
|
11
|
+
dataset,
|
12
|
+
batch_size: int = 128,
|
13
|
+
shuffle: bool = True,
|
14
|
+
num_workers: int = 12,
|
15
|
+
collate_fn: t.Callable = fp_collate
|
16
|
+
):
|
17
|
+
super().__init__(
|
18
|
+
dataset,
|
19
|
+
batch_size=batch_size,
|
20
|
+
shuffle=shuffle,
|
21
|
+
num_workers=num_workers,
|
22
|
+
collate_fn=collate_fn
|
23
|
+
)
|
@@ -0,0 +1,86 @@
|
|
1
|
+
from typing import DefaultDict, Tuple
|
2
|
+
from random import Random
|
3
|
+
from collections import defaultdict
|
4
|
+
from rdkit import Chem
|
5
|
+
from rdkit.Chem.Scaffolds import MurckoScaffold
|
6
|
+
from hdl.data.dataset.graph.chiral import MolDataset
|
7
|
+
|
8
|
+
|
9
|
+
def split_data(
|
10
|
+
smis: Tuple[str],
|
11
|
+
labels: Tuple,
|
12
|
+
split_type: str = "random",
|
13
|
+
sizes: Tuple[float, float, float] = (0.8, 0.2, 0.0),
|
14
|
+
seed: int = 999,
|
15
|
+
num_folds: int = 1,
|
16
|
+
balanced: bool = True,
|
17
|
+
args=None,
|
18
|
+
) -> Tuple[Tuple[str], Tuple[str], Tuple[str]]:
|
19
|
+
random = Random(seed)
|
20
|
+
|
21
|
+
if split_type == "random":
|
22
|
+
indices = list(range(len(smis)))
|
23
|
+
random.shuffle(indices)
|
24
|
+
|
25
|
+
train_size = int(sizes[0] * len(smis))
|
26
|
+
train_val_size = int((sizes[0] + sizes[1]) * len(smis))
|
27
|
+
train = [
|
28
|
+
[smis[i] for i in indices[:train_size]],
|
29
|
+
[labels[i] for i in indices[:train_size]],
|
30
|
+
]
|
31
|
+
val = [
|
32
|
+
[smis[i] for i in indices[train_size:train_val_size]],
|
33
|
+
[labels[i] for i in indices[train_size:train_val_size]],
|
34
|
+
]
|
35
|
+
test = [
|
36
|
+
[smis[i] for i in indices[train_val_size:]],
|
37
|
+
[labels[i] for i in indices[train_val_size:]],
|
38
|
+
]
|
39
|
+
elif split_type == "scaffold_balanced":
|
40
|
+
train_size, val_size, test_size = (
|
41
|
+
sizes[0] * len(data),
|
42
|
+
sizes[1] * len(data),
|
43
|
+
sizes[2] * len(data),
|
44
|
+
)
|
45
|
+
train, val, test = [], [], []
|
46
|
+
train_scaffold_count, val_scaffold_count, test_scaffold_count = 0, 0, 0
|
47
|
+
scaffold_to_indices = defaultdict(set)
|
48
|
+
rdmols = [Chem.MolFromSmiles(s) for s in smis]
|
49
|
+
for i, rdmol in enumerate(rdmols):
|
50
|
+
scaffold = MurckoScaffold.MurckoScaffoldSmiles(
|
51
|
+
mol=rdmol, includeChirality=False
|
52
|
+
)
|
53
|
+
scaffold_to_indices[scaffold].add(i)
|
54
|
+
if balanced:
|
55
|
+
index_sets = list(scaffold_to_indices.values())
|
56
|
+
big_index_sets = []
|
57
|
+
small_index_sets = []
|
58
|
+
for index_set in index_sets:
|
59
|
+
if len(index_set) > val_size / 2 or len(index_set) > test_size / 2:
|
60
|
+
big_index_sets.append(index_set)
|
61
|
+
else:
|
62
|
+
small_index_sets.append(index_set)
|
63
|
+
random.seed(seed)
|
64
|
+
random.shuffle(big_index_sets)
|
65
|
+
random.shuffle(small_index_sets)
|
66
|
+
index_sets = big_index_sets + small_index_sets
|
67
|
+
else:
|
68
|
+
index_sets = sorted(
|
69
|
+
list(scaffold_to_indices.values()),
|
70
|
+
key=lambda index_set: len(index_set),
|
71
|
+
reverse=True,
|
72
|
+
)
|
73
|
+
for index_set in index_sets:
|
74
|
+
if len(train) + len(index_set) <= train_size:
|
75
|
+
train += index_set
|
76
|
+
train_scaffold_count += 1
|
77
|
+
elif len(val) + len(index_set) <= val_size:
|
78
|
+
val += index_set
|
79
|
+
val_scaffold_count += 1
|
80
|
+
else:
|
81
|
+
test += index_set
|
82
|
+
test_scaffold_count += 1
|
83
|
+
train = [smis[i] for i in train]
|
84
|
+
val = [smis[i] for i in val]
|
85
|
+
test = [smis[i] for i in test]
|
86
|
+
return train, val, test
|
File without changes
|
@@ -0,0 +1,19 @@
|
|
1
|
+
from itertools import chain
|
2
|
+
|
3
|
+
import numpy as np
|
4
|
+
from torch.utils.data.sampler import Sampler
|
5
|
+
|
6
|
+
|
7
|
+
class StereoSampler(Sampler):
|
8
|
+
|
9
|
+
def __init__(self, data_source):
|
10
|
+
self.data_source = data_source
|
11
|
+
|
12
|
+
def __iter__(self):
|
13
|
+
groups = [[i, i + 1] for i in range(0, len(self.data_source), 2)]
|
14
|
+
np.random.shuffle(groups)
|
15
|
+
indices = list(chain(*groups))
|
16
|
+
return iter(indices)
|
17
|
+
|
18
|
+
def __len__(self):
|
19
|
+
return len(self.data_source)
|
File without changes
|