hjxdl 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (91) hide show
  1. hdl/__init__.py +0 -0
  2. hdl/_version.py +16 -0
  3. hdl/args/__init__.py +0 -0
  4. hdl/args/loss_args.py +5 -0
  5. hdl/controllers/__init__.py +0 -0
  6. hdl/controllers/al/__init__.py +0 -0
  7. hdl/controllers/al/al.py +0 -0
  8. hdl/controllers/al/dispatcher.py +0 -0
  9. hdl/controllers/al/feedback.py +0 -0
  10. hdl/controllers/explain/__init__.py +0 -0
  11. hdl/controllers/explain/shapley.py +293 -0
  12. hdl/controllers/explain/subgraphx.py +865 -0
  13. hdl/controllers/train/__init__.py +0 -0
  14. hdl/controllers/train/rxn_train.py +219 -0
  15. hdl/controllers/train/train.py +50 -0
  16. hdl/controllers/train/train_ginet.py +316 -0
  17. hdl/controllers/train/trainer_base.py +155 -0
  18. hdl/controllers/train/trainer_iterative.py +389 -0
  19. hdl/data/__init__.py +0 -0
  20. hdl/data/dataset/__init__.py +0 -0
  21. hdl/data/dataset/base_dataset.py +98 -0
  22. hdl/data/dataset/fp/__init__.py +0 -0
  23. hdl/data/dataset/fp/fp_dataset.py +122 -0
  24. hdl/data/dataset/graph/__init__.py +0 -0
  25. hdl/data/dataset/graph/chiral.py +62 -0
  26. hdl/data/dataset/graph/gin.py +255 -0
  27. hdl/data/dataset/graph/molnet.py +362 -0
  28. hdl/data/dataset/loaders/__init__.py +0 -0
  29. hdl/data/dataset/loaders/chiral_graph.py +71 -0
  30. hdl/data/dataset/loaders/collate_funcs/__init__.py +0 -0
  31. hdl/data/dataset/loaders/collate_funcs/fp.py +56 -0
  32. hdl/data/dataset/loaders/collate_funcs/rxn.py +40 -0
  33. hdl/data/dataset/loaders/general.py +23 -0
  34. hdl/data/dataset/loaders/spliter.py +86 -0
  35. hdl/data/dataset/samplers/__init__.py +0 -0
  36. hdl/data/dataset/samplers/chiral.py +19 -0
  37. hdl/data/dataset/seq/__init__.py +0 -0
  38. hdl/data/dataset/seq/rxn_dataset.py +61 -0
  39. hdl/data/dataset/utils.py +31 -0
  40. hdl/data/to_mols.py +0 -0
  41. hdl/features/__init__.py +0 -0
  42. hdl/features/fp/__init__.py +0 -0
  43. hdl/features/fp/features_generators.py +235 -0
  44. hdl/features/graph/__init__.py +0 -0
  45. hdl/features/graph/featurization.py +297 -0
  46. hdl/features/utils/__init__.py +0 -0
  47. hdl/features/utils/utils.py +111 -0
  48. hdl/layers/__init__.py +0 -0
  49. hdl/layers/general/__init__.py +0 -0
  50. hdl/layers/general/gp.py +14 -0
  51. hdl/layers/general/linear.py +641 -0
  52. hdl/layers/graph/__init__.py +0 -0
  53. hdl/layers/graph/chiral_graph.py +230 -0
  54. hdl/layers/graph/gcn.py +16 -0
  55. hdl/layers/graph/gin.py +45 -0
  56. hdl/layers/graph/tetra.py +158 -0
  57. hdl/layers/graph/transformer.py +188 -0
  58. hdl/layers/sequential/__init__.py +0 -0
  59. hdl/metric_loss/__init__.py +0 -0
  60. hdl/metric_loss/loss.py +79 -0
  61. hdl/metric_loss/metric.py +178 -0
  62. hdl/metric_loss/multi_label.py +42 -0
  63. hdl/metric_loss/nt_xent.py +65 -0
  64. hdl/models/__init__.py +0 -0
  65. hdl/models/chiral_gnn.py +176 -0
  66. hdl/models/fast_transformer.py +234 -0
  67. hdl/models/ginet.py +189 -0
  68. hdl/models/linear.py +137 -0
  69. hdl/models/model_dict.py +18 -0
  70. hdl/models/norm_flows.py +33 -0
  71. hdl/models/optim_dict.py +16 -0
  72. hdl/models/rxn.py +63 -0
  73. hdl/models/utils.py +83 -0
  74. hdl/ops/__init__.py +0 -0
  75. hdl/ops/utils.py +42 -0
  76. hdl/optims/__init__.py +0 -0
  77. hdl/optims/nadam.py +86 -0
  78. hdl/utils/__init__.py +0 -0
  79. hdl/utils/chemical_tools/__init__.py +2 -0
  80. hdl/utils/chemical_tools/query_info.py +149 -0
  81. hdl/utils/chemical_tools/sdf.py +20 -0
  82. hdl/utils/database_tools/__init__.py +0 -0
  83. hdl/utils/database_tools/connect.py +28 -0
  84. hdl/utils/general/__init__.py +0 -0
  85. hdl/utils/general/glob.py +21 -0
  86. hdl/utils/schedulers/__init__.py +0 -0
  87. hdl/utils/schedulers/norm_lr.py +108 -0
  88. hjxdl-0.0.1.dist-info/METADATA +19 -0
  89. hjxdl-0.0.1.dist-info/RECORD +91 -0
  90. hjxdl-0.0.1.dist-info/WHEEL +5 -0
  91. hjxdl-0.0.1.dist-info/top_level.txt +1 -0
