hjxdl 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (91) hide show
  1. hdl/__init__.py +0 -0
  2. hdl/_version.py +16 -0
  3. hdl/args/__init__.py +0 -0
  4. hdl/args/loss_args.py +5 -0
  5. hdl/controllers/__init__.py +0 -0
  6. hdl/controllers/al/__init__.py +0 -0
  7. hdl/controllers/al/al.py +0 -0
  8. hdl/controllers/al/dispatcher.py +0 -0
  9. hdl/controllers/al/feedback.py +0 -0
  10. hdl/controllers/explain/__init__.py +0 -0
  11. hdl/controllers/explain/shapley.py +293 -0
  12. hdl/controllers/explain/subgraphx.py +865 -0
  13. hdl/controllers/train/__init__.py +0 -0
  14. hdl/controllers/train/rxn_train.py +219 -0
  15. hdl/controllers/train/train.py +50 -0
  16. hdl/controllers/train/train_ginet.py +316 -0
  17. hdl/controllers/train/trainer_base.py +155 -0
  18. hdl/controllers/train/trainer_iterative.py +389 -0
  19. hdl/data/__init__.py +0 -0
  20. hdl/data/dataset/__init__.py +0 -0
  21. hdl/data/dataset/base_dataset.py +98 -0
  22. hdl/data/dataset/fp/__init__.py +0 -0
  23. hdl/data/dataset/fp/fp_dataset.py +122 -0
  24. hdl/data/dataset/graph/__init__.py +0 -0
  25. hdl/data/dataset/graph/chiral.py +62 -0
  26. hdl/data/dataset/graph/gin.py +255 -0
  27. hdl/data/dataset/graph/molnet.py +362 -0
  28. hdl/data/dataset/loaders/__init__.py +0 -0
  29. hdl/data/dataset/loaders/chiral_graph.py +71 -0
  30. hdl/data/dataset/loaders/collate_funcs/__init__.py +0 -0
  31. hdl/data/dataset/loaders/collate_funcs/fp.py +56 -0
  32. hdl/data/dataset/loaders/collate_funcs/rxn.py +40 -0
  33. hdl/data/dataset/loaders/general.py +23 -0
  34. hdl/data/dataset/loaders/spliter.py +86 -0
  35. hdl/data/dataset/samplers/__init__.py +0 -0
  36. hdl/data/dataset/samplers/chiral.py +19 -0
  37. hdl/data/dataset/seq/__init__.py +0 -0
  38. hdl/data/dataset/seq/rxn_dataset.py +61 -0
  39. hdl/data/dataset/utils.py +31 -0
  40. hdl/data/to_mols.py +0 -0
  41. hdl/features/__init__.py +0 -0
  42. hdl/features/fp/__init__.py +0 -0
  43. hdl/features/fp/features_generators.py +235 -0
  44. hdl/features/graph/__init__.py +0 -0
  45. hdl/features/graph/featurization.py +297 -0
  46. hdl/features/utils/__init__.py +0 -0
  47. hdl/features/utils/utils.py +111 -0
  48. hdl/layers/__init__.py +0 -0
  49. hdl/layers/general/__init__.py +0 -0
  50. hdl/layers/general/gp.py +14 -0
  51. hdl/layers/general/linear.py +641 -0
  52. hdl/layers/graph/__init__.py +0 -0
  53. hdl/layers/graph/chiral_graph.py +230 -0
  54. hdl/layers/graph/gcn.py +16 -0
  55. hdl/layers/graph/gin.py +45 -0
  56. hdl/layers/graph/tetra.py +158 -0
  57. hdl/layers/graph/transformer.py +188 -0
  58. hdl/layers/sequential/__init__.py +0 -0
  59. hdl/metric_loss/__init__.py +0 -0
  60. hdl/metric_loss/loss.py +79 -0
  61. hdl/metric_loss/metric.py +178 -0
  62. hdl/metric_loss/multi_label.py +42 -0
  63. hdl/metric_loss/nt_xent.py +65 -0
  64. hdl/models/__init__.py +0 -0
  65. hdl/models/chiral_gnn.py +176 -0
  66. hdl/models/fast_transformer.py +234 -0
  67. hdl/models/ginet.py +189 -0
  68. hdl/models/linear.py +137 -0
  69. hdl/models/model_dict.py +18 -0
  70. hdl/models/norm_flows.py +33 -0
  71. hdl/models/optim_dict.py +16 -0
  72. hdl/models/rxn.py +63 -0
  73. hdl/models/utils.py +83 -0
  74. hdl/ops/__init__.py +0 -0
  75. hdl/ops/utils.py +42 -0
  76. hdl/optims/__init__.py +0 -0
  77. hdl/optims/nadam.py +86 -0
  78. hdl/utils/__init__.py +0 -0
  79. hdl/utils/chemical_tools/__init__.py +2 -0
  80. hdl/utils/chemical_tools/query_info.py +149 -0
  81. hdl/utils/chemical_tools/sdf.py +20 -0
  82. hdl/utils/database_tools/__init__.py +0 -0
  83. hdl/utils/database_tools/connect.py +28 -0
  84. hdl/utils/general/__init__.py +0 -0
  85. hdl/utils/general/glob.py +21 -0
  86. hdl/utils/schedulers/__init__.py +0 -0
  87. hdl/utils/schedulers/norm_lr.py +108 -0
  88. hjxdl-0.0.1.dist-info/METADATA +19 -0
  89. hjxdl-0.0.1.dist-info/RECORD +91 -0
  90. hjxdl-0.0.1.dist-info/WHEEL +5 -0
  91. hjxdl-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,362 @@
