hjxdl 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (91) hide show
  1. hdl/__init__.py +0 -0
  2. hdl/_version.py +16 -0
  3. hdl/args/__init__.py +0 -0
  4. hdl/args/loss_args.py +5 -0
  5. hdl/controllers/__init__.py +0 -0
  6. hdl/controllers/al/__init__.py +0 -0
  7. hdl/controllers/al/al.py +0 -0
  8. hdl/controllers/al/dispatcher.py +0 -0
  9. hdl/controllers/al/feedback.py +0 -0
  10. hdl/controllers/explain/__init__.py +0 -0
  11. hdl/controllers/explain/shapley.py +293 -0
  12. hdl/controllers/explain/subgraphx.py +865 -0
  13. hdl/controllers/train/__init__.py +0 -0
  14. hdl/controllers/train/rxn_train.py +219 -0
  15. hdl/controllers/train/train.py +50 -0
  16. hdl/controllers/train/train_ginet.py +316 -0
  17. hdl/controllers/train/trainer_base.py +155 -0
  18. hdl/controllers/train/trainer_iterative.py +389 -0
  19. hdl/data/__init__.py +0 -0
  20. hdl/data/dataset/__init__.py +0 -0
  21. hdl/data/dataset/base_dataset.py +98 -0
  22. hdl/data/dataset/fp/__init__.py +0 -0
  23. hdl/data/dataset/fp/fp_dataset.py +122 -0
  24. hdl/data/dataset/graph/__init__.py +0 -0
  25. hdl/data/dataset/graph/chiral.py +62 -0
  26. hdl/data/dataset/graph/gin.py +255 -0
  27. hdl/data/dataset/graph/molnet.py +362 -0
  28. hdl/data/dataset/loaders/__init__.py +0 -0
  29. hdl/data/dataset/loaders/chiral_graph.py +71 -0
  30. hdl/data/dataset/loaders/collate_funcs/__init__.py +0 -0
  31. hdl/data/dataset/loaders/collate_funcs/fp.py +56 -0
  32. hdl/data/dataset/loaders/collate_funcs/rxn.py +40 -0
  33. hdl/data/dataset/loaders/general.py +23 -0
  34. hdl/data/dataset/loaders/spliter.py +86 -0
  35. hdl/data/dataset/samplers/__init__.py +0 -0
  36. hdl/data/dataset/samplers/chiral.py +19 -0
  37. hdl/data/dataset/seq/__init__.py +0 -0
  38. hdl/data/dataset/seq/rxn_dataset.py +61 -0
  39. hdl/data/dataset/utils.py +31 -0
  40. hdl/data/to_mols.py +0 -0
  41. hdl/features/__init__.py +0 -0
  42. hdl/features/fp/__init__.py +0 -0
  43. hdl/features/fp/features_generators.py +235 -0
  44. hdl/features/graph/__init__.py +0 -0
  45. hdl/features/graph/featurization.py +297 -0
  46. hdl/features/utils/__init__.py +0 -0
  47. hdl/features/utils/utils.py +111 -0
  48. hdl/layers/__init__.py +0 -0
  49. hdl/layers/general/__init__.py +0 -0
  50. hdl/layers/general/gp.py +14 -0
  51. hdl/layers/general/linear.py +641 -0
  52. hdl/layers/graph/__init__.py +0 -0
  53. hdl/layers/graph/chiral_graph.py +230 -0
  54. hdl/layers/graph/gcn.py +16 -0
  55. hdl/layers/graph/gin.py +45 -0
  56. hdl/layers/graph/tetra.py +158 -0
  57. hdl/layers/graph/transformer.py +188 -0
  58. hdl/layers/sequential/__init__.py +0 -0
  59. hdl/metric_loss/__init__.py +0 -0
  60. hdl/metric_loss/loss.py +79 -0
  61. hdl/metric_loss/metric.py +178 -0
  62. hdl/metric_loss/multi_label.py +42 -0
  63. hdl/metric_loss/nt_xent.py +65 -0
  64. hdl/models/__init__.py +0 -0
  65. hdl/models/chiral_gnn.py +176 -0
  66. hdl/models/fast_transformer.py +234 -0
  67. hdl/models/ginet.py +189 -0
  68. hdl/models/linear.py +137 -0
  69. hdl/models/model_dict.py +18 -0
  70. hdl/models/norm_flows.py +33 -0
  71. hdl/models/optim_dict.py +16 -0
  72. hdl/models/rxn.py +63 -0
  73. hdl/models/utils.py +83 -0
  74. hdl/ops/__init__.py +0 -0
  75. hdl/ops/utils.py +42 -0
  76. hdl/optims/__init__.py +0 -0
  77. hdl/optims/nadam.py +86 -0
  78. hdl/utils/__init__.py +0 -0
  79. hdl/utils/chemical_tools/__init__.py +2 -0
  80. hdl/utils/chemical_tools/query_info.py +149 -0
  81. hdl/utils/chemical_tools/sdf.py +20 -0
  82. hdl/utils/database_tools/__init__.py +0 -0
  83. hdl/utils/database_tools/connect.py +28 -0
  84. hdl/utils/general/__init__.py +0 -0
  85. hdl/utils/general/glob.py +21 -0
  86. hdl/utils/schedulers/__init__.py +0 -0
  87. hdl/utils/schedulers/norm_lr.py +108 -0
  88. hjxdl-0.0.1.dist-info/METADATA +19 -0
  89. hjxdl-0.0.1.dist-info/RECORD +91 -0
  90. hjxdl-0.0.1.dist-info/WHEEL +5 -0
  91. hjxdl-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,122 @@
