hjxdl 0.0.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. hjxdl-0.0.1/.github/workflows/python-publish.yml +36 -0
  2. hjxdl-0.0.1/.gitignore +70 -0
  3. hjxdl-0.0.1/.gitmodules +0 -0
  4. hjxdl-0.0.1/MANIFEST.in +2 -0
  5. hjxdl-0.0.1/PKG-INFO +19 -0
  6. hjxdl-0.0.1/README.md +6 -0
  7. hjxdl-0.0.1/__init__.py +0 -0
  8. hjxdl-0.0.1/hdl/__init__.py +0 -0
  9. hjxdl-0.0.1/hdl/_version.py +16 -0
  10. hjxdl-0.0.1/hdl/args/__init__.py +0 -0
  11. hjxdl-0.0.1/hdl/args/loss_args.py +5 -0
  12. hjxdl-0.0.1/hdl/controllers/__init__.py +0 -0
  13. hjxdl-0.0.1/hdl/controllers/al/__init__.py +0 -0
  14. hjxdl-0.0.1/hdl/controllers/al/al.py +0 -0
  15. hjxdl-0.0.1/hdl/controllers/al/dispatcher.py +0 -0
  16. hjxdl-0.0.1/hdl/controllers/al/feedback.py +0 -0
  17. hjxdl-0.0.1/hdl/controllers/explain/__init__.py +0 -0
  18. hjxdl-0.0.1/hdl/controllers/explain/shapley.py +293 -0
  19. hjxdl-0.0.1/hdl/controllers/explain/subgraphx.py +865 -0
  20. hjxdl-0.0.1/hdl/controllers/predictors/gin_predictor.py +2 -0
  21. hjxdl-0.0.1/hdl/controllers/predictors/rxn_predictor.py +113 -0
  22. hjxdl-0.0.1/hdl/controllers/predictors/torch_predictor.py +28 -0
  23. hjxdl-0.0.1/hdl/controllers/train/__init__.py +0 -0
  24. hjxdl-0.0.1/hdl/controllers/train/rxn_train.py +219 -0
  25. hjxdl-0.0.1/hdl/controllers/train/train.py +50 -0
  26. hjxdl-0.0.1/hdl/controllers/train/train_ginet.py +316 -0
  27. hjxdl-0.0.1/hdl/controllers/train/trainer_base.py +155 -0
  28. hjxdl-0.0.1/hdl/controllers/train/trainer_iterative.py +389 -0
  29. hjxdl-0.0.1/hdl/data/__init__.py +0 -0
  30. hjxdl-0.0.1/hdl/data/dataset/__init__.py +0 -0
  31. hjxdl-0.0.1/hdl/data/dataset/base_dataset.py +98 -0
  32. hjxdl-0.0.1/hdl/data/dataset/fp/__init__.py +0 -0
  33. hjxdl-0.0.1/hdl/data/dataset/fp/fp_dataset.py +122 -0
  34. hjxdl-0.0.1/hdl/data/dataset/graph/__init__.py +0 -0
  35. hjxdl-0.0.1/hdl/data/dataset/graph/chiral.py +62 -0
  36. hjxdl-0.0.1/hdl/data/dataset/graph/gin.py +255 -0
  37. hjxdl-0.0.1/hdl/data/dataset/graph/molnet.py +362 -0
  38. hjxdl-0.0.1/hdl/data/dataset/loaders/__init__.py +0 -0
  39. hjxdl-0.0.1/hdl/data/dataset/loaders/chiral_graph.py +71 -0
  40. hjxdl-0.0.1/hdl/data/dataset/loaders/collate_funcs/__init__.py +0 -0
  41. hjxdl-0.0.1/hdl/data/dataset/loaders/collate_funcs/fp.py +56 -0
  42. hjxdl-0.0.1/hdl/data/dataset/loaders/collate_funcs/rxn.py +40 -0
  43. hjxdl-0.0.1/hdl/data/dataset/loaders/general.py +23 -0
  44. hjxdl-0.0.1/hdl/data/dataset/loaders/spliter.py +86 -0
  45. hjxdl-0.0.1/hdl/data/dataset/samplers/__init__.py +0 -0
  46. hjxdl-0.0.1/hdl/data/dataset/samplers/chiral.py +19 -0
  47. hjxdl-0.0.1/hdl/data/dataset/seq/__init__.py +0 -0
  48. hjxdl-0.0.1/hdl/data/dataset/seq/rxn_dataset.py +61 -0
  49. hjxdl-0.0.1/hdl/data/dataset/utils.py +31 -0
  50. hjxdl-0.0.1/hdl/data/to_mols.py +0 -0
  51. hjxdl-0.0.1/hdl/features/__init__.py +0 -0
  52. hjxdl-0.0.1/hdl/features/fp/__init__.py +0 -0
  53. hjxdl-0.0.1/hdl/features/fp/features_generators.py +235 -0
  54. hjxdl-0.0.1/hdl/features/graph/__init__.py +0 -0
  55. hjxdl-0.0.1/hdl/features/graph/featurization.py +297 -0
  56. hjxdl-0.0.1/hdl/features/utils/__init__.py +0 -0
  57. hjxdl-0.0.1/hdl/features/utils/utils.py +111 -0
  58. hjxdl-0.0.1/hdl/features/vocab.txt +591 -0
  59. hjxdl-0.0.1/hdl/include/add2.h +4 -0
  60. hjxdl-0.0.1/hdl/kernel/add2_kernel.cu +18 -0
  61. hjxdl-0.0.1/hdl/kernel/test +0 -0
  62. hjxdl-0.0.1/hdl/kernel/test.cu +14 -0
  63. hjxdl-0.0.1/hdl/layers/__init__.py +0 -0
  64. hjxdl-0.0.1/hdl/layers/general/__init__.py +0 -0
  65. hjxdl-0.0.1/hdl/layers/general/gp.py +14 -0
  66. hjxdl-0.0.1/hdl/layers/general/linear.py +641 -0
  67. hjxdl-0.0.1/hdl/layers/graph/__init__.py +0 -0
  68. hjxdl-0.0.1/hdl/layers/graph/chiral_graph.py +230 -0
  69. hjxdl-0.0.1/hdl/layers/graph/gcn.py +16 -0
  70. hjxdl-0.0.1/hdl/layers/graph/gin.py +45 -0
  71. hjxdl-0.0.1/hdl/layers/graph/tetra.py +158 -0
  72. hjxdl-0.0.1/hdl/layers/graph/transformer.py +188 -0
  73. hjxdl-0.0.1/hdl/layers/sequential/__init__.py +0 -0
  74. hjxdl-0.0.1/hdl/metric_loss/__init__.py +0 -0
  75. hjxdl-0.0.1/hdl/metric_loss/loss.py +79 -0
  76. hjxdl-0.0.1/hdl/metric_loss/metric.py +178 -0
  77. hjxdl-0.0.1/hdl/metric_loss/multi_label.py +42 -0
  78. hjxdl-0.0.1/hdl/metric_loss/nt_xent.py +65 -0
  79. hjxdl-0.0.1/hdl/models/__init__.py +0 -0
  80. hjxdl-0.0.1/hdl/models/chiral_gnn.py +176 -0
  81. hjxdl-0.0.1/hdl/models/fast_transformer.py +234 -0
  82. hjxdl-0.0.1/hdl/models/ginet.py +189 -0
  83. hjxdl-0.0.1/hdl/models/linear.py +137 -0
  84. hjxdl-0.0.1/hdl/models/model_dict.py +18 -0
  85. hjxdl-0.0.1/hdl/models/norm_flows.py +33 -0
  86. hjxdl-0.0.1/hdl/models/optim_dict.py +16 -0
  87. hjxdl-0.0.1/hdl/models/rxn.py +63 -0
  88. hjxdl-0.0.1/hdl/models/utils.py +83 -0
  89. hjxdl-0.0.1/hdl/ops/__init__.py +0 -0
  90. hjxdl-0.0.1/hdl/ops/utils.py +42 -0
  91. hjxdl-0.0.1/hdl/optims/__init__.py +0 -0
  92. hjxdl-0.0.1/hdl/optims/nadam.py +86 -0
  93. hjxdl-0.0.1/hdl/pytorch/add2_ops.cpp +22 -0
  94. hjxdl-0.0.1/hdl/utils/__init__.py +0 -0
  95. hjxdl-0.0.1/hdl/utils/chemical_tools/__init__.py +2 -0
  96. hjxdl-0.0.1/hdl/utils/chemical_tools/query_info.py +149 -0
  97. hjxdl-0.0.1/hdl/utils/chemical_tools/sdf.py +20 -0
  98. hjxdl-0.0.1/hdl/utils/database_tools/__init__.py +0 -0
  99. hjxdl-0.0.1/hdl/utils/database_tools/connect.py +28 -0
  100. hjxdl-0.0.1/hdl/utils/general/__init__.py +0 -0
  101. hjxdl-0.0.1/hdl/utils/general/glob.py +21 -0
  102. hjxdl-0.0.1/hdl/utils/schedulers/__init__.py +0 -0
  103. hjxdl-0.0.1/hdl/utils/schedulers/norm_lr.py +108 -0
  104. hjxdl-0.0.1/hjxdl.egg-info/PKG-INFO +19 -0
  105. hjxdl-0.0.1/hjxdl.egg-info/SOURCES.txt +110 -0
  106. hjxdl-0.0.1/hjxdl.egg-info/dependency_links.txt +1 -0
  107. hjxdl-0.0.1/hjxdl.egg-info/top_level.txt +1 -0
  108. hjxdl-0.0.1/pyproject.toml +8 -0
  109. hjxdl-0.0.1/setup.cfg +4 -0
  110. hjxdl-0.0.1/setup.py +39 -0
  111. hjxdl-0.0.1/update_main.sh +28 -0
  112. hjxdl-0.0.1/version.txt +1 -0
