hjxdl 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (91) hide show
  1. hdl/__init__.py +0 -0
  2. hdl/_version.py +16 -0
  3. hdl/args/__init__.py +0 -0
  4. hdl/args/loss_args.py +5 -0
  5. hdl/controllers/__init__.py +0 -0
  6. hdl/controllers/al/__init__.py +0 -0
  7. hdl/controllers/al/al.py +0 -0
  8. hdl/controllers/al/dispatcher.py +0 -0
  9. hdl/controllers/al/feedback.py +0 -0
  10. hdl/controllers/explain/__init__.py +0 -0
  11. hdl/controllers/explain/shapley.py +293 -0
  12. hdl/controllers/explain/subgraphx.py +865 -0
  13. hdl/controllers/train/__init__.py +0 -0
  14. hdl/controllers/train/rxn_train.py +219 -0
  15. hdl/controllers/train/train.py +50 -0
  16. hdl/controllers/train/train_ginet.py +316 -0
  17. hdl/controllers/train/trainer_base.py +155 -0
  18. hdl/controllers/train/trainer_iterative.py +389 -0
  19. hdl/data/__init__.py +0 -0
  20. hdl/data/dataset/__init__.py +0 -0
  21. hdl/data/dataset/base_dataset.py +98 -0
  22. hdl/data/dataset/fp/__init__.py +0 -0
  23. hdl/data/dataset/fp/fp_dataset.py +122 -0
  24. hdl/data/dataset/graph/__init__.py +0 -0
  25. hdl/data/dataset/graph/chiral.py +62 -0
  26. hdl/data/dataset/graph/gin.py +255 -0
  27. hdl/data/dataset/graph/molnet.py +362 -0
  28. hdl/data/dataset/loaders/__init__.py +0 -0
  29. hdl/data/dataset/loaders/chiral_graph.py +71 -0
  30. hdl/data/dataset/loaders/collate_funcs/__init__.py +0 -0
  31. hdl/data/dataset/loaders/collate_funcs/fp.py +56 -0
  32. hdl/data/dataset/loaders/collate_funcs/rxn.py +40 -0
  33. hdl/data/dataset/loaders/general.py +23 -0
  34. hdl/data/dataset/loaders/spliter.py +86 -0
  35. hdl/data/dataset/samplers/__init__.py +0 -0
  36. hdl/data/dataset/samplers/chiral.py +19 -0
  37. hdl/data/dataset/seq/__init__.py +0 -0
  38. hdl/data/dataset/seq/rxn_dataset.py +61 -0
  39. hdl/data/dataset/utils.py +31 -0
  40. hdl/data/to_mols.py +0 -0
  41. hdl/features/__init__.py +0 -0
  42. hdl/features/fp/__init__.py +0 -0
  43. hdl/features/fp/features_generators.py +235 -0
  44. hdl/features/graph/__init__.py +0 -0
  45. hdl/features/graph/featurization.py +297 -0
  46. hdl/features/utils/__init__.py +0 -0
  47. hdl/features/utils/utils.py +111 -0
  48. hdl/layers/__init__.py +0 -0
  49. hdl/layers/general/__init__.py +0 -0
  50. hdl/layers/general/gp.py +14 -0
  51. hdl/layers/general/linear.py +641 -0
  52. hdl/layers/graph/__init__.py +0 -0
  53. hdl/layers/graph/chiral_graph.py +230 -0
  54. hdl/layers/graph/gcn.py +16 -0
  55. hdl/layers/graph/gin.py +45 -0
  56. hdl/layers/graph/tetra.py +158 -0
  57. hdl/layers/graph/transformer.py +188 -0
  58. hdl/layers/sequential/__init__.py +0 -0
  59. hdl/metric_loss/__init__.py +0 -0
  60. hdl/metric_loss/loss.py +79 -0
  61. hdl/metric_loss/metric.py +178 -0
  62. hdl/metric_loss/multi_label.py +42 -0
  63. hdl/metric_loss/nt_xent.py +65 -0
  64. hdl/models/__init__.py +0 -0
  65. hdl/models/chiral_gnn.py +176 -0
  66. hdl/models/fast_transformer.py +234 -0
  67. hdl/models/ginet.py +189 -0
  68. hdl/models/linear.py +137 -0
  69. hdl/models/model_dict.py +18 -0
  70. hdl/models/norm_flows.py +33 -0
  71. hdl/models/optim_dict.py +16 -0
  72. hdl/models/rxn.py +63 -0
  73. hdl/models/utils.py +83 -0
  74. hdl/ops/__init__.py +0 -0
  75. hdl/ops/utils.py +42 -0
  76. hdl/optims/__init__.py +0 -0
  77. hdl/optims/nadam.py +86 -0
  78. hdl/utils/__init__.py +0 -0
  79. hdl/utils/chemical_tools/__init__.py +2 -0
  80. hdl/utils/chemical_tools/query_info.py +149 -0
  81. hdl/utils/chemical_tools/sdf.py +20 -0
  82. hdl/utils/database_tools/__init__.py +0 -0
  83. hdl/utils/database_tools/connect.py +28 -0
  84. hdl/utils/general/__init__.py +0 -0
  85. hdl/utils/general/glob.py +21 -0
  86. hdl/utils/schedulers/__init__.py +0 -0
  87. hdl/utils/schedulers/norm_lr.py +108 -0
  88. hjxdl-0.0.1.dist-info/METADATA +19 -0
  89. hjxdl-0.0.1.dist-info/RECORD +91 -0
  90. hjxdl-0.0.1.dist-info/WHEEL +5 -0
  91. hjxdl-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,865 @@
