hjxdl 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hdl/__init__.py +0 -0
- hdl/_version.py +16 -0
- hdl/args/__init__.py +0 -0
- hdl/args/loss_args.py +5 -0
- hdl/controllers/__init__.py +0 -0
- hdl/controllers/al/__init__.py +0 -0
- hdl/controllers/al/al.py +0 -0
- hdl/controllers/al/dispatcher.py +0 -0
- hdl/controllers/al/feedback.py +0 -0
- hdl/controllers/explain/__init__.py +0 -0
- hdl/controllers/explain/shapley.py +293 -0
- hdl/controllers/explain/subgraphx.py +865 -0
- hdl/controllers/train/__init__.py +0 -0
- hdl/controllers/train/rxn_train.py +219 -0
- hdl/controllers/train/train.py +50 -0
- hdl/controllers/train/train_ginet.py +316 -0
- hdl/controllers/train/trainer_base.py +155 -0
- hdl/controllers/train/trainer_iterative.py +389 -0
- hdl/data/__init__.py +0 -0
- hdl/data/dataset/__init__.py +0 -0
- hdl/data/dataset/base_dataset.py +98 -0
- hdl/data/dataset/fp/__init__.py +0 -0
- hdl/data/dataset/fp/fp_dataset.py +122 -0
- hdl/data/dataset/graph/__init__.py +0 -0
- hdl/data/dataset/graph/chiral.py +62 -0
- hdl/data/dataset/graph/gin.py +255 -0
- hdl/data/dataset/graph/molnet.py +362 -0
- hdl/data/dataset/loaders/__init__.py +0 -0
- hdl/data/dataset/loaders/chiral_graph.py +71 -0
- hdl/data/dataset/loaders/collate_funcs/__init__.py +0 -0
- hdl/data/dataset/loaders/collate_funcs/fp.py +56 -0
- hdl/data/dataset/loaders/collate_funcs/rxn.py +40 -0
- hdl/data/dataset/loaders/general.py +23 -0
- hdl/data/dataset/loaders/spliter.py +86 -0
- hdl/data/dataset/samplers/__init__.py +0 -0
- hdl/data/dataset/samplers/chiral.py +19 -0
- hdl/data/dataset/seq/__init__.py +0 -0
- hdl/data/dataset/seq/rxn_dataset.py +61 -0
- hdl/data/dataset/utils.py +31 -0
- hdl/data/to_mols.py +0 -0
- hdl/features/__init__.py +0 -0
- hdl/features/fp/__init__.py +0 -0
- hdl/features/fp/features_generators.py +235 -0
- hdl/features/graph/__init__.py +0 -0
- hdl/features/graph/featurization.py +297 -0
- hdl/features/utils/__init__.py +0 -0
- hdl/features/utils/utils.py +111 -0
- hdl/layers/__init__.py +0 -0
- hdl/layers/general/__init__.py +0 -0
- hdl/layers/general/gp.py +14 -0
- hdl/layers/general/linear.py +641 -0
- hdl/layers/graph/__init__.py +0 -0
- hdl/layers/graph/chiral_graph.py +230 -0
- hdl/layers/graph/gcn.py +16 -0
- hdl/layers/graph/gin.py +45 -0
- hdl/layers/graph/tetra.py +158 -0
- hdl/layers/graph/transformer.py +188 -0
- hdl/layers/sequential/__init__.py +0 -0
- hdl/metric_loss/__init__.py +0 -0
- hdl/metric_loss/loss.py +79 -0
- hdl/metric_loss/metric.py +178 -0
- hdl/metric_loss/multi_label.py +42 -0
- hdl/metric_loss/nt_xent.py +65 -0
- hdl/models/__init__.py +0 -0
- hdl/models/chiral_gnn.py +176 -0
- hdl/models/fast_transformer.py +234 -0
- hdl/models/ginet.py +189 -0
- hdl/models/linear.py +137 -0
- hdl/models/model_dict.py +18 -0
- hdl/models/norm_flows.py +33 -0
- hdl/models/optim_dict.py +16 -0
- hdl/models/rxn.py +63 -0
- hdl/models/utils.py +83 -0
- hdl/ops/__init__.py +0 -0
- hdl/ops/utils.py +42 -0
- hdl/optims/__init__.py +0 -0
- hdl/optims/nadam.py +86 -0
- hdl/utils/__init__.py +0 -0
- hdl/utils/chemical_tools/__init__.py +2 -0
- hdl/utils/chemical_tools/query_info.py +149 -0
- hdl/utils/chemical_tools/sdf.py +20 -0
- hdl/utils/database_tools/__init__.py +0 -0
- hdl/utils/database_tools/connect.py +28 -0
- hdl/utils/general/__init__.py +0 -0
- hdl/utils/general/glob.py +21 -0
- hdl/utils/schedulers/__init__.py +0 -0
- hdl/utils/schedulers/norm_lr.py +108 -0
- hjxdl-0.0.1.dist-info/METADATA +19 -0
- hjxdl-0.0.1.dist-info/RECORD +91 -0
- hjxdl-0.0.1.dist-info/WHEEL +5 -0
- hjxdl-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,865 @@
|
|
1
|
+
import os
|
2
|
+
import math
|
3
|
+
import copy
|
4
|
+
from typing import List, Tuple, Dict
|
5
|
+
from typing import Callable, Union, Optional
|
6
|
+
from collections import Counter
|
7
|
+
from functools import partial
|
8
|
+
from textwrap import wrap
|
9
|
+
|
10
|
+
import torch
|
11
|
+
import numpy as np
|
12
|
+
import networkx as nx
|
13
|
+
from rdkit import Chem
|
14
|
+
from torch import Tensor
|
15
|
+
from torch_geometric.data import Batch, Data
|
16
|
+
from torch_geometric.utils import to_networkx
|
17
|
+
import matplotlib.pyplot as plt
|
18
|
+
from torch_geometric.utils.num_nodes import maybe_num_nodes
|
19
|
+
from torch_geometric.nn.conv import MessagePassing
|
20
|
+
from torch_geometric.datasets import MoleculeNet
|
21
|
+
from torch_geometric.utils import remove_self_loops
|
22
|
+
|
23
|
+
from .shapley import GnnNetsGC2valueFunc, GnnNetsNC2valueFunc, \
|
24
|
+
gnn_score, mc_shapley, l_shapley, mc_l_shapley, NC_mc_l_shapley, sparsity
|
25
|
+
|
26
|
+
|
27
|
+
def find_closest_node_result(results, max_nodes):
|
28
|
+
""" return the highest reward tree_node with its subgraph is smaller than max_nodes """
|
29
|
+
results = sorted(results, key=lambda x: len(x.coalition))
|
30
|
+
|
31
|
+
result_node = results[0]
|
32
|
+
for result_idx in range(len(results)):
|
33
|
+
x = results[result_idx]
|
34
|
+
if len(x.coalition) <= max_nodes and x.P > result_node.P:
|
35
|
+
result_node = x
|
36
|
+
return result_node
|
37
|
+
|
38
|
+
|
39
|
+
def reward_func(reward_method, value_func, node_idx=None,
|
40
|
+
local_radius=4, sample_num=100,
|
41
|
+
subgraph_building_method='zero_filling'):
|
42
|
+
if reward_method.lower() == 'gnn_score':
|
43
|
+
return partial(gnn_score,
|
44
|
+
value_func=value_func,
|
45
|
+
subgraph_building_method=subgraph_building_method)
|
46
|
+
|
47
|
+
elif reward_method.lower() == 'mc_shapley':
|
48
|
+
return partial(mc_shapley,
|
49
|
+
value_func=value_func,
|
50
|
+
subgraph_building_method=subgraph_building_method,
|
51
|
+
sample_num=sample_num)
|
52
|
+
|
53
|
+
elif reward_method.lower() == 'l_shapley':
|
54
|
+
return partial(l_shapley,
|
55
|
+
local_radius=local_radius,
|
56
|
+
value_func=value_func,
|
57
|
+
subgraph_building_method=subgraph_building_method)
|
58
|
+
|
59
|
+
elif reward_method.lower() == 'mc_l_shapley':
|
60
|
+
return partial(mc_l_shapley,
|
61
|
+
local_radius=local_radius,
|
62
|
+
value_func=value_func,
|
63
|
+
subgraph_building_method=subgraph_building_method,
|
64
|
+
sample_num=sample_num)
|
65
|
+
|
66
|
+
elif reward_method.lower() == 'nc_mc_l_shapley':
|
67
|
+
assert node_idx is not None, " Wrong node idx input "
|
68
|
+
return partial(NC_mc_l_shapley,
|
69
|
+
node_idx=node_idx,
|
70
|
+
local_radius=local_radius,
|
71
|
+
value_func=value_func,
|
72
|
+
subgraph_building_method=subgraph_building_method,
|
73
|
+
sample_num=sample_num)
|
74
|
+
|
75
|
+
else:
|
76
|
+
raise NotImplementedError
|
77
|
+
|
78
|
+
|
79
|
+
def k_hop_subgraph_with_default_whole_graph(
|
80
|
+
edge_index, node_idx=None, num_hops=3, relabel_nodes=False,
|
81
|
+
num_nodes=None, flow='source_to_target'):
|
82
|
+
r"""Computes the :math:`k`-hop subgraph of :obj:`edge_index` around node
|
83
|
+
:attr:`node_idx`.
|
84
|
+
It returns (1) the nodes involved in the subgraph, (2) the filtered
|
85
|
+
:obj:`edge_index` connectivity, (3) the mapping from node indices in
|
86
|
+
:obj:`node_idx` to their new location, and (4) the edge mask indicating
|
87
|
+
which edges were preserved.
|
88
|
+
Args:
|
89
|
+
node_idx (int, list, tuple or :obj:`torch.Tensor`): The central
|
90
|
+
node(s).
|
91
|
+
num_hops: (int): The number of hops :math:`k`.
|
92
|
+
edge_index (LongTensor): The edge indices.
|
93
|
+
relabel_nodes (bool, optional): If set to :obj:`True`, the resulting
|
94
|
+
:obj:`edge_index` will be relabeled to hold consecutive indices
|
95
|
+
starting from zero. (default: :obj:`False`)
|
96
|
+
num_nodes (int, optional): The number of nodes, *i.e.*
|
97
|
+
:obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`)
|
98
|
+
flow (string, optional): The flow direction of :math:`k`-hop
|
99
|
+
aggregation (:obj:`"source_to_target"` or
|
100
|
+
:obj:`"target_to_source"`). (default: :obj:`"source_to_target"`)
|
101
|
+
:rtype: (:class:`LongTensor`, :class:`LongTensor`, :class:`LongTensor`,
|
102
|
+
:class:`BoolTensor`)
|
103
|
+
"""
|
104
|
+
|
105
|
+
num_nodes = maybe_num_nodes(edge_index, num_nodes)
|
106
|
+
|
107
|
+
assert flow in ['source_to_target', 'target_to_source']
|
108
|
+
if flow == 'target_to_source':
|
109
|
+
row, col = edge_index
|
110
|
+
else:
|
111
|
+
col, row = edge_index # edge_index 0 to 1, col: source, row: target
|
112
|
+
|
113
|
+
node_mask = row.new_empty(num_nodes, dtype=torch.bool)
|
114
|
+
edge_mask = row.new_empty(row.size(0), dtype=torch.bool)
|
115
|
+
|
116
|
+
inv = None
|
117
|
+
|
118
|
+
if node_idx is None:
|
119
|
+
subsets = torch.tensor([0])
|
120
|
+
cur_subsets = subsets
|
121
|
+
while 1:
|
122
|
+
node_mask.fill_(False)
|
123
|
+
node_mask[subsets] = True
|
124
|
+
torch.index_select(node_mask, 0, row, out=edge_mask)
|
125
|
+
subsets = torch.cat([subsets, col[edge_mask]]).unique()
|
126
|
+
if not cur_subsets.equal(subsets):
|
127
|
+
cur_subsets = subsets
|
128
|
+
else:
|
129
|
+
subset = subsets
|
130
|
+
break
|
131
|
+
else:
|
132
|
+
if isinstance(node_idx, (int, list, tuple)):
|
133
|
+
node_idx = torch.tensor([node_idx], device=row.device, dtype=torch.int64).flatten()
|
134
|
+
elif isinstance(node_idx, torch.Tensor) and len(node_idx.shape) == 0:
|
135
|
+
node_idx = torch.tensor([node_idx])
|
136
|
+
else:
|
137
|
+
node_idx = node_idx.to(row.device)
|
138
|
+
|
139
|
+
subsets = [node_idx]
|
140
|
+
for _ in range(num_hops):
|
141
|
+
node_mask.fill_(False)
|
142
|
+
node_mask[subsets[-1]] = True
|
143
|
+
torch.index_select(node_mask, 0, row, out=edge_mask)
|
144
|
+
subsets.append(col[edge_mask])
|
145
|
+
subset, inv = torch.cat(subsets).unique(return_inverse=True)
|
146
|
+
inv = inv[:node_idx.numel()]
|
147
|
+
|
148
|
+
node_mask.fill_(False)
|
149
|
+
node_mask[subset] = True
|
150
|
+
edge_mask = node_mask[row] & node_mask[col]
|
151
|
+
|
152
|
+
edge_index = edge_index[:, edge_mask]
|
153
|
+
|
154
|
+
if relabel_nodes:
|
155
|
+
node_idx = row.new_full((num_nodes,), -1)
|
156
|
+
node_idx[subset] = torch.arange(subset.size(0), device=row.device)
|
157
|
+
edge_index = node_idx[edge_index]
|
158
|
+
|
159
|
+
return subset, edge_index, inv, edge_mask # subset: key new node idx; value original node idx
|
160
|
+
|
161
|
+
|
162
|
+
def compute_scores(score_func, children):
|
163
|
+
results = []
|
164
|
+
for child in children:
|
165
|
+
if child.P == 0:
|
166
|
+
score = score_func(child.coalition, child.data)
|
167
|
+
else:
|
168
|
+
score = child.P
|
169
|
+
results.append(score)
|
170
|
+
return results
|
171
|
+
|
172
|
+
|
173
|
+
class PlotUtils(object):
|
174
|
+
def __init__(self, dataset_name, is_show=True):
|
175
|
+
self.dataset_name = dataset_name
|
176
|
+
self.is_show = is_show
|
177
|
+
|
178
|
+
def plot(self, graph, nodelist, figname, title_sentence=None, **kwargs):
|
179
|
+
""" plot function for different dataset """
|
180
|
+
if self.dataset_name.lower() in ['ba_2motifs', 'ba_lrp']:
|
181
|
+
self.plot_ba2motifs(graph, nodelist, title_sentence=title_sentence, figname=figname)
|
182
|
+
elif self.dataset_name.lower() in ['mutag'] + list(MoleculeNet.names.keys()):
|
183
|
+
x = kwargs.get('x')
|
184
|
+
self.plot_molecule(graph, nodelist, x, title_sentence=title_sentence, figname=figname)
|
185
|
+
elif self.dataset_name.lower() in ['ba_shapes', 'ba_community', 'tree_grid', 'tree_cycle']:
|
186
|
+
y = kwargs.get('y')
|
187
|
+
node_idx = kwargs.get('node_idx')
|
188
|
+
self.plot_bashapes(graph, nodelist, y, node_idx, title_sentence=title_sentence, figname=figname)
|
189
|
+
elif self.dataset_name.lower() in ['graph_sst2', 'graph_sst5', 'twitter']:
|
190
|
+
words = kwargs.get('words')
|
191
|
+
self.plot_sentence(graph, nodelist, words=words, title_sentence=title_sentence, figname=figname)
|
192
|
+
else:
|
193
|
+
raise NotImplementedError
|
194
|
+
|
195
|
+
def plot_subgraph(self,
|
196
|
+
graph,
|
197
|
+
nodelist,
|
198
|
+
colors: Union[None, str, List[str]] = '#FFA500',
|
199
|
+
labels=None,
|
200
|
+
edge_color='gray',
|
201
|
+
edgelist=None,
|
202
|
+
subgraph_edge_color='black',
|
203
|
+
title_sentence=None,
|
204
|
+
figname=None):
|
205
|
+
|
206
|
+
if edgelist is None:
|
207
|
+
edgelist = [(n_frm, n_to) for (n_frm, n_to) in graph.edges()
|
208
|
+
if n_frm in nodelist and n_to in nodelist]
|
209
|
+
pos = nx.kamada_kawai_layout(graph)
|
210
|
+
pos_nodelist = {k: v for k, v in pos.items() if k in nodelist}
|
211
|
+
|
212
|
+
nx.draw_networkx_nodes(graph, pos,
|
213
|
+
nodelist=list(graph.nodes()),
|
214
|
+
node_color=colors,
|
215
|
+
node_size=300)
|
216
|
+
|
217
|
+
nx.draw_networkx_edges(graph, pos, width=3, edge_color=edge_color, arrows=False)
|
218
|
+
|
219
|
+
nx.draw_networkx_edges(graph, pos=pos_nodelist,
|
220
|
+
edgelist=edgelist, width=6,
|
221
|
+
edge_color=subgraph_edge_color,
|
222
|
+
arrows=False)
|
223
|
+
|
224
|
+
if labels is not None:
|
225
|
+
nx.draw_networkx_labels(graph, pos, labels)
|
226
|
+
|
227
|
+
plt.axis('off')
|
228
|
+
if figname is not None:
|
229
|
+
plt.savefig(figname)
|
230
|
+
if title_sentence is not None:
|
231
|
+
plt.title('\n'.join(wrap(title_sentence, width=60)))
|
232
|
+
if self.is_show:
|
233
|
+
plt.show()
|
234
|
+
if figname is not None:
|
235
|
+
plt.close()
|
236
|
+
|
237
|
+
def plot_subgraph_with_nodes(self,
|
238
|
+
graph,
|
239
|
+
nodelist,
|
240
|
+
node_idx,
|
241
|
+
colors='#FFA500',
|
242
|
+
labels=None,
|
243
|
+
edge_color='gray',
|
244
|
+
edgelist=None,
|
245
|
+
subgraph_edge_color='black',
|
246
|
+
title_sentence=None,
|
247
|
+
figname=None):
|
248
|
+
node_idx = int(node_idx)
|
249
|
+
if edgelist is None:
|
250
|
+
edgelist = [(n_frm, n_to) for (n_frm, n_to) in graph.edges()
|
251
|
+
if n_frm in nodelist and n_to in nodelist]
|
252
|
+
|
253
|
+
pos = nx.kamada_kawai_layout(graph) # calculate according to graph.nodes()
|
254
|
+
pos_nodelist = {k: v for k, v in pos.items() if k in nodelist}
|
255
|
+
|
256
|
+
nx.draw_networkx_nodes(graph, pos,
|
257
|
+
nodelist=list(graph.nodes()),
|
258
|
+
node_color=colors,
|
259
|
+
node_size=300)
|
260
|
+
if isinstance(colors, list):
|
261
|
+
list_indices = int(np.where(np.array(graph.nodes()) == node_idx)[0])
|
262
|
+
node_idx_color = colors[list_indices]
|
263
|
+
else:
|
264
|
+
node_idx_color = colors
|
265
|
+
|
266
|
+
nx.draw_networkx_nodes(graph, pos=pos,
|
267
|
+
nodelist=[node_idx],
|
268
|
+
node_color=node_idx_color,
|
269
|
+
node_size=600)
|
270
|
+
|
271
|
+
nx.draw_networkx_edges(graph, pos, width=3, edge_color=edge_color, arrows=False)
|
272
|
+
|
273
|
+
nx.draw_networkx_edges(graph, pos=pos_nodelist,
|
274
|
+
edgelist=edgelist, width=3,
|
275
|
+
edge_color=subgraph_edge_color,
|
276
|
+
arrows=False)
|
277
|
+
|
278
|
+
if labels is not None:
|
279
|
+
nx.draw_networkx_labels(graph, pos, labels)
|
280
|
+
|
281
|
+
plt.axis('off')
|
282
|
+
if title_sentence is not None:
|
283
|
+
plt.title('\n'.join(wrap(title_sentence, width=60)))
|
284
|
+
|
285
|
+
if figname is not None:
|
286
|
+
plt.savefig(figname)
|
287
|
+
if self.is_show:
|
288
|
+
plt.show()
|
289
|
+
if figname is not None:
|
290
|
+
plt.close()
|
291
|
+
|
292
|
+
def plot_sentence(self, graph, nodelist, words, edgelist=None, title_sentence=None, figname=None):
|
293
|
+
pos = nx.kamada_kawai_layout(graph)
|
294
|
+
words_dict = {i: words[i] for i in graph.nodes}
|
295
|
+
if nodelist is not None:
|
296
|
+
pos_coalition = {k: v for k, v in pos.items() if k in nodelist}
|
297
|
+
nx.draw_networkx_nodes(graph, pos_coalition,
|
298
|
+
nodelist=nodelist,
|
299
|
+
node_color='yellow',
|
300
|
+
node_shape='o',
|
301
|
+
node_size=500)
|
302
|
+
if edgelist is None:
|
303
|
+
edgelist = [(n_frm, n_to) for (n_frm, n_to) in graph.edges()
|
304
|
+
if n_frm in nodelist and n_to in nodelist]
|
305
|
+
nx.draw_networkx_edges(graph, pos=pos_coalition, edgelist=edgelist, width=5, edge_color='yellow')
|
306
|
+
|
307
|
+
nx.draw_networkx_nodes(graph, pos, nodelist=list(graph.nodes()), node_size=300)
|
308
|
+
|
309
|
+
nx.draw_networkx_edges(graph, pos, width=2, edge_color='grey')
|
310
|
+
nx.draw_networkx_labels(graph, pos, words_dict)
|
311
|
+
|
312
|
+
plt.axis('off')
|
313
|
+
plt.title('\n'.join(wrap(' '.join(words), width=50)))
|
314
|
+
if title_sentence is not None:
|
315
|
+
string = '\n'.join(wrap(' '.join(words), width=50))
|
316
|
+
string += '\n'.join(wrap(title_sentence, width=60))
|
317
|
+
plt.title(string)
|
318
|
+
if figname is not None:
|
319
|
+
plt.savefig(figname)
|
320
|
+
if self.is_show:
|
321
|
+
plt.show()
|
322
|
+
if figname is not None:
|
323
|
+
plt.close()
|
324
|
+
|
325
|
+
def plot_ba2motifs(self,
|
326
|
+
graph,
|
327
|
+
nodelist,
|
328
|
+
edgelist=None,
|
329
|
+
title_sentence=None,
|
330
|
+
figname=None):
|
331
|
+
return self.plot_subgraph(graph, nodelist,
|
332
|
+
edgelist=edgelist,
|
333
|
+
title_sentence=title_sentence,
|
334
|
+
figname=figname)
|
335
|
+
|
336
|
+
def plot_molecule(self,
|
337
|
+
graph,
|
338
|
+
nodelist,
|
339
|
+
x,
|
340
|
+
edgelist=None,
|
341
|
+
title_sentence=None,
|
342
|
+
figname=None):
|
343
|
+
# collect the text information and node color
|
344
|
+
if self.dataset_name == 'mutag':
|
345
|
+
node_dict = {0: 'C', 1: 'N', 2: 'O', 3: 'F', 4: 'I', 5: 'Cl', 6: 'Br'}
|
346
|
+
node_idxs = {k: int(v) for k, v in enumerate(np.where(x.cpu().numpy() == 1)[1])}
|
347
|
+
node_labels = {k: node_dict[v] for k, v in node_idxs.items()}
|
348
|
+
node_color = ['#E49D1C', '#4970C6', '#FF5357', '#29A329', 'brown', 'darkslategray', '#F0EA00']
|
349
|
+
colors = [node_color[v % len(node_color)] for k, v in node_idxs.items()]
|
350
|
+
|
351
|
+
elif self.dataset_name in MoleculeNet.names.keys():
|
352
|
+
element_idxs = {k: int(v) for k, v in enumerate(x[:, 0])}
|
353
|
+
node_idxs = element_idxs
|
354
|
+
node_labels = {k: Chem.PeriodicTable.GetElementSymbol(Chem.GetPeriodicTable(), int(v))
|
355
|
+
for k, v in element_idxs.items()}
|
356
|
+
node_color = ['#29A329', 'lime', '#F0EA00', 'maroon', 'brown', '#E49D1C', '#4970C6', '#FF5357']
|
357
|
+
colors = [node_color[(v - 1) % len(node_color)] for k, v in node_idxs.items()]
|
358
|
+
else:
|
359
|
+
raise NotImplementedError
|
360
|
+
|
361
|
+
self.plot_subgraph(graph, nodelist,
|
362
|
+
colors=colors,
|
363
|
+
labels=node_labels,
|
364
|
+
edgelist=edgelist,
|
365
|
+
edge_color='gray',
|
366
|
+
subgraph_edge_color='black',
|
367
|
+
title_sentence=title_sentence,
|
368
|
+
figname=figname)
|
369
|
+
|
370
|
+
def plot_bashapes(self,
|
371
|
+
graph,
|
372
|
+
nodelist,
|
373
|
+
y,
|
374
|
+
node_idx,
|
375
|
+
edgelist=None,
|
376
|
+
title_sentence=None,
|
377
|
+
figname=None):
|
378
|
+
node_idxs = {k: int(v) for k, v in enumerate(y.reshape(-1).tolist())}
|
379
|
+
node_color = ['#FFA500', '#4970C6', '#FE0000', 'green']
|
380
|
+
colors = [node_color[v % len(node_color)] for k, v in node_idxs.items()]
|
381
|
+
self.plot_subgraph_with_nodes(graph,
|
382
|
+
nodelist,
|
383
|
+
node_idx,
|
384
|
+
colors,
|
385
|
+
edgelist=edgelist,
|
386
|
+
title_sentence=title_sentence,
|
387
|
+
figname=figname,
|
388
|
+
subgraph_edge_color='black')
|
389
|
+
|
390
|
+
|
391
|
+
class MCTSNode(object):
|
392
|
+
def __init__(self, coalition: list = None, data: Data = None, ori_graph: nx.Graph = None,
|
393
|
+
c_puct: float = 10.0, W: float = 0, N: int = 0, P: float = 0,
|
394
|
+
load_dict: Optional[Dict] = None, device='cpu'):
|
395
|
+
self.data = data
|
396
|
+
self.coalition = coalition
|
397
|
+
self.ori_graph = ori_graph
|
398
|
+
self.device = device
|
399
|
+
self.c_puct = c_puct
|
400
|
+
self.children = []
|
401
|
+
self.W = W # sum of node value
|
402
|
+
self.N = N # times of arrival
|
403
|
+
self.P = P # property score (reward)
|
404
|
+
if load_dict is not None:
|
405
|
+
self.load_info(load_dict)
|
406
|
+
|
407
|
+
def Q(self):
|
408
|
+
return self.W / self.N if self.N > 0 else 0
|
409
|
+
|
410
|
+
def U(self, n):
|
411
|
+
return self.c_puct * self.P * math.sqrt(n) / (1 + self.N)
|
412
|
+
|
413
|
+
@property
|
414
|
+
def info(self):
|
415
|
+
info_dict = {
|
416
|
+
'data': self.data.to('cpu'),
|
417
|
+
'coalition': self.coalition,
|
418
|
+
'ori_graph': self.ori_graph,
|
419
|
+
'W': self.W,
|
420
|
+
'N': self.N,
|
421
|
+
'P': self.P
|
422
|
+
}
|
423
|
+
return info_dict
|
424
|
+
|
425
|
+
def load_info(self, info_dict):
|
426
|
+
self.W = info_dict['W']
|
427
|
+
self.N = info_dict['N']
|
428
|
+
self.P = info_dict['P']
|
429
|
+
self.coalition = info_dict['coalition']
|
430
|
+
self.ori_graph = info_dict['ori_graph']
|
431
|
+
self.data = info_dict['data'].to(self.device)
|
432
|
+
self.children = []
|
433
|
+
return self
|
434
|
+
|
435
|
+
|
436
|
+
class MCTS(object):
|
437
|
+
r"""
|
438
|
+
Monte Carlo Tree Search Method
|
439
|
+
Args:
|
440
|
+
X (:obj:`torch.Tensor`): Input node features
|
441
|
+
edge_index (:obj:`torch.Tensor`): The edge indices.
|
442
|
+
num_hops (:obj:`int`): The number of hops :math:`k`.
|
443
|
+
n_rollout (:obj:`int`): The number of sequence to build the monte carlo tree.
|
444
|
+
min_atoms (:obj:`int`): The number of atoms for the subgraph in the monte carlo tree leaf node.
|
445
|
+
c_puct (:obj:`float`): The hyper-parameter to encourage exploration while searching.
|
446
|
+
expand_atoms (:obj:`int`): The number of children to expand.
|
447
|
+
high2low (:obj:`bool`): Whether to expand children tree node from high degree nodes to low degree nodes.
|
448
|
+
node_idx (:obj:`int`): The target node index to extract the neighborhood.
|
449
|
+
score_func (:obj:`Callable`): The reward function for tree node, such as mc_shapely and mc_l_shapely.
|
450
|
+
"""
|
451
|
+
def __init__(self, X: torch.Tensor, edge_index: torch.Tensor, num_hops: int,
|
452
|
+
n_rollout: int = 10, min_atoms: int = 3, c_puct: float = 10.0,
|
453
|
+
expand_atoms: int = 14, high2low: bool = False,
|
454
|
+
node_idx: int = None, score_func: Callable = None, device='cpu'):
|
455
|
+
|
456
|
+
self.X = X
|
457
|
+
self.edge_index = edge_index
|
458
|
+
self.device = device
|
459
|
+
self.num_hops = num_hops
|
460
|
+
self.data = Data(x=self.X, edge_index=self.edge_index)
|
461
|
+
graph_data = Data(x=self.X, edge_index=remove_self_loops(self.edge_index)[0])
|
462
|
+
self.graph = to_networkx(graph_data, to_undirected=True)
|
463
|
+
self.data = Batch.from_data_list([self.data])
|
464
|
+
self.num_nodes = self.graph.number_of_nodes()
|
465
|
+
self.score_func = score_func
|
466
|
+
self.n_rollout = n_rollout
|
467
|
+
self.min_atoms = min_atoms
|
468
|
+
self.c_puct = c_puct
|
469
|
+
self.expand_atoms = expand_atoms
|
470
|
+
self.high2low = high2low
|
471
|
+
self.new_node_idx = None
|
472
|
+
|
473
|
+
# extract the sub-graph and change the node indices.
|
474
|
+
if node_idx is not None:
|
475
|
+
self.ori_node_idx = node_idx
|
476
|
+
self.ori_graph = copy.copy(self.graph)
|
477
|
+
x, edge_index, subset, edge_mask, kwargs = \
|
478
|
+
self.__subgraph__(node_idx, self.X, self.edge_index, self.num_hops)
|
479
|
+
self.data = Batch.from_data_list([Data(x=x, edge_index=edge_index)])
|
480
|
+
self.graph = self.ori_graph.subgraph(subset.tolist())
|
481
|
+
mapping = {int(v): k for k, v in enumerate(subset)}
|
482
|
+
self.graph = nx.relabel_nodes(self.graph, mapping)
|
483
|
+
self.new_node_idx = torch.where(subset == self.ori_node_idx)[0].item()
|
484
|
+
self.num_nodes = self.graph.number_of_nodes()
|
485
|
+
self.subset = subset
|
486
|
+
|
487
|
+
self.root_coalition = sorted([node for node in range(self.num_nodes)])
|
488
|
+
self.MCTSNodeClass = partial(MCTSNode, data=self.data, ori_graph=self.graph,
|
489
|
+
c_puct=self.c_puct, device=self.device)
|
490
|
+
self.root = self.MCTSNodeClass(self.root_coalition)
|
491
|
+
self.state_map = {str(self.root.coalition): self.root}
|
492
|
+
|
493
|
+
def set_score_func(self, score_func):
|
494
|
+
self.score_func = score_func
|
495
|
+
|
496
|
+
@staticmethod
|
497
|
+
def __subgraph__(node_idx, x, edge_index, num_hops, **kwargs):
|
498
|
+
num_nodes, num_edges = x.size(0), edge_index.size(1)
|
499
|
+
subset, edge_index, _, edge_mask = k_hop_subgraph_with_default_whole_graph(
|
500
|
+
edge_index, node_idx, num_hops, relabel_nodes=True, num_nodes=num_nodes)
|
501
|
+
|
502
|
+
x = x[subset]
|
503
|
+
for key, item in kwargs.items():
|
504
|
+
if torch.is_tensor(item) and item.size(0) == num_nodes:
|
505
|
+
item = item[subset]
|
506
|
+
elif torch.is_tensor(item) and item.size(0) == num_edges:
|
507
|
+
item = item[edge_mask]
|
508
|
+
kwargs[key] = item
|
509
|
+
|
510
|
+
return x, edge_index, subset, edge_mask, kwargs
|
511
|
+
|
512
|
+
def mcts_rollout(self, tree_node):
|
513
|
+
cur_graph_coalition = tree_node.coalition
|
514
|
+
if len(cur_graph_coalition) <= self.min_atoms:
|
515
|
+
return tree_node.P
|
516
|
+
|
517
|
+
# Expand if this node has never been visited
|
518
|
+
if len(tree_node.children) == 0:
|
519
|
+
node_degree_list = list(self.graph.subgraph(cur_graph_coalition).degree)
|
520
|
+
node_degree_list = sorted(node_degree_list, key=lambda x: x[1], reverse=self.high2low)
|
521
|
+
all_nodes = [x[0] for x in node_degree_list]
|
522
|
+
|
523
|
+
if self.new_node_idx:
|
524
|
+
expand_nodes = [node for node in all_nodes if node != self.new_node_idx]
|
525
|
+
else:
|
526
|
+
expand_nodes = all_nodes
|
527
|
+
|
528
|
+
if len(all_nodes) > self.expand_atoms:
|
529
|
+
expand_nodes = expand_nodes[:self.expand_atoms]
|
530
|
+
|
531
|
+
for each_node in expand_nodes:
|
532
|
+
# for each node, pruning it and get the remaining sub-graph
|
533
|
+
# here we check the resulting sub-graphs and only keep the largest one
|
534
|
+
subgraph_coalition = [node for node in all_nodes if node != each_node]
|
535
|
+
|
536
|
+
subgraphs = [self.graph.subgraph(c)
|
537
|
+
for c in nx.connected_components(self.graph.subgraph(subgraph_coalition))]
|
538
|
+
|
539
|
+
if self.new_node_idx:
|
540
|
+
for sub in subgraphs:
|
541
|
+
if self.new_node_idx in list(sub.nodes()):
|
542
|
+
main_sub = sub
|
543
|
+
else:
|
544
|
+
main_sub = subgraphs[0]
|
545
|
+
|
546
|
+
for sub in subgraphs:
|
547
|
+
if sub.number_of_nodes() > main_sub.number_of_nodes():
|
548
|
+
main_sub = sub
|
549
|
+
|
550
|
+
new_graph_coalition = sorted(list(main_sub.nodes()))
|
551
|
+
|
552
|
+
# check the state map and merge the same sub-graph
|
553
|
+
find_same = False
|
554
|
+
for old_graph_node in self.state_map.values():
|
555
|
+
if Counter(old_graph_node.coalition) == Counter(new_graph_coalition):
|
556
|
+
new_node = old_graph_node
|
557
|
+
find_same = True
|
558
|
+
|
559
|
+
if not find_same:
|
560
|
+
new_node = self.MCTSNodeClass(new_graph_coalition)
|
561
|
+
self.state_map[str(new_graph_coalition)] = new_node
|
562
|
+
|
563
|
+
find_same_child = False
|
564
|
+
for cur_child in tree_node.children:
|
565
|
+
if Counter(cur_child.coalition) == Counter(new_graph_coalition):
|
566
|
+
find_same_child = True
|
567
|
+
|
568
|
+
if not find_same_child:
|
569
|
+
tree_node.children.append(new_node)
|
570
|
+
|
571
|
+
scores = compute_scores(self.score_func, tree_node.children)
|
572
|
+
for child, score in zip(tree_node.children, scores):
|
573
|
+
child.P = score
|
574
|
+
|
575
|
+
sum_count = sum([c.N for c in tree_node.children])
|
576
|
+
selected_node = max(tree_node.children, key=lambda x: x.Q() + x.U(sum_count))
|
577
|
+
v = self.mcts_rollout(selected_node)
|
578
|
+
selected_node.W += v
|
579
|
+
selected_node.N += 1
|
580
|
+
return v
|
581
|
+
|
582
|
+
def mcts(self, verbose=True):
|
583
|
+
if verbose:
|
584
|
+
print(f"The nodes in graph is {self.graph.number_of_nodes()}")
|
585
|
+
for rollout_idx in range(self.n_rollout):
|
586
|
+
self.mcts_rollout(self.root)
|
587
|
+
if verbose:
|
588
|
+
print(f"At the {rollout_idx} rollout, {len(self.state_map)} states that have been explored.")
|
589
|
+
|
590
|
+
explanations = [node for _, node in self.state_map.items()]
|
591
|
+
explanations = sorted(explanations, key=lambda x: x.P, reverse=True)
|
592
|
+
return explanations
|
593
|
+
|
594
|
+
|
595
|
+
class SubgraphX(object):
|
596
|
+
r"""
|
597
|
+
The implementation of paper
|
598
|
+
`On Explainability of Graph Neural Networks via Subgraph Explorations <https://arxiv.org/abs/2102.05152>`_.
|
599
|
+
Args:
|
600
|
+
model (:obj:`torch.nn.Module`): The target model prepared to explain
|
601
|
+
num_classes(:obj:`int`): Number of classes for the datasets
|
602
|
+
num_hops(:obj:`int`, :obj:`None`): The number of hops to extract neighborhood of target node
|
603
|
+
(default: :obj:`None`)
|
604
|
+
explain_graph(:obj:`bool`): Whether to explain graph classification model (default: :obj:`True`)
|
605
|
+
rollout(:obj:`int`): Number of iteration to get the prediction
|
606
|
+
min_atoms(:obj:`int`): Number of atoms of the leaf node in search tree
|
607
|
+
c_puct(:obj:`float`): The hyperparameter which encourages the exploration
|
608
|
+
expand_atoms(:obj:`int`): The number of atoms to expand
|
609
|
+
when extend the child nodes in the search tree
|
610
|
+
high2low(:obj:`bool`): Whether to expand children nodes from high degree to low degree when
|
611
|
+
extend the child nodes in the search tree (default: :obj:`False`)
|
612
|
+
local_radius(:obj:`int`): Number of local radius to calculate :obj:`l_shapley`, :obj:`mc_l_shapley`
|
613
|
+
sample_num(:obj:`int`): Sampling time of monte carlo sampling approximation for
|
614
|
+
:obj:`mc_shapley`, :obj:`mc_l_shapley` (default: :obj:`mc_l_shapley`)
|
615
|
+
reward_method(:obj:`str`): The command string to select the
|
616
|
+
subgraph_building_method(:obj:`str`): The command string for different subgraph building method,
|
617
|
+
such as :obj:`zero_filling`, :obj:`split` (default: :obj:`zero_filling`)
|
618
|
+
save_dir(:obj:`str`, :obj:`None`): Root directory to save the explanation results (default: :obj:`None`)
|
619
|
+
filename(:obj:`str`): The filename of results
|
620
|
+
vis(:obj:`bool`): Whether to show the visualization (default: :obj:`True`)
|
621
|
+
Example:
|
622
|
+
>>> # For graph classification task
|
623
|
+
>>> subgraphx = SubgraphX(model=model, num_classes=2)
|
624
|
+
>>> _, explanation_results, related_preds = subgraphx(x, edge_index)
|
625
|
+
"""
|
626
|
+
def __init__(self, model, num_classes: int, device, num_hops: Optional[int] = None, verbose: bool = False,
|
627
|
+
explain_graph: bool = True, rollout: int = 20, min_atoms: int = 5, c_puct: float = 10.0,
|
628
|
+
expand_atoms=14, high2low=False, local_radius=4, sample_num=100, reward_method='mc_l_shapley',
|
629
|
+
subgraph_building_method='zero_filling', save_dir: Optional[str] = None,
|
630
|
+
filename: str = 'example', vis: bool = True):
|
631
|
+
|
632
|
+
self.model = model
|
633
|
+
self.model.eval()
|
634
|
+
self.device = device
|
635
|
+
self.model.to(self.device)
|
636
|
+
self.num_classes = num_classes
|
637
|
+
self.num_hops = self.update_num_hops(num_hops)
|
638
|
+
self.explain_graph = explain_graph
|
639
|
+
self.verbose = verbose
|
640
|
+
|
641
|
+
# mcts hyper-parameters
|
642
|
+
self.rollout = rollout
|
643
|
+
self.min_atoms = min_atoms
|
644
|
+
self.c_puct = c_puct
|
645
|
+
self.expand_atoms = expand_atoms
|
646
|
+
self.high2low = high2low
|
647
|
+
|
648
|
+
# reward function hyper-parameters
|
649
|
+
self.local_radius = local_radius
|
650
|
+
self.sample_num = sample_num
|
651
|
+
self.reward_method = reward_method
|
652
|
+
self.subgraph_building_method = subgraph_building_method
|
653
|
+
|
654
|
+
# saving and visualization
|
655
|
+
self.vis = vis
|
656
|
+
self.save_dir = save_dir
|
657
|
+
self.filename = filename
|
658
|
+
self.save = True if self.save_dir is not None else False
|
659
|
+
|
660
|
+
def update_num_hops(self, num_hops):
|
661
|
+
if num_hops is not None:
|
662
|
+
return num_hops
|
663
|
+
|
664
|
+
k = 0
|
665
|
+
for module in self.model.modules():
|
666
|
+
if isinstance(module, MessagePassing):
|
667
|
+
k += 1
|
668
|
+
return k
|
669
|
+
|
670
|
+
def get_reward_func(self, value_func, node_idx=None):
|
671
|
+
if self.explain_graph:
|
672
|
+
node_idx = None
|
673
|
+
else:
|
674
|
+
assert node_idx is not None
|
675
|
+
return reward_func(reward_method=self.reward_method,
|
676
|
+
value_func=value_func,
|
677
|
+
node_idx=node_idx,
|
678
|
+
local_radius=self.local_radius,
|
679
|
+
sample_num=self.sample_num,
|
680
|
+
subgraph_building_method=self.subgraph_building_method)
|
681
|
+
|
682
|
+
def get_mcts_class(self, x, edge_index, node_idx: int = None, score_func: Callable = None):
|
683
|
+
if self.explain_graph:
|
684
|
+
node_idx = None
|
685
|
+
else:
|
686
|
+
assert node_idx is not None
|
687
|
+
return MCTS(x, edge_index,
|
688
|
+
node_idx=node_idx,
|
689
|
+
device=self.device,
|
690
|
+
score_func=score_func,
|
691
|
+
num_hops=self.num_hops,
|
692
|
+
n_rollout=self.rollout,
|
693
|
+
min_atoms=self.min_atoms,
|
694
|
+
c_puct=self.c_puct,
|
695
|
+
expand_atoms=self.expand_atoms,
|
696
|
+
high2low=self.high2low)
|
697
|
+
|
698
|
+
def visualization(self, results: list,
|
699
|
+
max_nodes: int, plot_utils: PlotUtils, words: Optional[list] = None,
|
700
|
+
y: Optional[Tensor] = None, title_sentence: Optional[str] = None,
|
701
|
+
vis_name: Optional[str] = None):
|
702
|
+
if self.save:
|
703
|
+
if vis_name is None:
|
704
|
+
vis_name = f"{self.filename}.png"
|
705
|
+
else:
|
706
|
+
vis_name = None
|
707
|
+
tree_node_x = find_closest_node_result(results, max_nodes=max_nodes)
|
708
|
+
if self.explain_graph:
|
709
|
+
if words is not None:
|
710
|
+
plot_utils.plot(tree_node_x.ori_graph,
|
711
|
+
tree_node_x.coalition,
|
712
|
+
words=words,
|
713
|
+
title_sentence=title_sentence,
|
714
|
+
figname=vis_name)
|
715
|
+
else:
|
716
|
+
plot_utils.plot(tree_node_x.ori_graph,
|
717
|
+
tree_node_x.coalition,
|
718
|
+
x=tree_node_x.data.x,
|
719
|
+
title_sentence=title_sentence,
|
720
|
+
figname=vis_name)
|
721
|
+
else:
|
722
|
+
subset = self.mcts_state_map.subset
|
723
|
+
subgraph_y = y[subset].to('cpu')
|
724
|
+
subgraph_y = torch.tensor([subgraph_y[node].item()
|
725
|
+
for node in tree_node_x.ori_graph.nodes()])
|
726
|
+
plot_utils.plot(tree_node_x.ori_graph,
|
727
|
+
tree_node_x.coalition,
|
728
|
+
node_idx=self.mcts_state_map.new_node_idx,
|
729
|
+
title_sentence=title_sentence,
|
730
|
+
y=subgraph_y,
|
731
|
+
figname=vis_name)
|
732
|
+
|
733
|
+
def read_from_MCTSInfo_list(self, MCTSInfo_list):
|
734
|
+
if isinstance(MCTSInfo_list[0], dict):
|
735
|
+
ret_list = [MCTSNode(device=self.device).load_info(node_info) for node_info in MCTSInfo_list]
|
736
|
+
elif isinstance(MCTSInfo_list[0][0], dict):
|
737
|
+
ret_list = []
|
738
|
+
for single_label_MCTSInfo_list in MCTSInfo_list:
|
739
|
+
single_label_ret_list = [MCTSNode(device=self.device).load_info(node_info) for node_info in single_label_MCTSInfo_list]
|
740
|
+
ret_list.append(single_label_ret_list)
|
741
|
+
return ret_list
|
742
|
+
|
743
|
+
def write_from_MCTSNode_list(self, MCTSNode_list):
|
744
|
+
if isinstance(MCTSNode_list[0], MCTSNode):
|
745
|
+
ret_list = [node.info for node in MCTSNode_list]
|
746
|
+
elif isinstance(MCTSNode_list[0][0], MCTSNode):
|
747
|
+
ret_list = []
|
748
|
+
for single_label_MCTSNode_list in MCTSNode_list:
|
749
|
+
single_label_ret_list = [node.info for node in single_label_MCTSNode_list]
|
750
|
+
ret_list.append(single_label_ret_list)
|
751
|
+
return ret_list
|
752
|
+
|
753
|
+
def explain(self, x: Tensor, edge_index: Tensor, label: int,
|
754
|
+
max_nodes: int = 5,
|
755
|
+
node_idx: Optional[int] = None,
|
756
|
+
saved_MCTSInfo_list: Optional[List[List]] = None):
|
757
|
+
|
758
|
+
probs = self.model(x, edge_index).squeeze().softmax(dim=-1)
|
759
|
+
if self.explain_graph:
|
760
|
+
if saved_MCTSInfo_list:
|
761
|
+
results = self.read_from_MCTSInfo_list(saved_MCTSInfo_list)
|
762
|
+
|
763
|
+
if not saved_MCTSInfo_list:
|
764
|
+
value_func = GnnNetsGC2valueFunc(self.model, target_class=label)
|
765
|
+
payoff_func = self.get_reward_func(value_func)
|
766
|
+
self.mcts_state_map = self.get_mcts_class(x, edge_index, score_func=payoff_func)
|
767
|
+
results = self.mcts_state_map.mcts(verbose=self.verbose)
|
768
|
+
|
769
|
+
# l sharply score
|
770
|
+
value_func = GnnNetsGC2valueFunc(self.model, target_class=label)
|
771
|
+
tree_node_x = find_closest_node_result(results, max_nodes=max_nodes)
|
772
|
+
|
773
|
+
else:
|
774
|
+
if saved_MCTSInfo_list:
|
775
|
+
results = self.read_from_MCTSInfo_list(saved_MCTSInfo_list)
|
776
|
+
|
777
|
+
self.mcts_state_map = self.get_mcts_class(x, edge_index, node_idx=node_idx)
|
778
|
+
self.new_node_idx = self.mcts_state_map.new_node_idx
|
779
|
+
# mcts will extract the subgraph and relabel the nodes
|
780
|
+
value_func = GnnNetsNC2valueFunc(self.model,
|
781
|
+
node_idx=self.mcts_state_map.new_node_idx,
|
782
|
+
target_class=label)
|
783
|
+
|
784
|
+
if not saved_MCTSInfo_list:
|
785
|
+
payoff_func = self.get_reward_func(value_func,
|
786
|
+
node_idx=self.mcts_state_map.new_node_idx)
|
787
|
+
self.mcts_state_map.set_score_func(payoff_func)
|
788
|
+
results = self.mcts_state_map.mcts(verbose=self.verbose)
|
789
|
+
|
790
|
+
tree_node_x = find_closest_node_result(results, max_nodes=max_nodes)
|
791
|
+
|
792
|
+
# keep the important structure
|
793
|
+
masked_node_list = [node for node in range(tree_node_x.data.x.shape[0])
|
794
|
+
if node in tree_node_x.coalition]
|
795
|
+
|
796
|
+
# remove the important structure, for node_classification,
|
797
|
+
# remain the node_idx when remove the important structure
|
798
|
+
maskout_node_list = [node for node in range(tree_node_x.data.x.shape[0])
|
799
|
+
if node not in tree_node_x.coalition]
|
800
|
+
if not self.explain_graph:
|
801
|
+
maskout_node_list += [self.new_node_idx]
|
802
|
+
|
803
|
+
masked_score = gnn_score(masked_node_list,
|
804
|
+
tree_node_x.data,
|
805
|
+
value_func=value_func,
|
806
|
+
subgraph_building_method=self.subgraph_building_method)
|
807
|
+
|
808
|
+
maskout_score = gnn_score(maskout_node_list,
|
809
|
+
tree_node_x.data,
|
810
|
+
value_func=value_func,
|
811
|
+
subgraph_building_method=self.subgraph_building_method)
|
812
|
+
|
813
|
+
sparsity_score = sparsity(masked_node_list, tree_node_x.data,
|
814
|
+
subgraph_building_method=self.subgraph_building_method)
|
815
|
+
|
816
|
+
results = self.write_from_MCTSNode_list(results)
|
817
|
+
related_pred = {'masked': masked_score,
|
818
|
+
'maskout': maskout_score,
|
819
|
+
'origin': probs[node_idx, label].item(),
|
820
|
+
'sparsity': sparsity_score}
|
821
|
+
|
822
|
+
return results, related_pred
|
823
|
+
|
824
|
+
def __call__(self, x: Tensor, edge_index: Tensor, **kwargs)\
|
825
|
+
-> Tuple[None, List, List[Dict]]:
|
826
|
+
r""" explain the GNN behavior for the graph using SubgraphX method
|
827
|
+
Args:
|
828
|
+
x (:obj:`torch.Tensor`): Node feature matrix with shape
|
829
|
+
:obj:`[num_nodes, dim_node_feature]`
|
830
|
+
edge_index (:obj:`torch.Tensor`): Graph connectivity in COO format
|
831
|
+
with shape :obj:`[2, num_edges]`
|
832
|
+
kwargs(:obj:`Dict`):
|
833
|
+
The additional parameters
|
834
|
+
- node_idx (:obj:`int`, :obj:`None`): The target node index when explain node classification task
|
835
|
+
- max_nodes (:obj:`int`, :obj:`None`): The number of nodes in the final explanation results
|
836
|
+
:rtype: (:obj:`None`, List[torch.Tensor], List[Dict])
|
837
|
+
"""
|
838
|
+
node_idx = kwargs.get('node_idx')
|
839
|
+
max_nodes = kwargs.get('max_nodes') # default max subgraph size
|
840
|
+
|
841
|
+
# collect all the class index
|
842
|
+
labels = tuple(label for label in range(self.num_classes))
|
843
|
+
ex_labels = tuple(torch.tensor([label]).to(self.device) for label in labels)
|
844
|
+
|
845
|
+
related_preds = []
|
846
|
+
explanation_results = []
|
847
|
+
saved_results = None
|
848
|
+
if self.save:
|
849
|
+
if os.path.isfile(os.path.join(self.save_dir, f"{self.filename}.pt")):
|
850
|
+
saved_results = torch.load(os.path.join(self.save_dir, f"{self.filename}.pt"))
|
851
|
+
|
852
|
+
for label_idx, label in enumerate(ex_labels):
|
853
|
+
results, related_pred = self.explain(x, edge_index,
|
854
|
+
label=label,
|
855
|
+
max_nodes=max_nodes,
|
856
|
+
node_idx=node_idx,
|
857
|
+
saved_MCTSInfo_list=saved_results)
|
858
|
+
related_preds.append(related_pred)
|
859
|
+
explanation_results.append(results)
|
860
|
+
|
861
|
+
if self.save:
|
862
|
+
torch.save(explanation_results,
|
863
|
+
os.path.join(self.save_dir, f"{self.filename}.pt"))
|
864
|
+
|
865
|
+
return None, explanation_results, related_preds
|