@@ -0,0 +1,36 @@
1
+ name: Build and Publish to PyPI
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - main # 确保这匹配你的主分支名
7
+
8
+ jobs:
9
+ build-and-publish:
10
+ runs-on: ubuntu-latest
11
+
12
+ steps:
13
+ - uses: actions/checkout@v2
14
+ with:
15
+ fetch-depth: 0 # 确保能够获取全部历史和标签
16
+
17
+ - name: Set up Python
18
+ uses: actions/setup-python@v2
19
+ with:
20
+ python-version: '3.x'
21
+
22
+ - name: Install dependencies
23
+ run: |
24
+ python -m pip install --upgrade pip
25
+ pip install setuptools wheel
26
+ pip install build # 确保在构建前安装 build 模块
27
+
28
+ - name: Build package
29
+ run: |
30
+ python -m build
31
+
32
+ - name: Publish package to PyPI
33
+ uses: pypa/gh-action-pypi-publish@v1.4.2
34
+ with:
35
+ user: __token__
36
+ password: ${{ secrets.PYPI_API_TOKEN }}
hjxdl-0.0.1/.gitignore ADDED
@@ -0,0 +1,70 @@
1
+ # These are some examples of commonly ignored file patterns.
2
+ # You should customize this list as applicable to your project.
3
+ # Learn more about .gitignore:
4
+ # https://www.atlassian.com/git/tutorials/saving-changes/gitignore
5
+
6
+ # Node artifact files
7
+ node_modules/
8
+ dist/
9
+
10
+ # Compiled Java class files
11
+ *.class
12
+
13
+ # Compiled Python bytecode
14
+ *.py[cod]
15
+
16
+ # Log files
17
+ *.log
18
+
19
+ # Package files
20
+ *.jar
21
+
22
+ # Maven
23
+ target/
24
+ dist/
25
+
26
+ # JetBrains IDE
27
+ .idea/
28
+
29
+ # Unit test reports
30
+ TEST*.xml
31
+
32
+ # Generated by MacOS
33
+ .DS_Store
34
+
35
+ # Generated by Windows
36
+ Thumbs.db
37
+
38
+ # Applications
39
+ *.app
40
+ *.exe
41
+ *.war
42
+
43
+ # Large media files
44
+ *.mp4
45
+ *.tiff
46
+ *.avi
47
+ *.flv
48
+ *.mov
49
+ *.wmv
50
+
51
+ test.py
52
+
53
+ # Distribution / packaging
54
+ .Python
55
+ env/
56
+ build/
57
+ develop-eggs/
58
+ dist/
59
+ downloads/
60
+ eggs/
61
+ .eggs/
62
+ lib/
63
+ lib64/
64
+ parts/
65
+ sdist/
66
+ var/
67
+ *.egg-info/
68
+ .installed.cfg
69
+ *.egg
70
+ .vscode
File without changes
@@ -0,0 +1,2 @@
1
+ include versioneer.py
2
+ include hdl/_version.py
hjxdl-0.0.1/PKG-INFO ADDED
@@ -0,0 +1,19 @@
1
+ Metadata-Version: 2.1
2
+ Name: hjxdl
3
+ Version: 0.0.1
4
+ Summary: A collection of functions for Jupyter notebooks
5
+ Home-page: https://github.com/huluxiaohuowa/hdl
6
+ Author: Jianxing Hu
7
+ Author-email: j.hu@pku.edu.cn
8
+ Classifier: Programming Language :: Python :: 3
9
+ Classifier: License :: OSI Approved :: MIT License
10
+ Classifier: Operating System :: OS Independent
11
+ Requires-Python: >=3.6
12
+ Description-Content-Type: text/markdown
13
+
14
+ # DL framework by Jianxing
15
+
16
+ ```bash
17
+ git clone git@github.com:huluxiaohuowa/hdl.git
18
+ python setup.py install
19
+ ```
hjxdl-0.0.1/README.md ADDED
@@ -0,0 +1,6 @@
1
+ # DL framework by Jianxing
2
+
3
+ ```bash
4
+ git clone git@github.com:huluxiaohuowa/hdl.git
5
+ python setup.py install
6
+ ```
File without changes
File without changes
@@ -0,0 +1,16 @@
1
+ # file generated by setuptools_scm
2
+ # don't change, don't track in version control
3
+ TYPE_CHECKING = False
4
+ if TYPE_CHECKING:
5
+ from typing import Tuple, Union
6
+ VERSION_TUPLE = Tuple[Union[int, str], ...]
7
+ else:
8
+ VERSION_TUPLE = object
9
+
10
+ version: str
11
+ __version__: str
12
+ __version_tuple__: VERSION_TUPLE
13
+ version_tuple: VERSION_TUPLE
14
+
15
+ __version__ = version = '0.0.1'
16
+ __version_tuple__ = version_tuple = (0, 0, 1)
File without changes
@@ -0,0 +1,5 @@
1
+ from tap import Tap
2
+
3
+
4
+ class LossArgs(Tap):
5
+ reduction: str = 'mean'
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
@@ -0,0 +1,293 @@
1
+ import copy
2
+ import torch
3
+ import numpy as np
4
+ from scipy.special import comb
5
+ from itertools import combinations
6
+ import torch.nn.functional as F
7
+ from torch_geometric.utils import to_networkx
8
+ from torch_geometric.data import Data, Batch, Dataset, DataLoader
9
+
10
+
11
+ def GnnNetsGC2valueFunc(gnnNets, target_class):
12
+ def value_func(batch):
13
+ with torch.no_grad():
14
+ logits = gnnNets(data=batch)
15
+ probs = F.softmax(logits, dim=-1)
16
+ score = probs[:, target_class]
17
+ return score
18
+ return value_func
19
+
20
+
21
+ def GnnNetsNC2valueFunc(gnnNets_NC, node_idx, target_class):
22
+ def value_func(data):
23
+ with torch.no_grad():
24
+ logits = gnnNets_NC(data=data)
25
+ probs = F.softmax(logits, dim=-1)
26
+ # select the corresponding node prob through the node idx on all the sampling graphs
27
+ batch_size = data.batch.max() + 1
28
+ probs = probs.reshape(batch_size, -1, probs.shape[-1])
29
+ score = probs[:, node_idx, target_class]
30
+ return score
31
+ return value_func
32
+
33
+
34
+ def get_graph_build_func(build_method):
35
+ if build_method.lower() == 'zero_filling':
36
+ return graph_build_zero_filling
37
+ elif build_method.lower() == 'split':
38
+ return graph_build_split
39
+ else:
40
+ raise NotImplementedError
41
+
42
+
43
+ class MarginalSubgraphDataset(Dataset):
44
+ def __init__(self, data, exclude_mask, include_mask, subgraph_build_func):
45
+ self.num_nodes = data.num_nodes
46
+ self.X = data.x
47
+ self.edge_index = data.edge_index
48
+ self.device = self.X.device
49
+
50
+ self.label = data.y
51
+ self.exclude_mask = torch.tensor(exclude_mask).type(torch.float32).to(self.device)
52
+ self.include_mask = torch.tensor(include_mask).type(torch.float32).to(self.device)
53
+ self.subgraph_build_func = subgraph_build_func
54
+
55
+ def __len__(self):
56
+ return self.exclude_mask.shape[0]
57
+
58
+ def __getitem__(self, idx):
59
+ exclude_graph_X, exclude_graph_edge_index = self.subgraph_build_func(self.X, self.edge_index, self.exclude_mask[idx])
60
+ include_graph_X, include_graph_edge_index = self.subgraph_build_func(self.X, self.edge_index, self.include_mask[idx])
61
+ exclude_data = Data(x=exclude_graph_X, edge_index=exclude_graph_edge_index)
62
+ include_data = Data(x=include_graph_X, edge_index=include_graph_edge_index)
63
+ return exclude_data, include_data
64
+
65
+
66
+ def marginal_contribution(data: Data, exclude_mask: np.array, include_mask: np.array,
67
+ value_func, subgraph_build_func):
68
+ """ Calculate the marginal value for each pair. Here exclude_mask and include_mask are node mask. """
69
+ marginal_subgraph_dataset = MarginalSubgraphDataset(data, exclude_mask, include_mask, subgraph_build_func)
70
+ dataloader = DataLoader(marginal_subgraph_dataset, batch_size=256, shuffle=False, num_workers=0)
71
+
72
+ marginal_contribution_list = []
73
+
74
+ for exclude_data, include_data in dataloader:
75
+ exclude_values = value_func(exclude_data)
76
+ include_values = value_func(include_data)
77
+ margin_values = include_values - exclude_values
78
+ marginal_contribution_list.append(margin_values)
79
+
80
+ marginal_contributions = torch.cat(marginal_contribution_list, dim=0)
81
+ return marginal_contributions
82
+
83
+
84
+ def graph_build_zero_filling(X, edge_index, node_mask: np.array):
85
+ """ subgraph building through masking the unselected nodes with zero features """
86
+ ret_X = X * node_mask.unsqueeze(1)
87
+ return ret_X, edge_index
88
+
89
+
90
+ def graph_build_split(X, edge_index, node_mask: np.array):
91
+ """ subgraph building through spliting the selected nodes from the original graph """
92
+ ret_X = X
93
+ row, col = edge_index
94
+ edge_mask = (node_mask[row] == 1) & (node_mask[col] == 1)
95
+ ret_edge_index = edge_index[:, edge_mask]
96
+ return ret_X, ret_edge_index
97
+
98
+
99
+ def l_shapley(coalition: list, data: Data, local_radius: int,
100
+ value_func: str, subgraph_building_method='zero_filling'):
101
+ """ shapley value where players are local neighbor nodes """
102
+ graph = to_networkx(data)
103
+ num_nodes = graph.number_of_nodes()
104
+ subgraph_build_func = get_graph_build_func(subgraph_building_method)
105
+
106
+ local_region = copy.copy(coalition)
107
+ for k in range(local_radius - 1):
108
+ k_neiborhoood = []
109
+ for node in local_region:
110
+ k_neiborhoood += list(graph.neighbors(node))
111
+ local_region += k_neiborhoood
112
+ local_region = list(set(local_region))
113
+
114
+ set_exclude_masks = []
115
+ set_include_masks = []
116
+ nodes_around = [node for node in local_region if node not in coalition]
117
+ num_nodes_around = len(nodes_around)
118
+
119
+ for subset_len in range(0, num_nodes_around + 1):
120
+ node_exclude_subsets = combinations(nodes_around, subset_len)
121
+ for node_exclude_subset in node_exclude_subsets:
122
+ set_exclude_mask = np.ones(num_nodes)
123
+ set_exclude_mask[local_region] = 0.0
124
+ if node_exclude_subset:
125
+ set_exclude_mask[list(node_exclude_subset)] = 1.0
126
+ set_include_mask = set_exclude_mask.copy()
127
+ set_include_mask[coalition] = 1.0
128
+
129
+ set_exclude_masks.append(set_exclude_mask)
130
+ set_include_masks.append(set_include_mask)
131
+
132
+ exclude_mask = np.stack(set_exclude_masks, axis=0)
133
+ include_mask = np.stack(set_include_masks, axis=0)
134
+ num_players = len(nodes_around) + 1
135
+ num_player_in_set = num_players - 1 + len(coalition) - (1 - exclude_mask).sum(axis=1)
136
+ p = num_players
137
+ S = num_player_in_set
138
+ coeffs = torch.tensor(1.0 / comb(p, S) / (p - S + 1e-6))
139
+
140
+ marginal_contributions = \
141
+ marginal_contribution(data, exclude_mask, include_mask, value_func, subgraph_build_func)
142
+
143
+ l_shapley_value = (marginal_contributions.squeeze().cpu() * coeffs).sum().item()
144
+ return l_shapley_value
145
+
146
+
147
+ def mc_shapley(coalition: list, data: Data,
148
+ value_func: str, subgraph_building_method='zero_filling',
149
+ sample_num=1000) -> float:
150
+ """ monte carlo sampling approximation of the shapley value """
151
+ subset_build_func = get_graph_build_func(subgraph_building_method)
152
+
153
+ num_nodes = data.num_nodes
154
+ node_indices = np.arange(num_nodes)
155
+ coalition_placeholder = num_nodes
156
+ set_exclude_masks = []
157
+ set_include_masks = []
158
+
159
+ for example_idx in range(sample_num):
160
+ subset_nodes_from = [node for node in node_indices if node not in coalition]
161
+ random_nodes_permutation = np.array(subset_nodes_from + [coalition_placeholder])
162
+ random_nodes_permutation = np.random.permutation(random_nodes_permutation)
163
+ split_idx = np.where(random_nodes_permutation == coalition_placeholder)[0][0]
164
+ selected_nodes = random_nodes_permutation[:split_idx]
165
+ set_exclude_mask = np.zeros(num_nodes)
166
+ set_exclude_mask[selected_nodes] = 1.0
167
+ set_include_mask = set_exclude_mask.copy()
168
+ set_include_mask[coalition] = 1.0
169
+
170
+ set_exclude_masks.append(set_exclude_mask)
171
+ set_include_masks.append(set_include_mask)
172
+
173
+ exclude_mask = np.stack(set_exclude_masks, axis=0)
174
+ include_mask = np.stack(set_include_masks, axis=0)
175
+ marginal_contributions = marginal_contribution(data, exclude_mask, include_mask, value_func, subset_build_func)
176
+ mc_shapley_value = marginal_contributions.mean().item()
177
+
178
+ return mc_shapley_value
179
+
180
+
181
+ def mc_l_shapley(coalition: list, data: Data, local_radius: int,
182
+ value_func: str, subgraph_building_method='zero_filling',
183
+ sample_num=1000) -> float:
184
+ """ monte carlo sampling approximation of the l_shapley value """
185
+ graph = to_networkx(data)
186
+ num_nodes = graph.number_of_nodes()
187
+ subgraph_build_func = get_graph_build_func(subgraph_building_method)
188
+
189
+ local_region = copy.copy(coalition)
190
+ for k in range(local_radius - 1):
191
+ k_neiborhoood = []
192
+ for node in local_region:
193
+ k_neiborhoood += list(graph.neighbors(node))
194
+ local_region += k_neiborhoood
195
+ local_region = list(set(local_region))
196
+
197
+ coalition_placeholder = num_nodes
198
+ set_exclude_masks = []
199
+ set_include_masks = []
200
+ for example_idx in range(sample_num):
201
+ subset_nodes_from = [node for node in local_region if node not in coalition]
202
+ random_nodes_permutation = np.array(subset_nodes_from + [coalition_placeholder])
203
+ random_nodes_permutation = np.random.permutation(random_nodes_permutation)
204
+ split_idx = np.where(random_nodes_permutation == coalition_placeholder)[0][0]
205
+ selected_nodes = random_nodes_permutation[:split_idx]
206
+ set_exclude_mask = np.ones(num_nodes)
207
+ set_exclude_mask[local_region] = 0.0
208
+ set_exclude_mask[selected_nodes] = 1.0
209
+ set_include_mask = set_exclude_mask.copy()
210
+ set_include_mask[coalition] = 1.0
211
+
212
+ set_exclude_masks.append(set_exclude_mask)
213
+ set_include_masks.append(set_include_mask)
214
+
215
+ exclude_mask = np.stack(set_exclude_masks, axis=0)
216
+ include_mask = np.stack(set_include_masks, axis=0)
217
+ marginal_contributions = \
218
+ marginal_contribution(data, exclude_mask, include_mask, value_func, subgraph_build_func)
219
+
220
+ mc_l_shapley_value = (marginal_contributions).mean().item()
221
+ return mc_l_shapley_value
222
+
223
+
224
+ def gnn_score(coalition: list, data: Data, value_func: str,
225
+ subgraph_building_method='zero_filling') -> torch.Tensor:
226
+ """ the value of subgraph with selected nodes """
227
+ num_nodes = data.num_nodes
228
+ subgraph_build_func = get_graph_build_func(subgraph_building_method)
229
+ mask = torch.zeros(num_nodes).type(torch.float32).to(data.x.device)
230
+ mask[coalition] = 1.0
231
+ ret_x, ret_edge_index = subgraph_build_func(data.x, data.edge_index, mask)
232
+ mask_data = Data(x=ret_x, edge_index=ret_edge_index)
233
+ mask_data = Batch.from_data_list([mask_data])
234
+ score = value_func(mask_data)
235
+ # get the score of predicted class for graph or specific node idx
236
+ return score.item()
237
+
238
+
239
+ def NC_mc_l_shapley(coalition: list, data: Data, local_radius: int,
240
+ value_func: str, node_idx: int = -1,
241
+ subgraph_building_method='zero_filling', sample_num=1000) -> float:
242
+ """ monte carlo approximation of l_shapley where the target node is kept in both subgraph """
243
+ graph = to_networkx(data)
244
+ num_nodes = graph.number_of_nodes()
245
+ subgraph_build_func = get_graph_build_func(subgraph_building_method)
246
+
247
+ local_region = copy.copy(coalition)
248
+ for k in range(local_radius - 1):
249
+ k_neiborhoood = []
250
+ for node in local_region:
251
+ k_neiborhoood += list(graph.neighbors(node))
252
+ local_region += k_neiborhoood
253
+ local_region = list(set(local_region))
254
+
255
+ coalition_placeholder = num_nodes
256
+ set_exclude_masks = []
257
+ set_include_masks = []
258
+ for example_idx in range(sample_num):
259
+ subset_nodes_from = [node for node in local_region if node not in coalition]
260
+ random_nodes_permutation = np.array(subset_nodes_from + [coalition_placeholder])
261
+ random_nodes_permutation = np.random.permutation(random_nodes_permutation)
262
+ split_idx = np.where(random_nodes_permutation == coalition_placeholder)[0][0]
263
+ selected_nodes = random_nodes_permutation[:split_idx]
264
+ set_exclude_mask = np.ones(num_nodes)
265
+ set_exclude_mask[local_region] = 0.0
266
+ set_exclude_mask[selected_nodes] = 1.0
267
+ if node_idx != -1:
268
+ set_exclude_mask[node_idx] = 1.0
269
+ set_include_mask = set_exclude_mask.copy()
270
+ set_include_mask[coalition] = 1.0 # include the node_idx
271
+
272
+ set_exclude_masks.append(set_exclude_mask)
273
+ set_include_masks.append(set_include_mask)
274
+
275
+ exclude_mask = np.stack(set_exclude_masks, axis=0)
276
+ include_mask = np.stack(set_include_masks, axis=0)
277
+ marginal_contributions = \
278
+ marginal_contribution(data, exclude_mask, include_mask, value_func, subgraph_build_func)
279
+
280
+ mc_l_shapley_value = (marginal_contributions).mean().item()
281
+ return mc_l_shapley_value
282
+
283
+
284
+ def sparsity(coalition: list, data: Data, subgraph_building_method='zero_filling'):
285
+ if subgraph_building_method == 'zero_filling':
286
+ return 1.0 - len(coalition) / data.num_nodes
287
+
288
+ elif subgraph_building_method == 'split':
289
+ row, col = data.edge_index
290
+ node_mask = torch.zeros(data.x.shape[0])
291
+ node_mask[coalition] = 1.0
292
+ edge_mask = (node_mask[row] == 1) & (node_mask[col] == 1)
293
+ return 1.0 - edge_mask.sum() / edge_mask.shape[0]