1
+ import os
2
+ import math
3
+ import copy
4
+ from typing import List, Tuple, Dict
5
+ from typing import Callable, Union, Optional
6
+ from collections import Counter
7
+ from functools import partial
8
+ from textwrap import wrap
9
+
10
+ import torch
11
+ import numpy as np
12
+ import networkx as nx
13
+ from rdkit import Chem
14
+ from torch import Tensor
15
+ from torch_geometric.data import Batch, Data
16
+ from torch_geometric.utils import to_networkx
17
+ import matplotlib.pyplot as plt
18
+ from torch_geometric.utils.num_nodes import maybe_num_nodes
19
+ from torch_geometric.nn.conv import MessagePassing
20
+ from torch_geometric.datasets import MoleculeNet
21
+ from torch_geometric.utils import remove_self_loops
22
+
23
+ from .shapley import GnnNetsGC2valueFunc, GnnNetsNC2valueFunc, \
24
+ gnn_score, mc_shapley, l_shapley, mc_l_shapley, NC_mc_l_shapley, sparsity
25
+
26
+
27
+ def find_closest_node_result(results, max_nodes):
28
+ """ return the highest reward tree_node with its subgraph is smaller than max_nodes """
29
+ results = sorted(results, key=lambda x: len(x.coalition))
30
+
31
+ result_node = results[0]
32
+ for result_idx in range(len(results)):
33
+ x = results[result_idx]
34
+ if len(x.coalition) <= max_nodes and x.P > result_node.P:
35
+ result_node = x
36
+ return result_node
37
+
38
+
39
+ def reward_func(reward_method, value_func, node_idx=None,
40
+ local_radius=4, sample_num=100,
41
+ subgraph_building_method='zero_filling'):
42
+ if reward_method.lower() == 'gnn_score':
43
+ return partial(gnn_score,
44
+ value_func=value_func,
45
+ subgraph_building_method=subgraph_building_method)
46
+
47
+ elif reward_method.lower() == 'mc_shapley':
48
+ return partial(mc_shapley,
49
+ value_func=value_func,
50
+ subgraph_building_method=subgraph_building_method,
51
+ sample_num=sample_num)
52
+
53
+ elif reward_method.lower() == 'l_shapley':
54
+ return partial(l_shapley,
55
+ local_radius=local_radius,
56
+ value_func=value_func,
57
+ subgraph_building_method=subgraph_building_method)
58
+
59
+ elif reward_method.lower() == 'mc_l_shapley':
60
+ return partial(mc_l_shapley,
61
+ local_radius=local_radius,
62
+ value_func=value_func,
63
+ subgraph_building_method=subgraph_building_method,
64
+ sample_num=sample_num)
65
+
66
+ elif reward_method.lower() == 'nc_mc_l_shapley':
67
+ assert node_idx is not None, " Wrong node idx input "
68
+ return partial(NC_mc_l_shapley,
69
+ node_idx=node_idx,
70
+ local_radius=local_radius,
71
+ value_func=value_func,
72
+ subgraph_building_method=subgraph_building_method,
73
+ sample_num=sample_num)
74
+
75
+ else:
76
+ raise NotImplementedError
77
+
78
+
79
+ def k_hop_subgraph_with_default_whole_graph(
80
+ edge_index, node_idx=None, num_hops=3, relabel_nodes=False,
81
+ num_nodes=None, flow='source_to_target'):
82
+ r"""Computes the :math:`k`-hop subgraph of :obj:`edge_index` around node
83
+ :attr:`node_idx`.
84
+ It returns (1) the nodes involved in the subgraph, (2) the filtered
85
+ :obj:`edge_index` connectivity, (3) the mapping from node indices in
86
+ :obj:`node_idx` to their new location, and (4) the edge mask indicating
87
+ which edges were preserved.
88
+ Args:
89
+ node_idx (int, list, tuple or :obj:`torch.Tensor`): The central
90
+ node(s).
91
+ num_hops: (int): The number of hops :math:`k`.
92
+ edge_index (LongTensor): The edge indices.
93
+ relabel_nodes (bool, optional): If set to :obj:`True`, the resulting
94
+ :obj:`edge_index` will be relabeled to hold consecutive indices
95
+ starting from zero. (default: :obj:`False`)
96
+ num_nodes (int, optional): The number of nodes, *i.e.*
97
+ :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`)
98
+ flow (string, optional): The flow direction of :math:`k`-hop
99
+ aggregation (:obj:`"source_to_target"` or
100
+ :obj:`"target_to_source"`). (default: :obj:`"source_to_target"`)
101
+ :rtype: (:class:`LongTensor`, :class:`LongTensor`, :class:`LongTensor`,
102
+ :class:`BoolTensor`)
103
+ """
104
+
105
+ num_nodes = maybe_num_nodes(edge_index, num_nodes)
106
+
107
+ assert flow in ['source_to_target', 'target_to_source']
108
+ if flow == 'target_to_source':
109
+ row, col = edge_index
110
+ else:
111
+ col, row = edge_index # edge_index 0 to 1, col: source, row: target
112
+
113
+ node_mask = row.new_empty(num_nodes, dtype=torch.bool)
114
+ edge_mask = row.new_empty(row.size(0), dtype=torch.bool)
115
+
116
+ inv = None
117
+
118
+ if node_idx is None:
119
+ subsets = torch.tensor([0])
120
+ cur_subsets = subsets
121
+ while 1:
122
+ node_mask.fill_(False)
123
+ node_mask[subsets] = True
124
+ torch.index_select(node_mask, 0, row, out=edge_mask)
125
+ subsets = torch.cat([subsets, col[edge_mask]]).unique()
126
+ if not cur_subsets.equal(subsets):
127
+ cur_subsets = subsets
128
+ else:
129
+ subset = subsets
130
+ break
131
+ else:
132
+ if isinstance(node_idx, (int, list, tuple)):
133
+ node_idx = torch.tensor([node_idx], device=row.device, dtype=torch.int64).flatten()
134
+ elif isinstance(node_idx, torch.Tensor) and len(node_idx.shape) == 0:
135
+ node_idx = torch.tensor([node_idx])
136
+ else:
137
+ node_idx = node_idx.to(row.device)
138
+
139
+ subsets = [node_idx]
140
+ for _ in range(num_hops):
141
+ node_mask.fill_(False)
142
+ node_mask[subsets[-1]] = True
143
+ torch.index_select(node_mask, 0, row, out=edge_mask)
144
+ subsets.append(col[edge_mask])
145
+ subset, inv = torch.cat(subsets).unique(return_inverse=True)
146
+ inv = inv[:node_idx.numel()]
147
+
148
+ node_mask.fill_(False)
149
+ node_mask[subset] = True
150
+ edge_mask = node_mask[row] & node_mask[col]
151
+
152
+ edge_index = edge_index[:, edge_mask]
153
+
154
+ if relabel_nodes:
155
+ node_idx = row.new_full((num_nodes,), -1)
156
+ node_idx[subset] = torch.arange(subset.size(0), device=row.device)
157
+ edge_index = node_idx[edge_index]
158
+
159
+ return subset, edge_index, inv, edge_mask # subset: key new node idx; value original node idx
160
+
161
+
162
+ def compute_scores(score_func, children):
163
+ results = []
164
+ for child in children:
165
+ if child.P == 0:
166
+ score = score_func(child.coalition, child.data)
167
+ else:
168
+ score = child.P
169
+ results.append(score)
170
+ return results
171
+
172
+
173
+ class PlotUtils(object):
174
+ def __init__(self, dataset_name, is_show=True):
175
+ self.dataset_name = dataset_name
176
+ self.is_show = is_show
177
+
178
+ def plot(self, graph, nodelist, figname, title_sentence=None, **kwargs):
179
+ """ plot function for different dataset """
180
+ if self.dataset_name.lower() in ['ba_2motifs', 'ba_lrp']:
181
+ self.plot_ba2motifs(graph, nodelist, title_sentence=title_sentence, figname=figname)
182
+ elif self.dataset_name.lower() in ['mutag'] + list(MoleculeNet.names.keys()):
183
+ x = kwargs.get('x')
184
+ self.plot_molecule(graph, nodelist, x, title_sentence=title_sentence, figname=figname)
185
+ elif self.dataset_name.lower() in ['ba_shapes', 'ba_community', 'tree_grid', 'tree_cycle']:
186
+ y = kwargs.get('y')
187
+ node_idx = kwargs.get('node_idx')
188
+ self.plot_bashapes(graph, nodelist, y, node_idx, title_sentence=title_sentence, figname=figname)
189
+ elif self.dataset_name.lower() in ['graph_sst2', 'graph_sst5', 'twitter']:
190
+ words = kwargs.get('words')
191
+ self.plot_sentence(graph, nodelist, words=words, title_sentence=title_sentence, figname=figname)
192
+ else:
193
+ raise NotImplementedError
194
+
195
+ def plot_subgraph(self,
196
+ graph,
197
+ nodelist,
198
+ colors: Union[None, str, List[str]] = '#FFA500',
199
+ labels=None,
200
+ edge_color='gray',
201
+ edgelist=None,
202
+ subgraph_edge_color='black',
203
+ title_sentence=None,
204
+ figname=None):
205
+
206
+ if edgelist is None:
207
+ edgelist = [(n_frm, n_to) for (n_frm, n_to) in graph.edges()
208
+ if n_frm in nodelist and n_to in nodelist]
209
+ pos = nx.kamada_kawai_layout(graph)
210
+ pos_nodelist = {k: v for k, v in pos.items() if k in nodelist}
211
+
212
+ nx.draw_networkx_nodes(graph, pos,
213
+ nodelist=list(graph.nodes()),
214
+ node_color=colors,
215
+ node_size=300)
216
+
217
+ nx.draw_networkx_edges(graph, pos, width=3, edge_color=edge_color, arrows=False)
218
+
219
+ nx.draw_networkx_edges(graph, pos=pos_nodelist,
220
+ edgelist=edgelist, width=6,
221
+ edge_color=subgraph_edge_color,
222
+ arrows=False)
223
+
224
+ if labels is not None:
225
+ nx.draw_networkx_labels(graph, pos, labels)
226
+
227
+ plt.axis('off')
228
+ if figname is not None:
229
+ plt.savefig(figname)
230
+ if title_sentence is not None:
231
+ plt.title('\n'.join(wrap(title_sentence, width=60)))
232
+ if self.is_show:
233
+ plt.show()
234
+ if figname is not None:
235
+ plt.close()
236
+
237
+ def plot_subgraph_with_nodes(self,
238
+ graph,
239
+ nodelist,
240
+ node_idx,
241
+ colors='#FFA500',
242
+ labels=None,
243
+ edge_color='gray',
244
+ edgelist=None,
245
+ subgraph_edge_color='black',
246
+ title_sentence=None,
247
+ figname=None):
248
+ node_idx = int(node_idx)
249
+ if edgelist is None:
250
+ edgelist = [(n_frm, n_to) for (n_frm, n_to) in graph.edges()
251
+ if n_frm in nodelist and n_to in nodelist]
252
+
253
+ pos = nx.kamada_kawai_layout(graph) # calculate according to graph.nodes()
254
+ pos_nodelist = {k: v for k, v in pos.items() if k in nodelist}
255
+
256
+ nx.draw_networkx_nodes(graph, pos,
257
+ nodelist=list(graph.nodes()),
258
+ node_color=colors,
259
+ node_size=300)
260
+ if isinstance(colors, list):
261
+ list_indices = int(np.where(np.array(graph.nodes()) == node_idx)[0])
262
+ node_idx_color = colors[list_indices]
263
+ else:
264
+ node_idx_color = colors
265
+
266
+ nx.draw_networkx_nodes(graph, pos=pos,
267
+ nodelist=[node_idx],
268
+ node_color=node_idx_color,
269
+ node_size=600)
270
+
271
+ nx.draw_networkx_edges(graph, pos, width=3, edge_color=edge_color, arrows=False)
272
+
273
+ nx.draw_networkx_edges(graph, pos=pos_nodelist,
274
+ edgelist=edgelist, width=3,
275
+ edge_color=subgraph_edge_color,
276
+ arrows=False)
277
+
278
+ if labels is not None:
279
+ nx.draw_networkx_labels(graph, pos, labels)
280
+
281
+ plt.axis('off')
282
+ if title_sentence is not None:
283
+ plt.title('\n'.join(wrap(title_sentence, width=60)))
284
+
285
+ if figname is not None:
286
+ plt.savefig(figname)
287
+ if self.is_show:
288
+ plt.show()
289
+ if figname is not None:
290
+ plt.close()
291
+
292
+ def plot_sentence(self, graph, nodelist, words, edgelist=None, title_sentence=None, figname=None):
293
+ pos = nx.kamada_kawai_layout(graph)
294
+ words_dict = {i: words[i] for i in graph.nodes}
295
+ if nodelist is not None:
296
+ pos_coalition = {k: v for k, v in pos.items() if k in nodelist}
297
+ nx.draw_networkx_nodes(graph, pos_coalition,
298
+ nodelist=nodelist,
299
+ node_color='yellow',
300
+ node_shape='o',
301
+ node_size=500)
302
+ if edgelist is None:
303
+ edgelist = [(n_frm, n_to) for (n_frm, n_to) in graph.edges()
304
+ if n_frm in nodelist and n_to in nodelist]
305
+ nx.draw_networkx_edges(graph, pos=pos_coalition, edgelist=edgelist, width=5, edge_color='yellow')
306
+
307
+ nx.draw_networkx_nodes(graph, pos, nodelist=list(graph.nodes()), node_size=300)
308
+
309
+ nx.draw_networkx_edges(graph, pos, width=2, edge_color='grey')
310
+ nx.draw_networkx_labels(graph, pos, words_dict)
311
+
312
+ plt.axis('off')
313
+ plt.title('\n'.join(wrap(' '.join(words), width=50)))
314
+ if title_sentence is not None:
315
+ string = '\n'.join(wrap(' '.join(words), width=50))
316
+ string += '\n'.join(wrap(title_sentence, width=60))
317
+ plt.title(string)
318
+ if figname is not None:
319
+ plt.savefig(figname)
320
+ if self.is_show:
321
+ plt.show()
322
+ if figname is not None:
323
+ plt.close()
324
+
325
+ def plot_ba2motifs(self,
326
+ graph,
327
+ nodelist,
328
+ edgelist=None,
329
+ title_sentence=None,
330
+ figname=None):
331
+ return self.plot_subgraph(graph, nodelist,
332
+ edgelist=edgelist,
333
+ title_sentence=title_sentence,
334
+ figname=figname)
335
+
336
+ def plot_molecule(self,
337
+ graph,
338
+ nodelist,
339
+ x,
340
+ edgelist=None,
341
+ title_sentence=None,
342
+ figname=None):
343
+ # collect the text information and node color
344
+ if self.dataset_name == 'mutag':
345
+ node_dict = {0: 'C', 1: 'N', 2: 'O', 3: 'F', 4: 'I', 5: 'Cl', 6: 'Br'}
346
+ node_idxs = {k: int(v) for k, v in enumerate(np.where(x.cpu().numpy() == 1)[1])}
347
+ node_labels = {k: node_dict[v] for k, v in node_idxs.items()}
348
+ node_color = ['#E49D1C', '#4970C6', '#FF5357', '#29A329', 'brown', 'darkslategray', '#F0EA00']
349
+ colors = [node_color[v % len(node_color)] for k, v in node_idxs.items()]
350
+
351
+ elif self.dataset_name in MoleculeNet.names.keys():
352
+ element_idxs = {k: int(v) for k, v in enumerate(x[:, 0])}
353
+ node_idxs = element_idxs
354
+ node_labels = {k: Chem.PeriodicTable.GetElementSymbol(Chem.GetPeriodicTable(), int(v))
355
+ for k, v in element_idxs.items()}
356
+ node_color = ['#29A329', 'lime', '#F0EA00', 'maroon', 'brown', '#E49D1C', '#4970C6', '#FF5357']
357
+ colors = [node_color[(v - 1) % len(node_color)] for k, v in node_idxs.items()]
358
+ else:
359
+ raise NotImplementedError
360
+
361
+ self.plot_subgraph(graph, nodelist,
362
+ colors=colors,
363
+ labels=node_labels,
364
+ edgelist=edgelist,
365
+ edge_color='gray',
366
+ subgraph_edge_color='black',
367
+ title_sentence=title_sentence,
368
+ figname=figname)
369
+
370
+ def plot_bashapes(self,
371
+ graph,
372
+ nodelist,
373
+ y,
374
+ node_idx,
375
+ edgelist=None,
376
+ title_sentence=None,
377
+ figname=None):
378
+ node_idxs = {k: int(v) for k, v in enumerate(y.reshape(-1).tolist())}
379
+ node_color = ['#FFA500', '#4970C6', '#FE0000', 'green']
380
+ colors = [node_color[v % len(node_color)] for k, v in node_idxs.items()]
381
+ self.plot_subgraph_with_nodes(graph,
382
+ nodelist,
383
+ node_idx,
384
+ colors,
385
+ edgelist=edgelist,
386
+ title_sentence=title_sentence,
387
+ figname=figname,
388
+ subgraph_edge_color='black')
389
+
390
+
391
+ class MCTSNode(object):
392
+ def __init__(self, coalition: list = None, data: Data = None, ori_graph: nx.Graph = None,
393
+ c_puct: float = 10.0, W: float = 0, N: int = 0, P: float = 0,
394
+ load_dict: Optional[Dict] = None, device='cpu'):
395
+ self.data = data
396
+ self.coalition = coalition
397
+ self.ori_graph = ori_graph
398
+ self.device = device
399
+ self.c_puct = c_puct
400
+ self.children = []
401
+ self.W = W # sum of node value
402
+ self.N = N # times of arrival
403
+ self.P = P # property score (reward)
404
+ if load_dict is not None:
405
+ self.load_info(load_dict)
406
+
407
+ def Q(self):
408
+ return self.W / self.N if self.N > 0 else 0
409
+
410
+ def U(self, n):
411
+ return self.c_puct * self.P * math.sqrt(n) / (1 + self.N)
412
+
413
+ @property
414
+ def info(self):
415
+ info_dict = {
416
+ 'data': self.data.to('cpu'),
417
+ 'coalition': self.coalition,
418
+ 'ori_graph': self.ori_graph,
419
+ 'W': self.W,
420
+ 'N': self.N,
421
+ 'P': self.P
422
+ }
423
+ return info_dict
424
+
425
+ def load_info(self, info_dict):
426
+ self.W = info_dict['W']
427
+ self.N = info_dict['N']
428
+ self.P = info_dict['P']
429
+ self.coalition = info_dict['coalition']
430
+ self.ori_graph = info_dict['ori_graph']
431
+ self.data = info_dict['data'].to(self.device)
432
+ self.children = []
433
+ return self
434
+
435
+
436
+ class MCTS(object):
437
+ r"""
438
+ Monte Carlo Tree Search Method
439
+ Args:
440
+ X (:obj:`torch.Tensor`): Input node features
441
+ edge_index (:obj:`torch.Tensor`): The edge indices.
442
+ num_hops (:obj:`int`): The number of hops :math:`k`.
443
+ n_rollout (:obj:`int`): The number of sequence to build the monte carlo tree.
444
+ min_atoms (:obj:`int`): The number of atoms for the subgraph in the monte carlo tree leaf node.
445
+ c_puct (:obj:`float`): The hyper-parameter to encourage exploration while searching.
446
+ expand_atoms (:obj:`int`): The number of children to expand.
447
+ high2low (:obj:`bool`): Whether to expand children tree node from high degree nodes to low degree nodes.
448
+ node_idx (:obj:`int`): The target node index to extract the neighborhood.
449
+ score_func (:obj:`Callable`): The reward function for tree node, such as mc_shapely and mc_l_shapely.
450
+ """
451
+ def __init__(self, X: torch.Tensor, edge_index: torch.Tensor, num_hops: int,
452
+ n_rollout: int = 10, min_atoms: int = 3, c_puct: float = 10.0,
453
+ expand_atoms: int = 14, high2low: bool = False,
454
+ node_idx: int = None, score_func: Callable = None, device='cpu'):
455
+
456
+ self.X = X
457
+ self.edge_index = edge_index
458
+ self.device = device
459
+ self.num_hops = num_hops
460
+ self.data = Data(x=self.X, edge_index=self.edge_index)
461
+ graph_data = Data(x=self.X, edge_index=remove_self_loops(self.edge_index)[0])
462
+ self.graph = to_networkx(graph_data, to_undirected=True)
463
+ self.data = Batch.from_data_list([self.data])
464
+ self.num_nodes = self.graph.number_of_nodes()
465
+ self.score_func = score_func
466
+ self.n_rollout = n_rollout
467
+ self.min_atoms = min_atoms
468
+ self.c_puct = c_puct
469
+ self.expand_atoms = expand_atoms
470
+ self.high2low = high2low
471
+ self.new_node_idx = None
472
+
473
+ # extract the sub-graph and change the node indices.
474
+ if node_idx is not None:
475
+ self.ori_node_idx = node_idx
476
+ self.ori_graph = copy.copy(self.graph)
477
+ x, edge_index, subset, edge_mask, kwargs = \
478
+ self.__subgraph__(node_idx, self.X, self.edge_index, self.num_hops)
479
+ self.data = Batch.from_data_list([Data(x=x, edge_index=edge_index)])
480
+ self.graph = self.ori_graph.subgraph(subset.tolist())
481
+ mapping = {int(v): k for k, v in enumerate(subset)}
482
+ self.graph = nx.relabel_nodes(self.graph, mapping)
483
+ self.new_node_idx = torch.where(subset == self.ori_node_idx)[0].item()
484
+ self.num_nodes = self.graph.number_of_nodes()
485
+ self.subset = subset
486
+
487
+ self.root_coalition = sorted([node for node in range(self.num_nodes)])
488
+ self.MCTSNodeClass = partial(MCTSNode, data=self.data, ori_graph=self.graph,
489
+ c_puct=self.c_puct, device=self.device)
490
+ self.root = self.MCTSNodeClass(self.root_coalition)
491
+ self.state_map = {str(self.root.coalition): self.root}
492
+
493
+ def set_score_func(self, score_func):
494
+ self.score_func = score_func
495
+
496
+ @staticmethod
497
+ def __subgraph__(node_idx, x, edge_index, num_hops, **kwargs):
498
+ num_nodes, num_edges = x.size(0), edge_index.size(1)
499
+ subset, edge_index, _, edge_mask = k_hop_subgraph_with_default_whole_graph(
500
+ edge_index, node_idx, num_hops, relabel_nodes=True, num_nodes=num_nodes)
501
+
502
+ x = x[subset]
503
+ for key, item in kwargs.items():
504
+ if torch.is_tensor(item) and item.size(0) == num_nodes:
505
+ item = item[subset]
506
+ elif torch.is_tensor(item) and item.size(0) == num_edges:
507
+ item = item[edge_mask]
508
+ kwargs[key] = item
509
+
510
+ return x, edge_index, subset, edge_mask, kwargs
511
+
512
+ def mcts_rollout(self, tree_node):
513
+ cur_graph_coalition = tree_node.coalition
514
+ if len(cur_graph_coalition) <= self.min_atoms:
515
+ return tree_node.P
516
+
517
+ # Expand if this node has never been visited
518
+ if len(tree_node.children) == 0:
519
+ node_degree_list = list(self.graph.subgraph(cur_graph_coalition).degree)
520
+ node_degree_list = sorted(node_degree_list, key=lambda x: x[1], reverse=self.high2low)
521
+ all_nodes = [x[0] for x in node_degree_list]
522
+
523
+ if self.new_node_idx:
524
+ expand_nodes = [node for node in all_nodes if node != self.new_node_idx]
525
+ else:
526
+ expand_nodes = all_nodes
527
+
528
+ if len(all_nodes) > self.expand_atoms:
529
+ expand_nodes = expand_nodes[:self.expand_atoms]
530
+
531
+ for each_node in expand_nodes:
532
+ # for each node, pruning it and get the remaining sub-graph
533
+ # here we check the resulting sub-graphs and only keep the largest one
534
+ subgraph_coalition = [node for node in all_nodes if node != each_node]
535
+
536
+ subgraphs = [self.graph.subgraph(c)
537
+ for c in nx.connected_components(self.graph.subgraph(subgraph_coalition))]
538
+
539
+ if self.new_node_idx:
540
+ for sub in subgraphs:
541
+ if self.new_node_idx in list(sub.nodes()):
542
+ main_sub = sub
543
+ else:
544
+ main_sub = subgraphs[0]
545
+
546
+ for sub in subgraphs:
547
+ if sub.number_of_nodes() > main_sub.number_of_nodes():
548
+ main_sub = sub
549
+
550
+ new_graph_coalition = sorted(list(main_sub.nodes()))
551
+
552
+ # check the state map and merge the same sub-graph
553
+ find_same = False
554
+ for old_graph_node in self.state_map.values():
555
+ if Counter(old_graph_node.coalition) == Counter(new_graph_coalition):
556
+ new_node = old_graph_node
557
+ find_same = True
558
+
559
+ if not find_same:
560
+ new_node = self.MCTSNodeClass(new_graph_coalition)
561
+ self.state_map[str(new_graph_coalition)] = new_node
562
+
563
+ find_same_child = False
564
+ for cur_child in tree_node.children:
565
+ if Counter(cur_child.coalition) == Counter(new_graph_coalition):
566
+ find_same_child = True
567
+
568
+ if not find_same_child:
569
+ tree_node.children.append(new_node)
570
+
571
+ scores = compute_scores(self.score_func, tree_node.children)
572
+ for child, score in zip(tree_node.children, scores):
573
+ child.P = score
574
+
575
+ sum_count = sum([c.N for c in tree_node.children])
576
+ selected_node = max(tree_node.children, key=lambda x: x.Q() + x.U(sum_count))
577
+ v = self.mcts_rollout(selected_node)
578
+ selected_node.W += v
579
+ selected_node.N += 1
580
+ return v
581
+
582
+ def mcts(self, verbose=True):
583
+ if verbose:
584
+ print(f"The nodes in graph is {self.graph.number_of_nodes()}")
585
+ for rollout_idx in range(self.n_rollout):
586
+ self.mcts_rollout(self.root)
587
+ if verbose:
588
+ print(f"At the {rollout_idx} rollout, {len(self.state_map)} states that have been explored.")
589
+
590
+ explanations = [node for _, node in self.state_map.items()]
591
+ explanations = sorted(explanations, key=lambda x: x.P, reverse=True)
592
+ return explanations
593
+
594
+
595
+ class SubgraphX(object):
596
+ r"""
597
+ The implementation of paper
598
+ `On Explainability of Graph Neural Networks via Subgraph Explorations <https://arxiv.org/abs/2102.05152>`_.
599
+ Args:
600
+ model (:obj:`torch.nn.Module`): The target model prepared to explain
601
+ num_classes(:obj:`int`): Number of classes for the datasets
602
+ num_hops(:obj:`int`, :obj:`None`): The number of hops to extract neighborhood of target node
603
+ (default: :obj:`None`)
604
+ explain_graph(:obj:`bool`): Whether to explain graph classification model (default: :obj:`True`)
605
+ rollout(:obj:`int`): Number of iteration to get the prediction
606
+ min_atoms(:obj:`int`): Number of atoms of the leaf node in search tree
607
+ c_puct(:obj:`float`): The hyperparameter which encourages the exploration
608
+ expand_atoms(:obj:`int`): The number of atoms to expand
609
+ when extend the child nodes in the search tree
610
+ high2low(:obj:`bool`): Whether to expand children nodes from high degree to low degree when
611
+ extend the child nodes in the search tree (default: :obj:`False`)
612
+ local_radius(:obj:`int`): Number of local radius to calculate :obj:`l_shapley`, :obj:`mc_l_shapley`
613
+ sample_num(:obj:`int`): Sampling time of monte carlo sampling approximation for
614
+ :obj:`mc_shapley`, :obj:`mc_l_shapley` (default: :obj:`mc_l_shapley`)
615
+ reward_method(:obj:`str`): The command string to select the
616
+ subgraph_building_method(:obj:`str`): The command string for different subgraph building method,
617
+ such as :obj:`zero_filling`, :obj:`split` (default: :obj:`zero_filling`)
618
+ save_dir(:obj:`str`, :obj:`None`): Root directory to save the explanation results (default: :obj:`None`)
619
+ filename(:obj:`str`): The filename of results
620
+ vis(:obj:`bool`): Whether to show the visualization (default: :obj:`True`)
621
+ Example:
622
+ >>> # For graph classification task
623
+ >>> subgraphx = SubgraphX(model=model, num_classes=2)
624
+ >>> _, explanation_results, related_preds = subgraphx(x, edge_index)
625
+ """
626
+ def __init__(self, model, num_classes: int, device, num_hops: Optional[int] = None, verbose: bool = False,
627
+ explain_graph: bool = True, rollout: int = 20, min_atoms: int = 5, c_puct: float = 10.0,
628
+ expand_atoms=14, high2low=False, local_radius=4, sample_num=100, reward_method='mc_l_shapley',
629
+ subgraph_building_method='zero_filling', save_dir: Optional[str] = None,
630
+ filename: str = 'example', vis: bool = True):
631
+
632
+ self.model = model
633
+ self.model.eval()
634
+ self.device = device
635
+ self.model.to(self.device)
636
+ self.num_classes = num_classes
637
+ self.num_hops = self.update_num_hops(num_hops)
638
+ self.explain_graph = explain_graph
639
+ self.verbose = verbose
640
+
641
+ # mcts hyper-parameters
642
+ self.rollout = rollout
643
+ self.min_atoms = min_atoms
644
+ self.c_puct = c_puct
645
+ self.expand_atoms = expand_atoms
646
+ self.high2low = high2low
647
+
648
+ # reward function hyper-parameters
649
+ self.local_radius = local_radius
650
+ self.sample_num = sample_num
651
+ self.reward_method = reward_method
652
+ self.subgraph_building_method = subgraph_building_method
653
+
654
+ # saving and visualization
655
+ self.vis = vis
656
+ self.save_dir = save_dir
657
+ self.filename = filename
658
+ self.save = True if self.save_dir is not None else False
659
+
660
+ def update_num_hops(self, num_hops):
661
+ if num_hops is not None:
662
+ return num_hops
663
+
664
+ k = 0
665
+ for module in self.model.modules():
666
+ if isinstance(module, MessagePassing):
667
+ k += 1
668
+ return k
669
+
670
+ def get_reward_func(self, value_func, node_idx=None):
671
+ if self.explain_graph:
672
+ node_idx = None
673
+ else:
674
+ assert node_idx is not None
675
+ return reward_func(reward_method=self.reward_method,
676
+ value_func=value_func,
677
+ node_idx=node_idx,
678
+ local_radius=self.local_radius,
679
+ sample_num=self.sample_num,
680
+ subgraph_building_method=self.subgraph_building_method)
681
+
682
+ def get_mcts_class(self, x, edge_index, node_idx: int = None, score_func: Callable = None):
683
+ if self.explain_graph:
684
+ node_idx = None
685
+ else:
686
+ assert node_idx is not None
687
+ return MCTS(x, edge_index,
688
+ node_idx=node_idx,
689
+ device=self.device,
690
+ score_func=score_func,
691
+ num_hops=self.num_hops,
692
+ n_rollout=self.rollout,
693
+ min_atoms=self.min_atoms,
694
+ c_puct=self.c_puct,
695
+ expand_atoms=self.expand_atoms,
696
+ high2low=self.high2low)
697
+
698
+ def visualization(self, results: list,
699
+ max_nodes: int, plot_utils: PlotUtils, words: Optional[list] = None,
700
+ y: Optional[Tensor] = None, title_sentence: Optional[str] = None,
701
+ vis_name: Optional[str] = None):
702
+ if self.save:
703
+ if vis_name is None:
704
+ vis_name = f"{self.filename}.png"
705
+ else:
706
+ vis_name = None
707
+ tree_node_x = find_closest_node_result(results, max_nodes=max_nodes)
708
+ if self.explain_graph:
709
+ if words is not None:
710
+ plot_utils.plot(tree_node_x.ori_graph,
711
+ tree_node_x.coalition,
712
+ words=words,
713
+ title_sentence=title_sentence,
714
+ figname=vis_name)
715
+ else:
716
+ plot_utils.plot(tree_node_x.ori_graph,
717
+ tree_node_x.coalition,
718
+ x=tree_node_x.data.x,
719
+ title_sentence=title_sentence,
720
+ figname=vis_name)
721
+ else:
722
+ subset = self.mcts_state_map.subset
723
+ subgraph_y = y[subset].to('cpu')
724
+ subgraph_y = torch.tensor([subgraph_y[node].item()
725
+ for node in tree_node_x.ori_graph.nodes()])
726
+ plot_utils.plot(tree_node_x.ori_graph,
727
+ tree_node_x.coalition,
728
+ node_idx=self.mcts_state_map.new_node_idx,
729
+ title_sentence=title_sentence,
730
+ y=subgraph_y,
731
+ figname=vis_name)
732
+
733
+ def read_from_MCTSInfo_list(self, MCTSInfo_list):
734
+ if isinstance(MCTSInfo_list[0], dict):
735
+ ret_list = [MCTSNode(device=self.device).load_info(node_info) for node_info in MCTSInfo_list]
736
+ elif isinstance(MCTSInfo_list[0][0], dict):
737
+ ret_list = []
738
+ for single_label_MCTSInfo_list in MCTSInfo_list:
739
+ single_label_ret_list = [MCTSNode(device=self.device).load_info(node_info) for node_info in single_label_MCTSInfo_list]
740
+ ret_list.append(single_label_ret_list)
741
+ return ret_list
742
+
743
+ def write_from_MCTSNode_list(self, MCTSNode_list):
744
+ if isinstance(MCTSNode_list[0], MCTSNode):
745
+ ret_list = [node.info for node in MCTSNode_list]
746
+ elif isinstance(MCTSNode_list[0][0], MCTSNode):
747
+ ret_list = []
748
+ for single_label_MCTSNode_list in MCTSNode_list:
749
+ single_label_ret_list = [node.info for node in single_label_MCTSNode_list]
750
+ ret_list.append(single_label_ret_list)
751
+ return ret_list
752
+
753
+ def explain(self, x: Tensor, edge_index: Tensor, label: int,
754
+ max_nodes: int = 5,
755
+ node_idx: Optional[int] = None,
756
+ saved_MCTSInfo_list: Optional[List[List]] = None):
757
+
758
+ probs = self.model(x, edge_index).squeeze().softmax(dim=-1)
759
+ if self.explain_graph:
760
+ if saved_MCTSInfo_list:
761
+ results = self.read_from_MCTSInfo_list(saved_MCTSInfo_list)
762
+
763
+ if not saved_MCTSInfo_list:
764
+ value_func = GnnNetsGC2valueFunc(self.model, target_class=label)
765
+ payoff_func = self.get_reward_func(value_func)
766
+ self.mcts_state_map = self.get_mcts_class(x, edge_index, score_func=payoff_func)
767
+ results = self.mcts_state_map.mcts(verbose=self.verbose)
768
+
769
+ # l sharply score
770
+ value_func = GnnNetsGC2valueFunc(self.model, target_class=label)
771
+ tree_node_x = find_closest_node_result(results, max_nodes=max_nodes)
772
+
773
+ else:
774
+ if saved_MCTSInfo_list:
775
+ results = self.read_from_MCTSInfo_list(saved_MCTSInfo_list)
776
+
777
+ self.mcts_state_map = self.get_mcts_class(x, edge_index, node_idx=node_idx)
778
+ self.new_node_idx = self.mcts_state_map.new_node_idx
779
+ # mcts will extract the subgraph and relabel the nodes
780
+ value_func = GnnNetsNC2valueFunc(self.model,
781
+ node_idx=self.mcts_state_map.new_node_idx,
782
+ target_class=label)
783
+
784
+ if not saved_MCTSInfo_list:
785
+ payoff_func = self.get_reward_func(value_func,
786
+ node_idx=self.mcts_state_map.new_node_idx)
787
+ self.mcts_state_map.set_score_func(payoff_func)
788
+ results = self.mcts_state_map.mcts(verbose=self.verbose)
789
+
790
+ tree_node_x = find_closest_node_result(results, max_nodes=max_nodes)
791
+
792
+ # keep the important structure
793
+ masked_node_list = [node for node in range(tree_node_x.data.x.shape[0])
794
+ if node in tree_node_x.coalition]
795
+
796
+ # remove the important structure, for node_classification,
797
+ # remain the node_idx when remove the important structure
798
+ maskout_node_list = [node for node in range(tree_node_x.data.x.shape[0])
799
+ if node not in tree_node_x.coalition]
800
+ if not self.explain_graph:
801
+ maskout_node_list += [self.new_node_idx]
802
+
803
+ masked_score = gnn_score(masked_node_list,
804
+ tree_node_x.data,
805
+ value_func=value_func,
806
+ subgraph_building_method=self.subgraph_building_method)
807
+
808
+ maskout_score = gnn_score(maskout_node_list,
809
+ tree_node_x.data,
810
+ value_func=value_func,
811
+ subgraph_building_method=self.subgraph_building_method)
812
+
813
+ sparsity_score = sparsity(masked_node_list, tree_node_x.data,
814
+ subgraph_building_method=self.subgraph_building_method)
815
+
816
+ results = self.write_from_MCTSNode_list(results)
817
+ related_pred = {'masked': masked_score,
818
+ 'maskout': maskout_score,
819
+ 'origin': probs[node_idx, label].item(),
820
+ 'sparsity': sparsity_score}
821
+
822
+ return results, related_pred
823
+
824
+ def __call__(self, x: Tensor, edge_index: Tensor, **kwargs)\
825
+ -> Tuple[None, List, List[Dict]]:
826
+ r""" explain the GNN behavior for the graph using SubgraphX method
827
+ Args:
828
+ x (:obj:`torch.Tensor`): Node feature matrix with shape
829
+ :obj:`[num_nodes, dim_node_feature]`
830
+ edge_index (:obj:`torch.Tensor`): Graph connectivity in COO format
831
+ with shape :obj:`[2, num_edges]`
832
+ kwargs(:obj:`Dict`):
833
+ The additional parameters
834
+ - node_idx (:obj:`int`, :obj:`None`): The target node index when explain node classification task
835
+ - max_nodes (:obj:`int`, :obj:`None`): The number of nodes in the final explanation results
836
+ :rtype: (:obj:`None`, List[torch.Tensor], List[Dict])
837
+ """
838
+ node_idx = kwargs.get('node_idx')
839
+ max_nodes = kwargs.get('max_nodes') # default max subgraph size
840
+
841
+ # collect all the class index
842
+ labels = tuple(label for label in range(self.num_classes))
843
+ ex_labels = tuple(torch.tensor([label]).to(self.device) for label in labels)
844
+
845
+ related_preds = []
846
+ explanation_results = []
847
+ saved_results = None
848
+ if self.save:
849
+ if os.path.isfile(os.path.join(self.save_dir, f"{self.filename}.pt")):
850
+ saved_results = torch.load(os.path.join(self.save_dir, f"{self.filename}.pt"))
851
+
852
+ for label_idx, label in enumerate(ex_labels):
853
+ results, related_pred = self.explain(x, edge_index,
854
+ label=label,
855
+ max_nodes=max_nodes,
856
+ node_idx=node_idx,
857
+ saved_MCTSInfo_list=saved_results)
858
+ related_preds.append(related_pred)
859
+ explanation_results.append(results)
860
+
861
+ if self.save:
862
+ torch.save(explanation_results,
863
+ os.path.join(self.save_dir, f"{self.filename}.pt"))
864
+
865
+ return None, explanation_results, related_preds