1
+ # import os
2
+ import copy
3
+ import os.path as osp
4
+ import re
5
+ from itertools import product, repeat
6
+ from typing import List, Tuple, Dict, Optional, Union
7
+ from collections import Sequence
8
+
9
+ import torch
10
+ from torch import Tensor
11
+ from torch_geometric.data import (
12
+ InMemoryDataset, Data
13
+ )
14
+ from torch_geometric.data.dataset import Dataset, IndexType
15
+ import numpy as np
16
+
17
+ try:
18
+ from rdkit import Chem
19
+ except ImportError:
20
+ Chem = None
21
+
22
+ from jupyfuncs.pbar import tqdm
23
+
24
+
25
+ x_map = {
26
+ 'atomic_num':
27
+ list(range(0, 119)),
28
+ 'chirality': [
29
+ 'CHI_UNSPECIFIED',
30
+ 'CHI_TETRAHEDRAL_CW',
31
+ 'CHI_TETRAHEDRAL_CCW',
32
+ 'CHI_OTHER',
33
+ ],
34
+ 'degree':
35
+ list(range(0, 11)),
36
+ 'formal_charge':
37
+ list(range(-5, 7)),
38
+ 'num_hs':
39
+ list(range(0, 9)),
40
+ 'num_radical_electrons':
41
+ list(range(0, 5)),
42
+ 'hybridization': [
43
+ 'UNSPECIFIED',
44
+ 'S',
45
+ 'SP',
46
+ 'SP2',
47
+ 'SP3',
48
+ 'SP3D',
49
+ 'SP3D2',
50
+ 'OTHER',
51
+ ],
52
+ 'is_aromatic': [False, True],
53
+ 'is_in_ring': [False, True],
54
+ }
55
+
56
+ e_map = {
57
+ 'bond_type': [
58
+ 'misc',
59
+ 'SINGLE',
60
+ 'DOUBLE',
61
+ 'TRIPLE',
62
+ 'AROMATIC',
63
+ ],
64
+ 'stereo': [
65
+ 'STEREONONE',
66
+ 'STEREOZ',
67
+ 'STEREOE',
68
+ 'STEREOCIS',
69
+ 'STEREOTRANS',
70
+ 'STEREOANY',
71
+ ],
72
+ 'is_conjugated': [False, True],
73
+ }
74
+
75
+
76
+ class MoleculeNet(torch.utils.data.Dataset):
77
+ r"""The `MoleculeNet <http://moleculenet.ai/datasets-1>`_ benchmark
78
+ collection from the `"MoleculeNet: A Benchmark for Molecular Machine
79
+ Learning" <https://arxiv.org/abs/1703.00564>`_ paper, containing datasets
80
+ from physical chemistry, biophysics and physiology.
81
+ All datasets come with the additional node and edge features introduced by
82
+ the `Open Graph Benchmark <https://ogb.stanford.edu/docs/graphprop/>`_.
83
+ Args:
84
+ root (string): Root directory where the dataset should be saved.
85
+ name (string): The name of the dataset (:obj:`"ESOL"`,
86
+ :obj:`"FreeSolv"`, :obj:`"Lipo"`, :obj:`"PCBA"`, :obj:`"MUV"`,
87
+ :obj:`"HIV"`, :obj:`"BACE"`, :obj:`"BBPB"`, :obj:`"Tox21"`,
88
+ :obj:`"ToxCast"`, :obj:`"SIDER"`, :obj:`"ClinTox"`).
89
+ transform (callable, optional): A function/transform that takes in an
90
+ :obj:`torch_geometric.data.Data` object and returns a transformed
91
+ version. The data object will be transformed before every access.
92
+ (default: :obj:`None`)
93
+ pre_transform (callable, optional): A function/transform that takes in
94
+ an :obj:`torch_geometric.data.Data` object and returns a
95
+ transformed version. The data object will be transformed before
96
+ being saved to disk. (default: :obj:`None`)
97
+ pre_filter (callable, optional): A function that takes in an
98
+ :obj:`torch_geometric.data.Data` object and returns a boolean
99
+ value, indicating whether the data object should be included in the
100
+ final dataset. (default: :obj:`None`)
101
+ """
102
+
103
+ def __init__(self, root, file_type='smi_in_csv',
104
+ transform=None, pre_transform=None,
105
+ pre_filter=None):
106
+
107
+ self.file_type = 'smi_in_csv'
108
+
109
+ if Chem is None:
110
+ raise ImportError('`MoleculeNet` requires `rdkit`.')
111
+
112
+ self.transform = transform
113
+ self.pre_transform = pre_transform
114
+ self.pre_filter = pre_filter
115
+
116
+ # self.name = name.lower()
117
+ # assert self.name in self.names.keys()
118
+ # super(MoleculeNet, self).__init__(root, transform, pre_transform,
119
+ # pre_filter)
120
+ self.root_dir = ''
121
+ self.processed_file = osp.join(
122
+ self.root_dir, 'processed.pt'
123
+ )
124
+ self.process()
125
+ self.data, self.slices = torch.load(self.processed_file)
126
+
127
+ def process(self):
128
+ with open(self.raw_paths[0], 'r') as f:
129
+ dataset = f.read().split('\n')[1:-1]
130
+ dataset = [x for x in dataset if len(x) > 0] # Filter empty lines.
131
+
132
+ data_list = []
133
+ for line in tqdm(dataset):
134
+ line = re.sub(r'\".*\"', '', line) # Replace ".*" strings.
135
+ line = line.split(',')
136
+
137
+ smiles = line[self.names[self.name][3]]
138
+ ys = line[self.names[self.name][4]]
139
+ ys = ys if isinstance(ys, list) else [ys]
140
+
141
+ ys = [float(y) if len(y) > 0 else float('NaN') for y in ys]
142
+ y = torch.tensor(ys, dtype=torch.float).view(1, -1)
143
+
144
+ mol = Chem.MolFromSmiles(smiles)
145
+ if mol is None:
146
+ continue
147
+
148
+ xs = []
149
+ for atom in mol.GetAtoms():
150
+ x = []
151
+ x.append(x_map['atomic_num'].index(atom.GetAtomicNum()))
152
+ x.append(x_map['chirality'].index(str(atom.GetChiralTag())))
153
+ x.append(x_map['degree'].index(atom.GetTotalDegree()))
154
+ x.append(x_map['formal_charge'].index(atom.GetFormalCharge()))
155
+ x.append(x_map['num_hs'].index(atom.GetTotalNumHs()))
156
+ x.append(x_map['num_radical_electrons'].index(
157
+ atom.GetNumRadicalElectrons()))
158
+ x.append(x_map['hybridization'].index(
159
+ str(atom.GetHybridization())))
160
+ x.append(x_map['is_aromatic'].index(atom.GetIsAromatic()))
161
+ x.append(x_map['is_in_ring'].index(atom.IsInRing()))
162
+ xs.append(x)
163
+
164
+ x = torch.tensor(xs, dtype=torch.long).view(-1, 9)
165
+
166
+ edge_indices, edge_attrs = [], []
167
+ for bond in mol.GetBonds():
168
+ i = bond.GetBeginAtomIdx()
169
+ j = bond.GetEndAtomIdx()
170
+
171
+ e = []
172
+ e.append(e_map['bond_type'].index(str(bond.GetBondType())))
173
+ e.append(e_map['stereo'].index(str(bond.GetStereo())))
174
+ e.append(e_map['is_conjugated'].index(bond.GetIsConjugated()))
175
+
176
+ edge_indices += [[i, j], [j, i]]
177
+ edge_attrs += [e, e]
178
+
179
+ edge_index = torch.tensor(edge_indices)
180
+ edge_index = edge_index.t().to(torch.long).view(2, -1)
181
+ edge_attr = torch.tensor(edge_attrs, dtype=torch.long).view(-1, 3)
182
+
183
+ # Sort indices.
184
+ if edge_index.numel() > 0:
185
+ perm = (edge_index[0] * x.size(0) + edge_index[1]).argsort()
186
+ edge_index, edge_attr = edge_index[:, perm], edge_attr[perm]
187
+
188
+ data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y,
189
+ smiles=smiles)
190
+
191
+ if self.pre_filter is not None and not self.pre_filter(data):
192
+ continue
193
+
194
+ if self.pre_transform is not None:
195
+ data = self.pre_transform(data)
196
+
197
+ data_list.append(data)
198
+
199
+ torch.save(
200
+ self.collate(data_list),
201
+ self.processed_file
202
+ )
203
+
204
+ def collate(data_list: List[Data]) -> Tuple[Data, Dict[str, Tensor]]:
205
+ r"""Collates a python list of data objects to the internal storage
206
+ format of :class:`torch_geometric.data.InMemoryDataset`."""
207
+ keys = data_list[0].keys
208
+ data = data_list[0].__class__()
209
+
210
+ for key in keys:
211
+ data[key] = []
212
+ slices = {key: [0] for key in keys}
213
+
214
+ for item, key in product(data_list, keys):
215
+ data[key].append(item[key])
216
+ if isinstance(item[key], Tensor) and item[key].dim() > 0:
217
+ cat_dim = item.__cat_dim__(key, item[key])
218
+ cat_dim = 0 if cat_dim is None else cat_dim
219
+ s = slices[key][-1] + item[key].size(cat_dim)
220
+ else:
221
+ s = slices[key][-1] + 1
222
+ slices[key].append(s)
223
+
224
+ def __repr__(self):
225
+ return 'MoleculeNet ({})'.format(len(self))
226
+
227
+ def copy(self, idx: Optional[IndexType] = None):
228
+ if idx is None:
229
+ data_list = [self.get(i) for i in range(len(self))]
230
+ else:
231
+ data_list = [self.get(i) for i in self.index_select(idx).indices()]
232
+
233
+ dataset = copy.copy(self)
234
+ dataset._indices = None
235
+ dataset._data_list = data_list
236
+ dataset.data, dataset.slices = self.collate(data_list)
237
+ return dataset
238
+
239
+ @property
240
+ def num_classes(self) -> int:
241
+ r"""The number of classes in the dataset."""
242
+ y = self.data.y
243
+ if y is None:
244
+ return 0
245
+ elif y.numel() == y.size(0) and not torch.is_floating_point(y):
246
+ return int(self.data.y.max()) + 1
247
+ elif y.numel() == y.size(0) and torch.is_floating_point(y):
248
+ return torch.unique(y).numel()
249
+ else:
250
+ return self.data.y.size(-1)
251
+
252
+ def len(self) -> int:
253
+ for item in self.slices.values():
254
+ return len(item) - 1
255
+ return 0
256
+
257
+ def get(self, idx: int) -> Data:
258
+ if hasattr(self, '_data_list'):
259
+ if self._data_list is None:
260
+ self._data_list = self.len() * [None]
261
+ else:
262
+ data = self._data_list[idx]
263
+ if data is not None:
264
+ return copy.copy(data)
265
+
266
+ data = self.data.__class__()
267
+ if hasattr(self.data, '__num_nodes__'):
268
+ data.num_nodes = self.data.__num_nodes__[idx]
269
+
270
+ for key in self.data.keys:
271
+ item, slices = self.data[key], self.slices[key]
272
+ start, end = slices[idx].item(), slices[idx + 1].item()
273
+ if torch.is_tensor(item):
274
+ s = list(repeat(slice(None), item.dim()))
275
+ cat_dim = self.data.__cat_dim__(key, item)
276
+ if cat_dim is None:
277
+ cat_dim = 0
278
+ s[cat_dim] = slice(start, end)
279
+ elif start + 1 == end:
280
+ s = slices[start]
281
+ else:
282
+ s = slice(start, end)
283
+ data[key] = item[s]
284
+
285
+ if hasattr(self, '_data_list'):
286
+ self._data_list[idx] = copy.copy(data)
287
+
288
+ return data
289
+
290
+ def __len__(self) -> int:
291
+ r"""The number of examples in the dataset."""
292
+ return len(self.indices())
293
+
294
+ def __getitem__(
295
+ self,
296
+ idx: Union[int, np.integer, IndexType],
297
+ ) -> Union['Dataset', Data]:
298
+ r"""In case :obj:`idx` is of type integer, will return the data object
299
+ at index :obj:`idx` (and transforms it in case :obj:`transform` is
300
+ present).
301
+ In case :obj:`idx` is a slicing object, *e.g.*, :obj:`[2:5]`, a list, a
302
+ tuple, a PyTorch :obj:`LongTensor` or a :obj:`BoolTensor`, or a numpy
303
+ :obj:`np.array`, will return a subset of the dataset at the specified
304
+ indices."""
305
+ if (isinstance(idx, (int, np.integer))
306
+ or (isinstance(idx, Tensor) and idx.dim() == 0)
307
+ or (isinstance(idx, np.ndarray) and np.isscalar(idx))):
308
+
309
+ data = self.get(self.indices()[idx])
310
+ data = data if self.transform is None else self.transform(data)
311
+ return data
312
+
313
+ else:
314
+ return self.index_select(idx)
315
+
316
+ def index_select(self, idx: IndexType) -> 'Dataset':
317
+ indices = self.indices()
318
+
319
+ if isinstance(idx, slice):
320
+ indices = indices[idx]
321
+
322
+ elif isinstance(idx, Tensor) and idx.dtype == torch.long:
323
+ return self.index_select(idx.flatten().tolist())
324
+
325
+ elif isinstance(idx, Tensor) and idx.dtype == torch.bool:
326
+ idx = idx.flatten().nonzero(as_tuple=False)
327
+ return self.index_select(idx.flatten().tolist())
328
+
329
+ elif isinstance(idx, np.ndarray) and idx.dtype == np.int64:
330
+ return self.index_select(idx.flatten().tolist())
331
+
332
+ elif isinstance(idx, np.ndarray) and idx.dtype == np.bool:
333
+ idx = idx.flatten().nonzero()[0]
334
+ return self.index_select(idx.flatten().tolist())
335
+
336
+ elif isinstance(idx, Sequence) and not isinstance(idx, str):
337
+ indices = [indices[i] for i in idx]
338
+
339
+ else:
340
+ raise IndexError(
341
+ f"Only integers, slices (':'), list, tuples, torch.tensor and "
342
+ f"np.ndarray of dtype long or bool are valid indices (got "
343
+ f"'{type(idx).__name__}')")
344
+
345
+ dataset = copy.copy(self)
346
+ dataset._indices = indices
347
+ return dataset
348
+
349
+ def shuffle(
350
+ self,
351
+ return_perm: bool = False,
352
+ ) -> Union['Dataset', Tuple['Dataset', Tensor]]:
353
+ r"""Randomly shuffles the examples in the dataset.
354
+
355
+ Args:
356
+ return_perm (bool, optional): If set to :obj:`True`, will return
357
+ the random permutation used to shuffle the dataset in addition.
358
+ (default: :obj:`False`)
359
+ """
360
+ perm = torch.randperm(len(self))
361
+ dataset = self.index_select(perm)
362
+ return (dataset, perm) if return_perm is True else dataset
File without changes
@@ -0,0 +1,71 @@
1
+ import typing as t
2
+
3
+ import pandas as pd
4
+ import numpy as np
5
+ from torch_geometric.loader import DataLoader
6
+
7
+ from hdl.data.dataset.graph.chiral import MolDataset
8
+ from hdl.data.dataset.samplers.chiral import StereoSampler
9
+ from hdl.data.dataset.loaders.spliter import split_data
10
+
11
+
12
+ def get_chiralgraph_loader(
13
+ data_path: str = None,
14
+ smiles_list: t.List = [],
15
+ label_list: t.List = [],
16
+ batch_size: int = 1,
17
+ shuffle: bool = False,
18
+ smiles_col: str = 'SMILES',
19
+ label_col: str = 'label',
20
+ num_workers: int = 10,
21
+ shuffle_pairs: bool = False,
22
+ chiral_features: bool = True,
23
+ global_chiral_features: bool = True
24
+ ):
25
+
26
+ if data_path is not None:
27
+ data_df = pd.read_csv(data_path)
28
+
29
+ # smiles = data_df.iloc[:, 0].values
30
+ # labels = data_df.iloc[:, 1].values.astype(np.float32)
31
+ smiles = data_df[smiles_col].tolist()
32
+ labels = data_df[label_col].to_numpy()
33
+ else:
34
+ smiles = smiles_list
35
+ labels = np.array(label_list)
36
+
37
+ dataset = MolDataset(
38
+ smiles=smiles,
39
+ labels=labels,
40
+ chiral_features=chiral_features,
41
+ global_chiral_features=global_chiral_features
42
+ )
43
+ loader = DataLoader(
44
+ dataset=dataset,
45
+ batch_size=batch_size,
46
+ shuffle=shuffle,
47
+ num_workers=num_workers,
48
+ pin_memory=True,
49
+ sampler=StereoSampler(dataset) if shuffle_pairs else None)
50
+ return loader, dataset
51
+
52
+ split_loader_list = []
53
+ split_data_list = split_data(smiles, labels, split_type="random")
54
+ for split_smiles, split_labels in split_data_list:
55
+ dataset = MolDataset(
56
+ smiles=split_smiles,
57
+ labels=split_labels,
58
+ chiral_features=chiral_features,
59
+ global_chiral_features=global_chiral_features,
60
+ )
61
+
62
+ # train_dataset = dataset
63
+ loader = DataLoader(dataset=dataset,
64
+ batch_size=batch_size,
65
+ shuffle=shuffle,
66
+ num_workers=num_workers,
67
+ pin_memory=True,
68
+ sampler=StereoSampler(dataset) if shuffle_pairs else None)
69
+ split_loader_list.append(loader)
70
+
71
+ return split_loader_list, dataset
File without changes
@@ -0,0 +1,56 @@
1
+ r""""Contains definitions of the methods used by the _BaseDataLoaderIter workers to
2
+ collate samples fetched from dataset into Tensor(s).
3
+
4
+ These **needs** to be in global scope since Py2 doesn't support serializing
5
+ static methods.
6
+ """
7
+ import typing as t
8
+
9
+ import numpy as np
10
+ import pandas as pd
11
+ import torch
12
+
13
+
14
+ int_types = (
15
+ int,
16
+ np.int32,
17
+ np.int64,
18
+ pd.Int16Dtype,
19
+ pd.Int32Dtype,
20
+ pd.Int64Dtype,
21
+ # torch.int32
22
+ )
23
+
24
+
25
+ def fp_collate(batch):
26
+ transposed = list(zip(*batch))
27
+
28
+ # fps
29
+ fps = list(zip(*transposed[0]))
30
+ fps = [torch.vstack(fp).float() for fp in fps]
31
+ if len(transposed) == 1:
32
+ return fps
33
+
34
+ # target_list
35
+ targets = list(zip(*transposed[-1]))
36
+ targets_list = []
37
+ for target_labels in targets:
38
+ if not isinstance(target_labels[0], t.Iterable):
39
+ target_labels = torch.Tensor(target_labels)
40
+ else:
41
+ target_labels = list(target_labels)
42
+ # if isinstance(target_labels[0], int_types):
43
+ # target_labels = torch.LongTensor(target_labels)
44
+ targets_list.append(target_labels)
45
+
46
+ # target_tensors
47
+ if len(transposed) == 3:
48
+ target_tensors = list(zip(*transposed[1]))
49
+ target_tensors = [
50
+ torch.vstack(target_tensor).float()
51
+ for target_tensor in target_tensors
52
+ ]
53
+
54
+ return fps, target_tensors, targets_list
55
+ else:
56
+ return fps, targets, targets_list
@@ -0,0 +1,40 @@
1
+ from rxnfp.tokenization import (
2
+ SmilesTokenizer,
3
+ convert_reaction_to_valid_features_batch,
4
+ )
5
+ import torch
6
+ import numpy as np
7
+ import pkg_resources
8
+
9
+
10
+ __all__ = [
11
+ 'collate_rxn',
12
+ ]
13
+
14
+
15
+ def collate_rxn(
16
+ rxn_list,
17
+ labels,
18
+ vocab_path: str = None,
19
+ max_len: int = 512
20
+ ):
21
+ if vocab_path is None:
22
+ vocab_path = pkg_resources.resource_filename(
23
+ "rxnfp",
24
+ "models/transformers/bert_ft/vocab.txt"
25
+ )
26
+ tokenizer = SmilesTokenizer(
27
+ vocab_path, max_len=max_len
28
+ )
29
+
30
+ feats = convert_reaction_to_valid_features_batch(
31
+ rxn_list,
32
+ tokenizer
33
+ )
34
+ X = [
35
+ torch.tensor(feats.input_ids.astype(np.int64)),
36
+ torch.tensor(feats.input_mask.astype(np.int64)),
37
+ torch.tensor(feats.segment_ids.astype(np.int64))
38
+ ]
39
+ y = torch.LongTensor(labels)
40
+ return X, y
@@ -0,0 +1,23 @@
1
+ import typing as t
2
+
3
+ import torch.utils.data as tud
4
+
5
+ from hdl.data.dataset.loaders.collate_funcs.fp import fp_collate
6
+
7
+
8
+ class Loader(tud.DataLoader):
9
+ def __init__(
10
+ self,
11
+ dataset,
12
+ batch_size: int = 128,
13
+ shuffle: bool = True,
14
+ num_workers: int = 12,
15
+ collate_fn: t.Callable = fp_collate
16
+ ):
17
+ super().__init__(
18
+ dataset,
19
+ batch_size=batch_size,
20
+ shuffle=shuffle,
21
+ num_workers=num_workers,
22
+ collate_fn=collate_fn
23
+ )
@@ -0,0 +1,86 @@
1
+ from typing import DefaultDict, Tuple
2
+ from random import Random
3
+ from collections import defaultdict
4
+ from rdkit import Chem
5
+ from rdkit.Chem.Scaffolds import MurckoScaffold
6
+ from hdl.data.dataset.graph.chiral import MolDataset
7
+
8
+
9
+ def split_data(
10
+ smis: Tuple[str],
11
+ labels: Tuple,
12
+ split_type: str = "random",
13
+ sizes: Tuple[float, float, float] = (0.8, 0.2, 0.0),
14
+ seed: int = 999,
15
+ num_folds: int = 1,
16
+ balanced: bool = True,
17
+ args=None,
18
+ ) -> Tuple[Tuple[str], Tuple[str], Tuple[str]]:
19
+ random = Random(seed)
20
+
21
+ if split_type == "random":
22
+ indices = list(range(len(smis)))
23
+ random.shuffle(indices)
24
+
25
+ train_size = int(sizes[0] * len(smis))
26
+ train_val_size = int((sizes[0] + sizes[1]) * len(smis))
27
+ train = [
28
+ [smis[i] for i in indices[:train_size]],
29
+ [labels[i] for i in indices[:train_size]],
30
+ ]
31
+ val = [
32
+ [smis[i] for i in indices[train_size:train_val_size]],
33
+ [labels[i] for i in indices[train_size:train_val_size]],
34
+ ]
35
+ test = [
36
+ [smis[i] for i in indices[train_val_size:]],
37
+ [labels[i] for i in indices[train_val_size:]],
38
+ ]
39
+ elif split_type == "scaffold_balanced":
40
+ train_size, val_size, test_size = (
41
+ sizes[0] * len(data),
42
+ sizes[1] * len(data),
43
+ sizes[2] * len(data),
44
+ )
45
+ train, val, test = [], [], []
46
+ train_scaffold_count, val_scaffold_count, test_scaffold_count = 0, 0, 0
47
+ scaffold_to_indices = defaultdict(set)
48
+ rdmols = [Chem.MolFromSmiles(s) for s in smis]
49
+ for i, rdmol in enumerate(rdmols):
50
+ scaffold = MurckoScaffold.MurckoScaffoldSmiles(
51
+ mol=rdmol, includeChirality=False
52
+ )
53
+ scaffold_to_indices[scaffold].add(i)
54
+ if balanced:
55
+ index_sets = list(scaffold_to_indices.values())
56
+ big_index_sets = []
57
+ small_index_sets = []
58
+ for index_set in index_sets:
59
+ if len(index_set) > val_size / 2 or len(index_set) > test_size / 2:
60
+ big_index_sets.append(index_set)
61
+ else:
62
+ small_index_sets.append(index_set)
63
+ random.seed(seed)
64
+ random.shuffle(big_index_sets)
65
+ random.shuffle(small_index_sets)
66
+ index_sets = big_index_sets + small_index_sets
67
+ else:
68
+ index_sets = sorted(
69
+ list(scaffold_to_indices.values()),
70
+ key=lambda index_set: len(index_set),
71
+ reverse=True,
72
+ )
73
+ for index_set in index_sets:
74
+ if len(train) + len(index_set) <= train_size:
75
+ train += index_set
76
+ train_scaffold_count += 1
77
+ elif len(val) + len(index_set) <= val_size:
78
+ val += index_set
79
+ val_scaffold_count += 1
80
+ else:
81
+ test += index_set
82
+ test_scaffold_count += 1
83
+ train = [smis[i] for i in train]
84
+ val = [smis[i] for i in val]
85
+ test = [smis[i] for i in test]
86
+ return train, val, test
File without changes
@@ -0,0 +1,19 @@
1
+ from itertools import chain
2
+
3
+ import numpy as np
4
+ from torch.utils.data.sampler import Sampler
5
+
6
+
7
+ class StereoSampler(Sampler):
8
+
9
+ def __init__(self, data_source):
10
+ self.data_source = data_source
11
+
12
+ def __iter__(self):
13
+ groups = [[i, i + 1] for i in range(0, len(self.data_source), 2)]
14
+ np.random.shuffle(groups)
15
+ indices = list(chain(*groups))
16
+ return iter(indices)
17
+
18
+ def __len__(self):
19
+ return len(self.data_source)
File without changes