hdl/__init__.py ADDED
File without changes
hdl/_version.py ADDED
@@ -0,0 +1,16 @@
1
+ # file generated by setuptools_scm
2
+ # don't change, don't track in version control
3
+ TYPE_CHECKING = False
4
+ if TYPE_CHECKING:
5
+ from typing import Tuple, Union
6
+ VERSION_TUPLE = Tuple[Union[int, str], ...]
7
+ else:
8
+ VERSION_TUPLE = object
9
+
10
+ version: str
11
+ __version__: str
12
+ __version_tuple__: VERSION_TUPLE
13
+ version_tuple: VERSION_TUPLE
14
+
15
+ __version__ = version = '0.0.1'
16
+ __version_tuple__ = version_tuple = (0, 0, 1)
hdl/args/__init__.py ADDED
File without changes
hdl/args/loss_args.py ADDED
@@ -0,0 +1,5 @@
1
+ from tap import Tap
2
+
3
+
4
+ class LossArgs(Tap):
5
+ reduction: str = 'mean'
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
@@ -0,0 +1,293 @@
1
+ import copy
2
+ import torch
3
+ import numpy as np
4
+ from scipy.special import comb
5
+ from itertools import combinations
6
+ import torch.nn.functional as F
7
+ from torch_geometric.utils import to_networkx
8
+ from torch_geometric.data import Data, Batch, Dataset, DataLoader
9
+
10
+
11
+ def GnnNetsGC2valueFunc(gnnNets, target_class):
12
+ def value_func(batch):
13
+ with torch.no_grad():
14
+ logits = gnnNets(data=batch)
15
+ probs = F.softmax(logits, dim=-1)
16
+ score = probs[:, target_class]
17
+ return score
18
+ return value_func
19
+
20
+
21
+ def GnnNetsNC2valueFunc(gnnNets_NC, node_idx, target_class):
22
+ def value_func(data):
23
+ with torch.no_grad():
24
+ logits = gnnNets_NC(data=data)
25
+ probs = F.softmax(logits, dim=-1)
26
+ # select the corresponding node prob through the node idx on all the sampling graphs
27
+ batch_size = data.batch.max() + 1
28
+ probs = probs.reshape(batch_size, -1, probs.shape[-1])
29
+ score = probs[:, node_idx, target_class]
30
+ return score
31
+ return value_func
32
+
33
+
34
+ def get_graph_build_func(build_method):
35
+ if build_method.lower() == 'zero_filling':
36
+ return graph_build_zero_filling
37
+ elif build_method.lower() == 'split':
38
+ return graph_build_split
39
+ else:
40
+ raise NotImplementedError
41
+
42
+
43
+ class MarginalSubgraphDataset(Dataset):
44
+ def __init__(self, data, exclude_mask, include_mask, subgraph_build_func):
45
+ self.num_nodes = data.num_nodes
46
+ self.X = data.x
47
+ self.edge_index = data.edge_index
48
+ self.device = self.X.device
49
+
50
+ self.label = data.y
51
+ self.exclude_mask = torch.tensor(exclude_mask).type(torch.float32).to(self.device)
52
+ self.include_mask = torch.tensor(include_mask).type(torch.float32).to(self.device)
53
+ self.subgraph_build_func = subgraph_build_func
54
+
55
+ def __len__(self):
56
+ return self.exclude_mask.shape[0]
57
+
58
+ def __getitem__(self, idx):
59
+ exclude_graph_X, exclude_graph_edge_index = self.subgraph_build_func(self.X, self.edge_index, self.exclude_mask[idx])
60
+ include_graph_X, include_graph_edge_index = self.subgraph_build_func(self.X, self.edge_index, self.include_mask[idx])
61
+ exclude_data = Data(x=exclude_graph_X, edge_index=exclude_graph_edge_index)
62
+ include_data = Data(x=include_graph_X, edge_index=include_graph_edge_index)
63
+ return exclude_data, include_data
64
+
65
+
66
+ def marginal_contribution(data: Data, exclude_mask: np.array, include_mask: np.array,
67
+ value_func, subgraph_build_func):
68
+ """ Calculate the marginal value for each pair. Here exclude_mask and include_mask are node mask. """
69
+ marginal_subgraph_dataset = MarginalSubgraphDataset(data, exclude_mask, include_mask, subgraph_build_func)
70
+ dataloader = DataLoader(marginal_subgraph_dataset, batch_size=256, shuffle=False, num_workers=0)
71
+
72
+ marginal_contribution_list = []
73
+
74
+ for exclude_data, include_data in dataloader:
75
+ exclude_values = value_func(exclude_data)
76
+ include_values = value_func(include_data)
77
+ margin_values = include_values - exclude_values
78
+ marginal_contribution_list.append(margin_values)
79
+
80
+ marginal_contributions = torch.cat(marginal_contribution_list, dim=0)
81
+ return marginal_contributions
82
+
83
+
84
+ def graph_build_zero_filling(X, edge_index, node_mask: np.array):
85
+ """ subgraph building through masking the unselected nodes with zero features """
86
+ ret_X = X * node_mask.unsqueeze(1)
87
+ return ret_X, edge_index
88
+
89
+
90
+ def graph_build_split(X, edge_index, node_mask: np.array):
91
+ """ subgraph building through spliting the selected nodes from the original graph """
92
+ ret_X = X
93
+ row, col = edge_index
94
+ edge_mask = (node_mask[row] == 1) & (node_mask[col] == 1)
95
+ ret_edge_index = edge_index[:, edge_mask]
96
+ return ret_X, ret_edge_index
97
+
98
+
99
+ def l_shapley(coalition: list, data: Data, local_radius: int,
100
+ value_func: str, subgraph_building_method='zero_filling'):
101
+ """ shapley value where players are local neighbor nodes """
102
+ graph = to_networkx(data)
103
+ num_nodes = graph.number_of_nodes()
104
+ subgraph_build_func = get_graph_build_func(subgraph_building_method)
105
+
106
+ local_region = copy.copy(coalition)
107
+ for k in range(local_radius - 1):
108
+ k_neiborhoood = []
109
+ for node in local_region:
110
+ k_neiborhoood += list(graph.neighbors(node))
111
+ local_region += k_neiborhoood
112
+ local_region = list(set(local_region))
113
+
114
+ set_exclude_masks = []
115
+ set_include_masks = []
116
+ nodes_around = [node for node in local_region if node not in coalition]
117
+ num_nodes_around = len(nodes_around)
118
+
119
+ for subset_len in range(0, num_nodes_around + 1):
120
+ node_exclude_subsets = combinations(nodes_around, subset_len)
121
+ for node_exclude_subset in node_exclude_subsets:
122
+ set_exclude_mask = np.ones(num_nodes)
123
+ set_exclude_mask[local_region] = 0.0
124
+ if node_exclude_subset:
125
+ set_exclude_mask[list(node_exclude_subset)] = 1.0
126
+ set_include_mask = set_exclude_mask.copy()
127
+ set_include_mask[coalition] = 1.0
128
+
129
+ set_exclude_masks.append(set_exclude_mask)
130
+ set_include_masks.append(set_include_mask)
131
+
132
+ exclude_mask = np.stack(set_exclude_masks, axis=0)
133
+ include_mask = np.stack(set_include_masks, axis=0)
134
+ num_players = len(nodes_around) + 1
135
+ num_player_in_set = num_players - 1 + len(coalition) - (1 - exclude_mask).sum(axis=1)
136
+ p = num_players
137
+ S = num_player_in_set
138
+ coeffs = torch.tensor(1.0 / comb(p, S) / (p - S + 1e-6))
139
+
140
+ marginal_contributions = \
141
+ marginal_contribution(data, exclude_mask, include_mask, value_func, subgraph_build_func)
142
+
143
+ l_shapley_value = (marginal_contributions.squeeze().cpu() * coeffs).sum().item()
144
+ return l_shapley_value
145
+
146
+
147
+ def mc_shapley(coalition: list, data: Data,
148
+ value_func: str, subgraph_building_method='zero_filling',
149
+ sample_num=1000) -> float:
150
+ """ monte carlo sampling approximation of the shapley value """
151
+ subset_build_func = get_graph_build_func(subgraph_building_method)
152
+
153
+ num_nodes = data.num_nodes
154
+ node_indices = np.arange(num_nodes)
155
+ coalition_placeholder = num_nodes
156
+ set_exclude_masks = []
157
+ set_include_masks = []
158
+
159
+ for example_idx in range(sample_num):
160
+ subset_nodes_from = [node for node in node_indices if node not in coalition]
161
+ random_nodes_permutation = np.array(subset_nodes_from + [coalition_placeholder])
162
+ random_nodes_permutation = np.random.permutation(random_nodes_permutation)
163
+ split_idx = np.where(random_nodes_permutation == coalition_placeholder)[0][0]
164
+ selected_nodes = random_nodes_permutation[:split_idx]
165
+ set_exclude_mask = np.zeros(num_nodes)
166
+ set_exclude_mask[selected_nodes] = 1.0
167
+ set_include_mask = set_exclude_mask.copy()
168
+ set_include_mask[coalition] = 1.0
169
+
170
+ set_exclude_masks.append(set_exclude_mask)
171
+ set_include_masks.append(set_include_mask)
172
+
173
+ exclude_mask = np.stack(set_exclude_masks, axis=0)
174
+ include_mask = np.stack(set_include_masks, axis=0)
175
+ marginal_contributions = marginal_contribution(data, exclude_mask, include_mask, value_func, subset_build_func)
176
+ mc_shapley_value = marginal_contributions.mean().item()
177
+
178
+ return mc_shapley_value
179
+
180
+
181
+ def mc_l_shapley(coalition: list, data: Data, local_radius: int,
182
+ value_func: str, subgraph_building_method='zero_filling',
183
+ sample_num=1000) -> float:
184
+ """ monte carlo sampling approximation of the l_shapley value """
185
+ graph = to_networkx(data)
186
+ num_nodes = graph.number_of_nodes()
187
+ subgraph_build_func = get_graph_build_func(subgraph_building_method)
188
+
189
+ local_region = copy.copy(coalition)
190
+ for k in range(local_radius - 1):
191
+ k_neiborhoood = []
192
+ for node in local_region:
193
+ k_neiborhoood += list(graph.neighbors(node))
194
+ local_region += k_neiborhoood
195
+ local_region = list(set(local_region))
196
+
197
+ coalition_placeholder = num_nodes
198
+ set_exclude_masks = []
199
+ set_include_masks = []
200
+ for example_idx in range(sample_num):
201
+ subset_nodes_from = [node for node in local_region if node not in coalition]
202
+ random_nodes_permutation = np.array(subset_nodes_from + [coalition_placeholder])
203
+ random_nodes_permutation = np.random.permutation(random_nodes_permutation)
204
+ split_idx = np.where(random_nodes_permutation == coalition_placeholder)[0][0]
205
+ selected_nodes = random_nodes_permutation[:split_idx]
206
+ set_exclude_mask = np.ones(num_nodes)
207
+ set_exclude_mask[local_region] = 0.0
208
+ set_exclude_mask[selected_nodes] = 1.0
209
+ set_include_mask = set_exclude_mask.copy()
210
+ set_include_mask[coalition] = 1.0
211
+
212
+ set_exclude_masks.append(set_exclude_mask)
213
+ set_include_masks.append(set_include_mask)
214
+
215
+ exclude_mask = np.stack(set_exclude_masks, axis=0)
216
+ include_mask = np.stack(set_include_masks, axis=0)
217
+ marginal_contributions = \
218
+ marginal_contribution(data, exclude_mask, include_mask, value_func, subgraph_build_func)
219
+
220
+ mc_l_shapley_value = (marginal_contributions).mean().item()
221
+ return mc_l_shapley_value
222
+
223
+
224
+ def gnn_score(coalition: list, data: Data, value_func: str,
225
+ subgraph_building_method='zero_filling') -> torch.Tensor:
226
+ """ the value of subgraph with selected nodes """
227
+ num_nodes = data.num_nodes
228
+ subgraph_build_func = get_graph_build_func(subgraph_building_method)
229
+ mask = torch.zeros(num_nodes).type(torch.float32).to(data.x.device)
230
+ mask[coalition] = 1.0
231
+ ret_x, ret_edge_index = subgraph_build_func(data.x, data.edge_index, mask)
232
+ mask_data = Data(x=ret_x, edge_index=ret_edge_index)
233
+ mask_data = Batch.from_data_list([mask_data])
234
+ score = value_func(mask_data)
235
+ # get the score of predicted class for graph or specific node idx
236
+ return score.item()
237
+
238
+
239
+ def NC_mc_l_shapley(coalition: list, data: Data, local_radius: int,
240
+ value_func: str, node_idx: int = -1,
241
+ subgraph_building_method='zero_filling', sample_num=1000) -> float:
242
+ """ monte carlo approximation of l_shapley where the target node is kept in both subgraph """
243
+ graph = to_networkx(data)
244
+ num_nodes = graph.number_of_nodes()
245
+ subgraph_build_func = get_graph_build_func(subgraph_building_method)
246
+
247
+ local_region = copy.copy(coalition)
248
+ for k in range(local_radius - 1):
249
+ k_neiborhoood = []
250
+ for node in local_region:
251
+ k_neiborhoood += list(graph.neighbors(node))
252
+ local_region += k_neiborhoood
253
+ local_region = list(set(local_region))
254
+
255
+ coalition_placeholder = num_nodes
256
+ set_exclude_masks = []
257
+ set_include_masks = []
258
+ for example_idx in range(sample_num):
259
+ subset_nodes_from = [node for node in local_region if node not in coalition]
260
+ random_nodes_permutation = np.array(subset_nodes_from + [coalition_placeholder])
261
+ random_nodes_permutation = np.random.permutation(random_nodes_permutation)
262
+ split_idx = np.where(random_nodes_permutation == coalition_placeholder)[0][0]
263
+ selected_nodes = random_nodes_permutation[:split_idx]
264
+ set_exclude_mask = np.ones(num_nodes)
265
+ set_exclude_mask[local_region] = 0.0
266
+ set_exclude_mask[selected_nodes] = 1.0
267
+ if node_idx != -1:
268
+ set_exclude_mask[node_idx] = 1.0
269
+ set_include_mask = set_exclude_mask.copy()
270
+ set_include_mask[coalition] = 1.0 # include the node_idx
271
+
272
+ set_exclude_masks.append(set_exclude_mask)
273
+ set_include_masks.append(set_include_mask)
274
+
275
+ exclude_mask = np.stack(set_exclude_masks, axis=0)
276
+ include_mask = np.stack(set_include_masks, axis=0)
277
+ marginal_contributions = \
278
+ marginal_contribution(data, exclude_mask, include_mask, value_func, subgraph_build_func)
279
+
280
+ mc_l_shapley_value = (marginal_contributions).mean().item()
281
+ return mc_l_shapley_value
282
+
283
+
284
+ def sparsity(coalition: list, data: Data, subgraph_building_method='zero_filling'):
285
+ if subgraph_building_method == 'zero_filling':
286
+ return 1.0 - len(coalition) / data.num_nodes
287
+
288
+ elif subgraph_building_method == 'split':
289
+ row, col = data.edge_index
290
+ node_mask = torch.zeros(data.x.shape[0])
291
+ node_mask[coalition] = 1.0
292
+ edge_mask = (node_mask[row] == 1) & (node_mask[col] == 1)
293
+ return 1.0 - edge_mask.sum() / edge_mask.shape[0]