1
+ import typing as t
2
+
3
+ import torch
4
+ from rdkit import Chem
5
+ # import torch.utils.data as tud
6
+
7
+ from hdl.data.dataset.base_dataset import CSVDataset, CSVRDataset
8
+ from hdl.features.fp.features_generators import (
9
+ get_features_generator,
10
+ get_available_features_generators,
11
+ FP_BITS_DICT
12
+ )
13
+ # from hdl.features.fp.rxn import get_rxnrep_fingerprint
14
+
15
+
16
+ class FPDataset(CSVDataset):
17
+ def __init__(
18
+ self,
19
+ csv_file: str,
20
+ splitter: str,
21
+ smiles_cols: t.List,
22
+ target_cols: t.List = [],
23
+ missing_labels: t.List = [],
24
+ num_classes: t.List = [],
25
+ target_transform: t.Union[str, t.List[str]] = None,
26
+ fp_type: str = 'morgan_count',
27
+ **kwargs
28
+ ) -> None:
29
+ super().__init__(
30
+ csv_file,
31
+ splitter=splitter,
32
+ smiles_col=smiles_cols,
33
+ target_cols=target_cols,
34
+ num_classes=num_classes,
35
+ target_transform=target_transform,
36
+ **kwargs
37
+ )
38
+ self.smiles_cols = smiles_cols
39
+ assert fp_type in get_available_features_generators()
40
+ self.fp_type = fp_type
41
+ self.fp_generator = get_features_generator(self.fp_type)
42
+ self.fp_numbits = FP_BITS_DICT[self.fp_type]
43
+ self.missing_labels = missing_labels
44
+
45
+ def __getitem__(self, index):
46
+ smiles_list = self.df.loc[index][self.smiles_cols].tolist()
47
+
48
+ fingerprint_list = list(
49
+ map(
50
+ lambda x: torch.LongTensor(self.fp_generator(Chem.MolFromSmiles(x))),
51
+ smiles_list
52
+ )
53
+ )
54
+ if any(self.target_cols):
55
+ target_list = self.df.loc[index][self.target_cols].tolist()
56
+
57
+ # process with missing label
58
+ final_targets = []
59
+ for target, missing_label in zip(target_list, self.missing_labels):
60
+ if missing_label is not None and target == missing_label:
61
+ final_targets.append(float('nan'))
62
+ else:
63
+ final_targets.append(target)
64
+
65
+ if self.target_transform is None:
66
+ return fingerprint_list, final_targets
67
+ else:
68
+ # print(final_targets)
69
+ target_tensors = [
70
+ trans(target, num_class, missing_label=float('nan'))
71
+ for trans, target, num_class in zip(
72
+ self.target_transform,
73
+ final_targets,
74
+ self.num_classes
75
+ )
76
+ ]
77
+ # print(target_tensors)
78
+ return fingerprint_list, target_tensors, final_targets
79
+ else:
80
+ return fingerprint_list
81
+
82
+
83
+ class FPRDataset(CSVRDataset):
84
+ def __init__(
85
+ self,
86
+ csv_file: str,
87
+ splitter: str,
88
+ smiles_col: str,
89
+ target_col: str = None,
90
+ missing_label: str = None,
91
+ target_transform: t.Union[str, t.List[str]] = None,
92
+ fp_type: str = 'morgan_count',
93
+ **kwargs
94
+ ) -> None:
95
+ super().__init__(
96
+ csv_file,
97
+ splitter=splitter,
98
+ smiles_col=smiles_col,
99
+ target_col=target_col,
100
+ target_transform=target_transform,
101
+ missing_label=missing_label,
102
+ **kwargs
103
+ )
104
+ assert fp_type in get_available_features_generators()
105
+ self.fp_type = fp_type
106
+ self.fp_generator = get_features_generator(self.fp_type)
107
+ self.fp_numbits = FP_BITS_DICT[self.fp_type]
108
+ self.missing_label = missing_label
109
+
110
+ def __getitem__(self, index):
111
+ smiles = self.df.loc[index][self.smiles_col]
112
+ try:
113
+ fp = torch.LongTensor(self.fp_generator(Chem.MolFromSmiles(smiles)))
114
+ except Exception as _:
115
+ fp = torch.zeros(self.fp_numbits).long()
116
+
117
+ if self.target_col is not None:
118
+ target = self.df.loc[index][self.target_col]
119
+ target = (target, )
120
+ return fp, target
121
+ else:
122
+ return fp
File without changes
@@ -0,0 +1,62 @@
1
+ import typing as t
2
+
3
+ import numpy as np
4
+ import torch
5
+ from torch._C import dtype
6
+ import torch_geometric as tg
7
+ from torch_geometric.data import Dataset
8
+
9
+ from hdl.features.graph.featurization import MolGraph
10
+
11
+
12
+ class MolDataset(Dataset):
13
+
14
+ def __init__(
15
+ self,
16
+ smiles: t.List,
17
+ labels: t.List,
18
+ chiral_features: bool = False,
19
+ global_chiral_features: bool = False,
20
+ ):
21
+ super(MolDataset, self).__init__()
22
+
23
+ # self.split = list(range(len(smiles))) # fix this
24
+ # self.smiles = [smiles[i] for i in self.split]
25
+ # self.labels = [labels[i] for i in self.split]
26
+ self.smiles = smiles
27
+ self.labels = labels
28
+ # self.data_map = {k: v for k, v in zip(range(len(self.smiles)), self.split)}
29
+ # self.args = args
30
+ self.chiral_features = chiral_features
31
+ self.global_chiral_features = global_chiral_features
32
+
33
+ self.mean = np.mean(self.labels)
34
+ self.std = np.std(self.labels)
35
+
36
+ def process_key(self, key):
37
+ smi = self.smiles[key]
38
+ molgraph = MolGraph(
39
+ smi,
40
+ self.chiral_features,
41
+ self.global_chiral_features
42
+ )
43
+ mol = self.molgraph2data(molgraph, key)
44
+ return mol
45
+
46
+ def molgraph2data(self, molgraph, key):
47
+ data = tg.data.Data()
48
+ data.x = torch.tensor(molgraph.f_atoms, dtype=torch.float)
49
+ data.edge_index = torch.tensor(molgraph.edge_index, dtype=torch.long).t().contiguous()
50
+ data.edge_attr = torch.tensor(molgraph.f_bonds, dtype=torch.float)
51
+ data.y = torch.tensor([self.labels[key]], dtype=torch.float)
52
+ data.parity_atoms = torch.tensor(molgraph.parity_atoms, dtype=torch.long)
53
+ data.parity_bond_index = torch.tensor(molgraph.parity_bond_index, dtype=torch.long)
54
+ data.smiles = self.smiles[key]
55
+
56
+ return data
57
+
58
+ def __len__(self):
59
+ return len(self.smiles)
60
+
61
+ def __getitem__(self, key):
62
+ return self.process_key(key)
@@ -0,0 +1,255 @@
1
+ # import os
2
+ import csv
3
+ import math
4
+ # import time
5
+ import random
6
+ # import networkx as nx
7
+ import numpy as np
8
+ from copy import deepcopy
9
+ import typing as t
10
+
11
+ import torch
12
+ # import torch.nn.functional as F
13
+ # from torch.utils.data import Dataset, DataLoader
14
+ from torch.utils.data.sampler import SubsetRandomSampler
15
+ # import torchvision.transforms as transforms
16
+
17
+ # from torch_scatter import scatter
18
+ from torch_geometric.data import Data, Dataset
19
+ from torch_geometric.loader import DataLoader
20
+
21
+ # import rdkit
22
+ from rdkit import Chem
23
+ # from rdkit.Chem.rdchem import HybridizationType
24
+ from rdkit.Chem.rdchem import BondType as BT
25
+ # from rdkit.Chem import AllChem
26
+
27
+ from hdl.data.dataset.utils import read_smiles
28
+
29
+
30
+ __all__ = [
31
+ "MoleculeDataset",
32
+ "MoleculeDatasetWrapper"
33
+ ]
34
+
35
+
36
+ ATOM_LIST = list(range(1, 119))
37
+ CHIRALITY_LIST = [
38
+ Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
39
+ Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
40
+ Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
41
+ Chem.rdchem.ChiralType.CHI_OTHER
42
+ ]
43
+ BOND_LIST = [
44
+ BT.SINGLE,
45
+ BT.DOUBLE,
46
+ BT.TRIPLE,
47
+ BT.AROMATIC
48
+ ]
49
+ BONDDIR_LIST = [
50
+ Chem.rdchem.BondDir.NONE,
51
+ Chem.rdchem.BondDir.ENDUPRIGHT,
52
+ Chem.rdchem.BondDir.ENDDOWNRIGHT
53
+ ]
54
+
55
+
56
+ class MoleculeDataset(Dataset):
57
+ def __init__(
58
+ self,
59
+ data_path,
60
+ file_type: str = 'smi',
61
+ smi_col_names: t.List = [],
62
+ y_col_name: str = None,
63
+ ):
64
+ super(Dataset, self).__init__()
65
+ self.smiles_data = read_smiles(
66
+ data_path=data_path,
67
+ file_type=file_type,
68
+ smi_col_names=smi_col_names,
69
+ y_col_name=y_col_name
70
+ )
71
+ self.smi_col_names = smi_col_names
72
+ self.y_col_name = y_col_name
73
+
74
+ def __getitem__(
75
+ self,
76
+ idx: int
77
+ ):
78
+ if any(self.smi_col_names):
79
+ item = [
80
+ self.getitem(smiles)
81
+ for smiles in self.smiles_data[idx][: len(self.smi_col_names)]
82
+ ]
83
+ if self.y_col_name is not None:
84
+ item.append(float(self.smiles_data[idx][-1]))
85
+ return item
86
+ else:
87
+ return self.getitem(self.smiles_data[idx])
88
+
89
+ def getitem(self, smiles):
90
+ mol = Chem.MolFromSmiles(smiles)
91
+ # mol = Chem.AddHs(mol)
92
+
93
+ N = mol.GetNumAtoms()
94
+ M = mol.GetNumBonds()
95
+
96
+ type_idx = []
97
+ chirality_idx = []
98
+ atomic_number = []
99
+ # aromatic = []
100
+ # sp, sp2, sp3, sp3d = [], [], [], []
101
+ # num_hs = []
102
+ for atom in mol.GetAtoms():
103
+ type_idx.append(ATOM_LIST.index(atom.GetAtomicNum()))
104
+ chirality_idx.append(CHIRALITY_LIST.index(atom.GetChiralTag()))
105
+ atomic_number.append(atom.GetAtomicNum())
106
+ # aromatic.append(1 if atom.GetIsAromatic() else 0)
107
+ # hybridization = atom.GetHybridization()
108
+ # sp.append(1 if hybridization == HybridizationType.SP else 0)
109
+ # sp2.append(1 if hybridization == HybridizationType.SP2 else 0)
110
+ # sp3.append(1 if hybridization == HybridizationType.SP3 else 0)
111
+ # sp3d.append(1 if hybridization == HybridizationType.SP3D else 0)
112
+
113
+ # z = torch.tensor(atomic_number, dtype=torch.long)
114
+ x1 = torch.tensor(type_idx, dtype=torch.long).view(-1,1)
115
+ x2 = torch.tensor(chirality_idx, dtype=torch.long).view(-1,1)
116
+ x = torch.cat([x1, x2], dim=-1)
117
+ # x2 = torch.tensor([atomic_number, aromatic, sp, sp2, sp3, sp3d, num_hs],
118
+ # dtype=torch.float).t().contiguous()
119
+ # x = torch.cat([x1.to(torch.float), x2], dim=-1)
120
+
121
+ row, col, edge_feat = [], [], []
122
+ for bond in mol.GetBonds():
123
+ start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
124
+ row += [start, end]
125
+ col += [end, start]
126
+ # edge_type += 2 * [MOL_BONDS[bond.GetBondType()]]
127
+ edge_feat.append([
128
+ BOND_LIST.index(bond.GetBondType()),
129
+ BONDDIR_LIST.index(bond.GetBondDir())
130
+ ])
131
+ edge_feat.append([
132
+ BOND_LIST.index(bond.GetBondType()),
133
+ BONDDIR_LIST.index(bond.GetBondDir())
134
+ ])
135
+
136
+ edge_index = torch.tensor([row, col], dtype=torch.long)
137
+ edge_attr = torch.tensor(np.array(edge_feat), dtype=torch.long)
138
+
139
+ # random mask a subgraph of the molecule
140
+ num_mask_nodes = max([1, math.floor(0.25*N)])
141
+ num_mask_edges = max([0, math.floor(0.25*M)])
142
+ mask_nodes_i = random.sample(list(range(N)), num_mask_nodes)
143
+ mask_nodes_j = random.sample(list(range(N)), num_mask_nodes)
144
+
145
+ mask_edges_i_single = random.sample(list(range(M)), num_mask_edges)
146
+ mask_edges_j_single = random.sample(list(range(M)), num_mask_edges)
147
+ mask_edges_i = [2*i for i in mask_edges_i_single] + [2*i+1 for i in mask_edges_i_single]
148
+ mask_edges_j = [2*i for i in mask_edges_j_single] + [2*i+1 for i in mask_edges_j_single]
149
+
150
+ x_i = deepcopy(x)
151
+ for atom_idx in mask_nodes_i:
152
+ x_i[atom_idx,:] = torch.tensor([len(ATOM_LIST), 0])
153
+ edge_index_i = torch.zeros((2, 2*(M-num_mask_edges)), dtype=torch.long)
154
+ edge_attr_i = torch.zeros((2*(M-num_mask_edges), 2), dtype=torch.long)
155
+ count = 0
156
+ for bond_idx in range(2*M):
157
+ if bond_idx not in mask_edges_i:
158
+ edge_index_i[:,count] = edge_index[:,bond_idx]
159
+ edge_attr_i[count,:] = edge_attr[bond_idx,:]
160
+ count += 1
161
+ data_i = Data(x=x_i, edge_index=edge_index_i, edge_attr=edge_attr_i)
162
+
163
+ x_j = deepcopy(x)
164
+ for atom_idx in mask_nodes_j:
165
+ x_j[atom_idx,:] = torch.tensor([len(ATOM_LIST), 0])
166
+ edge_index_j = torch.zeros((2, 2*(M-num_mask_edges)), dtype=torch.long)
167
+ edge_attr_j = torch.zeros((2*(M-num_mask_edges), 2), dtype=torch.long)
168
+ count = 0
169
+ for bond_idx in range(2*M):
170
+ if bond_idx not in mask_edges_j:
171
+ edge_index_j[:,count] = edge_index[:,bond_idx]
172
+ edge_attr_j[count,:] = edge_attr[bond_idx,:]
173
+ count += 1
174
+ data_j = Data(x=x_j, edge_index=edge_index_j, edge_attr=edge_attr_j)
175
+
176
+ return data_i, data_j
177
+
178
+ def __len__(self):
179
+ return len(self.smiles_data)
180
+
181
+
182
+ class MoleculeDatasetWrapper(object):
183
+ def __init__(
184
+ self,
185
+ batch_size,
186
+ num_workers,
187
+ valid_size,
188
+ data_path,
189
+ file_type: str = 'smi',
190
+ smi_col_names: t.List = [],
191
+ y_col_name: str = None,
192
+ ):
193
+ super(object, self).__init__()
194
+ self.data_path = data_path
195
+ self.batch_size = batch_size
196
+ self.num_workers = num_workers
197
+ self.valid_size = valid_size
198
+ self.file_type = file_type
199
+ self.smi_col_names = smi_col_names
200
+ self.y_col_name = y_col_name
201
+
202
+ def get_data_loaders(self):
203
+ train_dataset = MoleculeDataset(data_path=self.data_path)
204
+ train_loader, valid_loader = self.get_train_validation_data_loaders(train_dataset)
205
+ return train_loader, valid_loader
206
+
207
+ def get_test_loader(self, shuffle=False):
208
+ test_dataset = MoleculeDataset(
209
+ data_path=self.data_path,
210
+ file_type=self.file_type,
211
+ smi_col_names=self.smi_col_names,
212
+ y_col_name=self.y_col_name
213
+ )
214
+ test_loader = self.get_test_data_loader(
215
+ test_dataset,
216
+ shuffle=shuffle
217
+ )
218
+ return test_loader
219
+
220
+ def get_train_validation_data_loaders(self, train_dataset):
221
+ # obtain training indices that will be used for validation
222
+ num_train = len(train_dataset)
223
+ indices = list(range(num_train))
224
+ np.random.shuffle(indices)
225
+
226
+ split = int(np.floor(self.valid_size * num_train))
227
+ train_idx, valid_idx = indices[split:], indices[:split]
228
+
229
+ # define samplers for obtaining training and validation batches
230
+ train_sampler = SubsetRandomSampler(train_idx)
231
+ valid_sampler = SubsetRandomSampler(valid_idx)
232
+
233
+ train_loader = DataLoader(train_dataset, batch_size=self.batch_size, sampler=train_sampler,
234
+ num_workers=self.num_workers, drop_last=True)
235
+
236
+ valid_loader = DataLoader(train_dataset, batch_size=self.batch_size, sampler=valid_sampler,
237
+ num_workers=self.num_workers, drop_last=True)
238
+
239
+ return train_loader, valid_loader
240
+
241
+ def get_test_data_loader(
242
+ self,
243
+ test_dataset,
244
+ shuffle=False
245
+ ):
246
+ # num_test = len(test_dataset)
247
+ # indices = list(range(num_test))
248
+ test_loader = DataLoader(
249
+ test_dataset,
250
+ batch_size=self.batch_size,
251
+ num_workers=self.num_workers,
252
+ drop_last=False,
253
+ shuffle=shuffle
254
+ )
255
+ return test_loader