hjxdl 0.0.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hjxdl-0.0.1/.github/workflows/python-publish.yml +36 -0
- hjxdl-0.0.1/.gitignore +70 -0
- hjxdl-0.0.1/.gitmodules +0 -0
- hjxdl-0.0.1/MANIFEST.in +2 -0
- hjxdl-0.0.1/PKG-INFO +19 -0
- hjxdl-0.0.1/README.md +6 -0
- hjxdl-0.0.1/__init__.py +0 -0
- hjxdl-0.0.1/hdl/__init__.py +0 -0
- hjxdl-0.0.1/hdl/_version.py +16 -0
- hjxdl-0.0.1/hdl/args/__init__.py +0 -0
- hjxdl-0.0.1/hdl/args/loss_args.py +5 -0
- hjxdl-0.0.1/hdl/controllers/__init__.py +0 -0
- hjxdl-0.0.1/hdl/controllers/al/__init__.py +0 -0
- hjxdl-0.0.1/hdl/controllers/al/al.py +0 -0
- hjxdl-0.0.1/hdl/controllers/al/dispatcher.py +0 -0
- hjxdl-0.0.1/hdl/controllers/al/feedback.py +0 -0
- hjxdl-0.0.1/hdl/controllers/explain/__init__.py +0 -0
- hjxdl-0.0.1/hdl/controllers/explain/shapley.py +293 -0
- hjxdl-0.0.1/hdl/controllers/explain/subgraphx.py +865 -0
- hjxdl-0.0.1/hdl/controllers/predictors/gin_predictor.py +2 -0
- hjxdl-0.0.1/hdl/controllers/predictors/rxn_predictor.py +113 -0
- hjxdl-0.0.1/hdl/controllers/predictors/torch_predictor.py +28 -0
- hjxdl-0.0.1/hdl/controllers/train/__init__.py +0 -0
- hjxdl-0.0.1/hdl/controllers/train/rxn_train.py +219 -0
- hjxdl-0.0.1/hdl/controllers/train/train.py +50 -0
- hjxdl-0.0.1/hdl/controllers/train/train_ginet.py +316 -0
- hjxdl-0.0.1/hdl/controllers/train/trainer_base.py +155 -0
- hjxdl-0.0.1/hdl/controllers/train/trainer_iterative.py +389 -0
- hjxdl-0.0.1/hdl/data/__init__.py +0 -0
- hjxdl-0.0.1/hdl/data/dataset/__init__.py +0 -0
- hjxdl-0.0.1/hdl/data/dataset/base_dataset.py +98 -0
- hjxdl-0.0.1/hdl/data/dataset/fp/__init__.py +0 -0
- hjxdl-0.0.1/hdl/data/dataset/fp/fp_dataset.py +122 -0
- hjxdl-0.0.1/hdl/data/dataset/graph/__init__.py +0 -0
- hjxdl-0.0.1/hdl/data/dataset/graph/chiral.py +62 -0
- hjxdl-0.0.1/hdl/data/dataset/graph/gin.py +255 -0
- hjxdl-0.0.1/hdl/data/dataset/graph/molnet.py +362 -0
- hjxdl-0.0.1/hdl/data/dataset/loaders/__init__.py +0 -0
- hjxdl-0.0.1/hdl/data/dataset/loaders/chiral_graph.py +71 -0
- hjxdl-0.0.1/hdl/data/dataset/loaders/collate_funcs/__init__.py +0 -0
- hjxdl-0.0.1/hdl/data/dataset/loaders/collate_funcs/fp.py +56 -0
- hjxdl-0.0.1/hdl/data/dataset/loaders/collate_funcs/rxn.py +40 -0
- hjxdl-0.0.1/hdl/data/dataset/loaders/general.py +23 -0
- hjxdl-0.0.1/hdl/data/dataset/loaders/spliter.py +86 -0
- hjxdl-0.0.1/hdl/data/dataset/samplers/__init__.py +0 -0
- hjxdl-0.0.1/hdl/data/dataset/samplers/chiral.py +19 -0
- hjxdl-0.0.1/hdl/data/dataset/seq/__init__.py +0 -0
- hjxdl-0.0.1/hdl/data/dataset/seq/rxn_dataset.py +61 -0
- hjxdl-0.0.1/hdl/data/dataset/utils.py +31 -0
- hjxdl-0.0.1/hdl/data/to_mols.py +0 -0
- hjxdl-0.0.1/hdl/features/__init__.py +0 -0
- hjxdl-0.0.1/hdl/features/fp/__init__.py +0 -0
- hjxdl-0.0.1/hdl/features/fp/features_generators.py +235 -0
- hjxdl-0.0.1/hdl/features/graph/__init__.py +0 -0
- hjxdl-0.0.1/hdl/features/graph/featurization.py +297 -0
- hjxdl-0.0.1/hdl/features/utils/__init__.py +0 -0
- hjxdl-0.0.1/hdl/features/utils/utils.py +111 -0
- hjxdl-0.0.1/hdl/features/vocab.txt +591 -0
- hjxdl-0.0.1/hdl/include/add2.h +4 -0
- hjxdl-0.0.1/hdl/kernel/add2_kernel.cu +18 -0
- hjxdl-0.0.1/hdl/kernel/test +0 -0
- hjxdl-0.0.1/hdl/kernel/test.cu +14 -0
- hjxdl-0.0.1/hdl/layers/__init__.py +0 -0
- hjxdl-0.0.1/hdl/layers/general/__init__.py +0 -0
- hjxdl-0.0.1/hdl/layers/general/gp.py +14 -0
- hjxdl-0.0.1/hdl/layers/general/linear.py +641 -0
- hjxdl-0.0.1/hdl/layers/graph/__init__.py +0 -0
- hjxdl-0.0.1/hdl/layers/graph/chiral_graph.py +230 -0
- hjxdl-0.0.1/hdl/layers/graph/gcn.py +16 -0
- hjxdl-0.0.1/hdl/layers/graph/gin.py +45 -0
- hjxdl-0.0.1/hdl/layers/graph/tetra.py +158 -0
- hjxdl-0.0.1/hdl/layers/graph/transformer.py +188 -0
- hjxdl-0.0.1/hdl/layers/sequential/__init__.py +0 -0
- hjxdl-0.0.1/hdl/metric_loss/__init__.py +0 -0
- hjxdl-0.0.1/hdl/metric_loss/loss.py +79 -0
- hjxdl-0.0.1/hdl/metric_loss/metric.py +178 -0
- hjxdl-0.0.1/hdl/metric_loss/multi_label.py +42 -0
- hjxdl-0.0.1/hdl/metric_loss/nt_xent.py +65 -0
- hjxdl-0.0.1/hdl/models/__init__.py +0 -0
- hjxdl-0.0.1/hdl/models/chiral_gnn.py +176 -0
- hjxdl-0.0.1/hdl/models/fast_transformer.py +234 -0
- hjxdl-0.0.1/hdl/models/ginet.py +189 -0
- hjxdl-0.0.1/hdl/models/linear.py +137 -0
- hjxdl-0.0.1/hdl/models/model_dict.py +18 -0
- hjxdl-0.0.1/hdl/models/norm_flows.py +33 -0
- hjxdl-0.0.1/hdl/models/optim_dict.py +16 -0
- hjxdl-0.0.1/hdl/models/rxn.py +63 -0
- hjxdl-0.0.1/hdl/models/utils.py +83 -0
- hjxdl-0.0.1/hdl/ops/__init__.py +0 -0
- hjxdl-0.0.1/hdl/ops/utils.py +42 -0
- hjxdl-0.0.1/hdl/optims/__init__.py +0 -0
- hjxdl-0.0.1/hdl/optims/nadam.py +86 -0
- hjxdl-0.0.1/hdl/pytorch/add2_ops.cpp +22 -0
- hjxdl-0.0.1/hdl/utils/__init__.py +0 -0
- hjxdl-0.0.1/hdl/utils/chemical_tools/__init__.py +2 -0
- hjxdl-0.0.1/hdl/utils/chemical_tools/query_info.py +149 -0
- hjxdl-0.0.1/hdl/utils/chemical_tools/sdf.py +20 -0
- hjxdl-0.0.1/hdl/utils/database_tools/__init__.py +0 -0
- hjxdl-0.0.1/hdl/utils/database_tools/connect.py +28 -0
- hjxdl-0.0.1/hdl/utils/general/__init__.py +0 -0
- hjxdl-0.0.1/hdl/utils/general/glob.py +21 -0
- hjxdl-0.0.1/hdl/utils/schedulers/__init__.py +0 -0
- hjxdl-0.0.1/hdl/utils/schedulers/norm_lr.py +108 -0
- hjxdl-0.0.1/hjxdl.egg-info/PKG-INFO +19 -0
- hjxdl-0.0.1/hjxdl.egg-info/SOURCES.txt +110 -0
- hjxdl-0.0.1/hjxdl.egg-info/dependency_links.txt +1 -0
- hjxdl-0.0.1/hjxdl.egg-info/top_level.txt +1 -0
- hjxdl-0.0.1/pyproject.toml +8 -0
- hjxdl-0.0.1/setup.cfg +4 -0
- hjxdl-0.0.1/setup.py +39 -0
- hjxdl-0.0.1/update_main.sh +28 -0
- hjxdl-0.0.1/version.txt +1 -0
@@ -0,0 +1,36 @@
|
|
1
|
+
name: Build and Publish to PyPI
|
2
|
+
|
3
|
+
on:
|
4
|
+
push:
|
5
|
+
branches:
|
6
|
+
- main # 确保这匹配你的主分支名
|
7
|
+
|
8
|
+
jobs:
|
9
|
+
build-and-publish:
|
10
|
+
runs-on: ubuntu-latest
|
11
|
+
|
12
|
+
steps:
|
13
|
+
- uses: actions/checkout@v2
|
14
|
+
with:
|
15
|
+
fetch-depth: 0 # 确保能够获取全部历史和标签
|
16
|
+
|
17
|
+
- name: Set up Python
|
18
|
+
uses: actions/setup-python@v2
|
19
|
+
with:
|
20
|
+
python-version: '3.x'
|
21
|
+
|
22
|
+
- name: Install dependencies
|
23
|
+
run: |
|
24
|
+
python -m pip install --upgrade pip
|
25
|
+
pip install setuptools wheel
|
26
|
+
pip install build # 确保在构建前安装 build 模块
|
27
|
+
|
28
|
+
- name: Build package
|
29
|
+
run: |
|
30
|
+
python -m build
|
31
|
+
|
32
|
+
- name: Publish package to PyPI
|
33
|
+
uses: pypa/gh-action-pypi-publish@v1.4.2
|
34
|
+
with:
|
35
|
+
user: __token__
|
36
|
+
password: ${{ secrets.PYPI_API_TOKEN }}
|
hjxdl-0.0.1/.gitignore
ADDED
@@ -0,0 +1,70 @@
|
|
1
|
+
# These are some examples of commonly ignored file patterns.
|
2
|
+
# You should customize this list as applicable to your project.
|
3
|
+
# Learn more about .gitignore:
|
4
|
+
# https://www.atlassian.com/git/tutorials/saving-changes/gitignore
|
5
|
+
|
6
|
+
# Node artifact files
|
7
|
+
node_modules/
|
8
|
+
dist/
|
9
|
+
|
10
|
+
# Compiled Java class files
|
11
|
+
*.class
|
12
|
+
|
13
|
+
# Compiled Python bytecode
|
14
|
+
*.py[cod]
|
15
|
+
|
16
|
+
# Log files
|
17
|
+
*.log
|
18
|
+
|
19
|
+
# Package files
|
20
|
+
*.jar
|
21
|
+
|
22
|
+
# Maven
|
23
|
+
target/
|
24
|
+
dist/
|
25
|
+
|
26
|
+
# JetBrains IDE
|
27
|
+
.idea/
|
28
|
+
|
29
|
+
# Unit test reports
|
30
|
+
TEST*.xml
|
31
|
+
|
32
|
+
# Generated by MacOS
|
33
|
+
.DS_Store
|
34
|
+
|
35
|
+
# Generated by Windows
|
36
|
+
Thumbs.db
|
37
|
+
|
38
|
+
# Applications
|
39
|
+
*.app
|
40
|
+
*.exe
|
41
|
+
*.war
|
42
|
+
|
43
|
+
# Large media files
|
44
|
+
*.mp4
|
45
|
+
*.tiff
|
46
|
+
*.avi
|
47
|
+
*.flv
|
48
|
+
*.mov
|
49
|
+
*.wmv
|
50
|
+
|
51
|
+
test.py
|
52
|
+
|
53
|
+
# Distribution / packaging
|
54
|
+
.Python
|
55
|
+
env/
|
56
|
+
build/
|
57
|
+
develop-eggs/
|
58
|
+
dist/
|
59
|
+
downloads/
|
60
|
+
eggs/
|
61
|
+
.eggs/
|
62
|
+
lib/
|
63
|
+
lib64/
|
64
|
+
parts/
|
65
|
+
sdist/
|
66
|
+
var/
|
67
|
+
*.egg-info/
|
68
|
+
.installed.cfg
|
69
|
+
*.egg
|
70
|
+
.vscode
|
hjxdl-0.0.1/.gitmodules
ADDED
File without changes
|
hjxdl-0.0.1/MANIFEST.in
ADDED
hjxdl-0.0.1/PKG-INFO
ADDED
@@ -0,0 +1,19 @@
|
|
1
|
+
Metadata-Version: 2.1
|
2
|
+
Name: hjxdl
|
3
|
+
Version: 0.0.1
|
4
|
+
Summary: A collection of functions for Jupyter notebooks
|
5
|
+
Home-page: https://github.com/huluxiaohuowa/hdl
|
6
|
+
Author: Jianxing Hu
|
7
|
+
Author-email: j.hu@pku.edu.cn
|
8
|
+
Classifier: Programming Language :: Python :: 3
|
9
|
+
Classifier: License :: OSI Approved :: MIT License
|
10
|
+
Classifier: Operating System :: OS Independent
|
11
|
+
Requires-Python: >=3.6
|
12
|
+
Description-Content-Type: text/markdown
|
13
|
+
|
14
|
+
# DL framework by Jianxing
|
15
|
+
|
16
|
+
```bash
|
17
|
+
git clone git@github.com:huluxiaohuowa/hdl.git
|
18
|
+
python setup.py install
|
19
|
+
```
|
hjxdl-0.0.1/README.md
ADDED
hjxdl-0.0.1/__init__.py
ADDED
File without changes
|
File without changes
|
@@ -0,0 +1,16 @@
|
|
1
|
+
# file generated by setuptools_scm
|
2
|
+
# don't change, don't track in version control
|
3
|
+
TYPE_CHECKING = False
|
4
|
+
if TYPE_CHECKING:
|
5
|
+
from typing import Tuple, Union
|
6
|
+
VERSION_TUPLE = Tuple[Union[int, str], ...]
|
7
|
+
else:
|
8
|
+
VERSION_TUPLE = object
|
9
|
+
|
10
|
+
version: str
|
11
|
+
__version__: str
|
12
|
+
__version_tuple__: VERSION_TUPLE
|
13
|
+
version_tuple: VERSION_TUPLE
|
14
|
+
|
15
|
+
__version__ = version = '0.0.1'
|
16
|
+
__version_tuple__ = version_tuple = (0, 0, 1)
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
@@ -0,0 +1,293 @@
|
|
1
|
+
import copy
|
2
|
+
import torch
|
3
|
+
import numpy as np
|
4
|
+
from scipy.special import comb
|
5
|
+
from itertools import combinations
|
6
|
+
import torch.nn.functional as F
|
7
|
+
from torch_geometric.utils import to_networkx
|
8
|
+
from torch_geometric.data import Data, Batch, Dataset, DataLoader
|
9
|
+
|
10
|
+
|
11
|
+
def GnnNetsGC2valueFunc(gnnNets, target_class):
|
12
|
+
def value_func(batch):
|
13
|
+
with torch.no_grad():
|
14
|
+
logits = gnnNets(data=batch)
|
15
|
+
probs = F.softmax(logits, dim=-1)
|
16
|
+
score = probs[:, target_class]
|
17
|
+
return score
|
18
|
+
return value_func
|
19
|
+
|
20
|
+
|
21
|
+
def GnnNetsNC2valueFunc(gnnNets_NC, node_idx, target_class):
|
22
|
+
def value_func(data):
|
23
|
+
with torch.no_grad():
|
24
|
+
logits = gnnNets_NC(data=data)
|
25
|
+
probs = F.softmax(logits, dim=-1)
|
26
|
+
# select the corresponding node prob through the node idx on all the sampling graphs
|
27
|
+
batch_size = data.batch.max() + 1
|
28
|
+
probs = probs.reshape(batch_size, -1, probs.shape[-1])
|
29
|
+
score = probs[:, node_idx, target_class]
|
30
|
+
return score
|
31
|
+
return value_func
|
32
|
+
|
33
|
+
|
34
|
+
def get_graph_build_func(build_method):
|
35
|
+
if build_method.lower() == 'zero_filling':
|
36
|
+
return graph_build_zero_filling
|
37
|
+
elif build_method.lower() == 'split':
|
38
|
+
return graph_build_split
|
39
|
+
else:
|
40
|
+
raise NotImplementedError
|
41
|
+
|
42
|
+
|
43
|
+
class MarginalSubgraphDataset(Dataset):
|
44
|
+
def __init__(self, data, exclude_mask, include_mask, subgraph_build_func):
|
45
|
+
self.num_nodes = data.num_nodes
|
46
|
+
self.X = data.x
|
47
|
+
self.edge_index = data.edge_index
|
48
|
+
self.device = self.X.device
|
49
|
+
|
50
|
+
self.label = data.y
|
51
|
+
self.exclude_mask = torch.tensor(exclude_mask).type(torch.float32).to(self.device)
|
52
|
+
self.include_mask = torch.tensor(include_mask).type(torch.float32).to(self.device)
|
53
|
+
self.subgraph_build_func = subgraph_build_func
|
54
|
+
|
55
|
+
def __len__(self):
|
56
|
+
return self.exclude_mask.shape[0]
|
57
|
+
|
58
|
+
def __getitem__(self, idx):
|
59
|
+
exclude_graph_X, exclude_graph_edge_index = self.subgraph_build_func(self.X, self.edge_index, self.exclude_mask[idx])
|
60
|
+
include_graph_X, include_graph_edge_index = self.subgraph_build_func(self.X, self.edge_index, self.include_mask[idx])
|
61
|
+
exclude_data = Data(x=exclude_graph_X, edge_index=exclude_graph_edge_index)
|
62
|
+
include_data = Data(x=include_graph_X, edge_index=include_graph_edge_index)
|
63
|
+
return exclude_data, include_data
|
64
|
+
|
65
|
+
|
66
|
+
def marginal_contribution(data: Data, exclude_mask: np.array, include_mask: np.array,
|
67
|
+
value_func, subgraph_build_func):
|
68
|
+
""" Calculate the marginal value for each pair. Here exclude_mask and include_mask are node mask. """
|
69
|
+
marginal_subgraph_dataset = MarginalSubgraphDataset(data, exclude_mask, include_mask, subgraph_build_func)
|
70
|
+
dataloader = DataLoader(marginal_subgraph_dataset, batch_size=256, shuffle=False, num_workers=0)
|
71
|
+
|
72
|
+
marginal_contribution_list = []
|
73
|
+
|
74
|
+
for exclude_data, include_data in dataloader:
|
75
|
+
exclude_values = value_func(exclude_data)
|
76
|
+
include_values = value_func(include_data)
|
77
|
+
margin_values = include_values - exclude_values
|
78
|
+
marginal_contribution_list.append(margin_values)
|
79
|
+
|
80
|
+
marginal_contributions = torch.cat(marginal_contribution_list, dim=0)
|
81
|
+
return marginal_contributions
|
82
|
+
|
83
|
+
|
84
|
+
def graph_build_zero_filling(X, edge_index, node_mask: np.array):
|
85
|
+
""" subgraph building through masking the unselected nodes with zero features """
|
86
|
+
ret_X = X * node_mask.unsqueeze(1)
|
87
|
+
return ret_X, edge_index
|
88
|
+
|
89
|
+
|
90
|
+
def graph_build_split(X, edge_index, node_mask: np.array):
|
91
|
+
""" subgraph building through spliting the selected nodes from the original graph """
|
92
|
+
ret_X = X
|
93
|
+
row, col = edge_index
|
94
|
+
edge_mask = (node_mask[row] == 1) & (node_mask[col] == 1)
|
95
|
+
ret_edge_index = edge_index[:, edge_mask]
|
96
|
+
return ret_X, ret_edge_index
|
97
|
+
|
98
|
+
|
99
|
+
def l_shapley(coalition: list, data: Data, local_radius: int,
|
100
|
+
value_func: str, subgraph_building_method='zero_filling'):
|
101
|
+
""" shapley value where players are local neighbor nodes """
|
102
|
+
graph = to_networkx(data)
|
103
|
+
num_nodes = graph.number_of_nodes()
|
104
|
+
subgraph_build_func = get_graph_build_func(subgraph_building_method)
|
105
|
+
|
106
|
+
local_region = copy.copy(coalition)
|
107
|
+
for k in range(local_radius - 1):
|
108
|
+
k_neiborhoood = []
|
109
|
+
for node in local_region:
|
110
|
+
k_neiborhoood += list(graph.neighbors(node))
|
111
|
+
local_region += k_neiborhoood
|
112
|
+
local_region = list(set(local_region))
|
113
|
+
|
114
|
+
set_exclude_masks = []
|
115
|
+
set_include_masks = []
|
116
|
+
nodes_around = [node for node in local_region if node not in coalition]
|
117
|
+
num_nodes_around = len(nodes_around)
|
118
|
+
|
119
|
+
for subset_len in range(0, num_nodes_around + 1):
|
120
|
+
node_exclude_subsets = combinations(nodes_around, subset_len)
|
121
|
+
for node_exclude_subset in node_exclude_subsets:
|
122
|
+
set_exclude_mask = np.ones(num_nodes)
|
123
|
+
set_exclude_mask[local_region] = 0.0
|
124
|
+
if node_exclude_subset:
|
125
|
+
set_exclude_mask[list(node_exclude_subset)] = 1.0
|
126
|
+
set_include_mask = set_exclude_mask.copy()
|
127
|
+
set_include_mask[coalition] = 1.0
|
128
|
+
|
129
|
+
set_exclude_masks.append(set_exclude_mask)
|
130
|
+
set_include_masks.append(set_include_mask)
|
131
|
+
|
132
|
+
exclude_mask = np.stack(set_exclude_masks, axis=0)
|
133
|
+
include_mask = np.stack(set_include_masks, axis=0)
|
134
|
+
num_players = len(nodes_around) + 1
|
135
|
+
num_player_in_set = num_players - 1 + len(coalition) - (1 - exclude_mask).sum(axis=1)
|
136
|
+
p = num_players
|
137
|
+
S = num_player_in_set
|
138
|
+
coeffs = torch.tensor(1.0 / comb(p, S) / (p - S + 1e-6))
|
139
|
+
|
140
|
+
marginal_contributions = \
|
141
|
+
marginal_contribution(data, exclude_mask, include_mask, value_func, subgraph_build_func)
|
142
|
+
|
143
|
+
l_shapley_value = (marginal_contributions.squeeze().cpu() * coeffs).sum().item()
|
144
|
+
return l_shapley_value
|
145
|
+
|
146
|
+
|
147
|
+
def mc_shapley(coalition: list, data: Data,
|
148
|
+
value_func: str, subgraph_building_method='zero_filling',
|
149
|
+
sample_num=1000) -> float:
|
150
|
+
""" monte carlo sampling approximation of the shapley value """
|
151
|
+
subset_build_func = get_graph_build_func(subgraph_building_method)
|
152
|
+
|
153
|
+
num_nodes = data.num_nodes
|
154
|
+
node_indices = np.arange(num_nodes)
|
155
|
+
coalition_placeholder = num_nodes
|
156
|
+
set_exclude_masks = []
|
157
|
+
set_include_masks = []
|
158
|
+
|
159
|
+
for example_idx in range(sample_num):
|
160
|
+
subset_nodes_from = [node for node in node_indices if node not in coalition]
|
161
|
+
random_nodes_permutation = np.array(subset_nodes_from + [coalition_placeholder])
|
162
|
+
random_nodes_permutation = np.random.permutation(random_nodes_permutation)
|
163
|
+
split_idx = np.where(random_nodes_permutation == coalition_placeholder)[0][0]
|
164
|
+
selected_nodes = random_nodes_permutation[:split_idx]
|
165
|
+
set_exclude_mask = np.zeros(num_nodes)
|
166
|
+
set_exclude_mask[selected_nodes] = 1.0
|
167
|
+
set_include_mask = set_exclude_mask.copy()
|
168
|
+
set_include_mask[coalition] = 1.0
|
169
|
+
|
170
|
+
set_exclude_masks.append(set_exclude_mask)
|
171
|
+
set_include_masks.append(set_include_mask)
|
172
|
+
|
173
|
+
exclude_mask = np.stack(set_exclude_masks, axis=0)
|
174
|
+
include_mask = np.stack(set_include_masks, axis=0)
|
175
|
+
marginal_contributions = marginal_contribution(data, exclude_mask, include_mask, value_func, subset_build_func)
|
176
|
+
mc_shapley_value = marginal_contributions.mean().item()
|
177
|
+
|
178
|
+
return mc_shapley_value
|
179
|
+
|
180
|
+
|
181
|
+
def mc_l_shapley(coalition: list, data: Data, local_radius: int,
|
182
|
+
value_func: str, subgraph_building_method='zero_filling',
|
183
|
+
sample_num=1000) -> float:
|
184
|
+
""" monte carlo sampling approximation of the l_shapley value """
|
185
|
+
graph = to_networkx(data)
|
186
|
+
num_nodes = graph.number_of_nodes()
|
187
|
+
subgraph_build_func = get_graph_build_func(subgraph_building_method)
|
188
|
+
|
189
|
+
local_region = copy.copy(coalition)
|
190
|
+
for k in range(local_radius - 1):
|
191
|
+
k_neiborhoood = []
|
192
|
+
for node in local_region:
|
193
|
+
k_neiborhoood += list(graph.neighbors(node))
|
194
|
+
local_region += k_neiborhoood
|
195
|
+
local_region = list(set(local_region))
|
196
|
+
|
197
|
+
coalition_placeholder = num_nodes
|
198
|
+
set_exclude_masks = []
|
199
|
+
set_include_masks = []
|
200
|
+
for example_idx in range(sample_num):
|
201
|
+
subset_nodes_from = [node for node in local_region if node not in coalition]
|
202
|
+
random_nodes_permutation = np.array(subset_nodes_from + [coalition_placeholder])
|
203
|
+
random_nodes_permutation = np.random.permutation(random_nodes_permutation)
|
204
|
+
split_idx = np.where(random_nodes_permutation == coalition_placeholder)[0][0]
|
205
|
+
selected_nodes = random_nodes_permutation[:split_idx]
|
206
|
+
set_exclude_mask = np.ones(num_nodes)
|
207
|
+
set_exclude_mask[local_region] = 0.0
|
208
|
+
set_exclude_mask[selected_nodes] = 1.0
|
209
|
+
set_include_mask = set_exclude_mask.copy()
|
210
|
+
set_include_mask[coalition] = 1.0
|
211
|
+
|
212
|
+
set_exclude_masks.append(set_exclude_mask)
|
213
|
+
set_include_masks.append(set_include_mask)
|
214
|
+
|
215
|
+
exclude_mask = np.stack(set_exclude_masks, axis=0)
|
216
|
+
include_mask = np.stack(set_include_masks, axis=0)
|
217
|
+
marginal_contributions = \
|
218
|
+
marginal_contribution(data, exclude_mask, include_mask, value_func, subgraph_build_func)
|
219
|
+
|
220
|
+
mc_l_shapley_value = (marginal_contributions).mean().item()
|
221
|
+
return mc_l_shapley_value
|
222
|
+
|
223
|
+
|
224
|
+
def gnn_score(coalition: list, data: Data, value_func: str,
|
225
|
+
subgraph_building_method='zero_filling') -> torch.Tensor:
|
226
|
+
""" the value of subgraph with selected nodes """
|
227
|
+
num_nodes = data.num_nodes
|
228
|
+
subgraph_build_func = get_graph_build_func(subgraph_building_method)
|
229
|
+
mask = torch.zeros(num_nodes).type(torch.float32).to(data.x.device)
|
230
|
+
mask[coalition] = 1.0
|
231
|
+
ret_x, ret_edge_index = subgraph_build_func(data.x, data.edge_index, mask)
|
232
|
+
mask_data = Data(x=ret_x, edge_index=ret_edge_index)
|
233
|
+
mask_data = Batch.from_data_list([mask_data])
|
234
|
+
score = value_func(mask_data)
|
235
|
+
# get the score of predicted class for graph or specific node idx
|
236
|
+
return score.item()
|
237
|
+
|
238
|
+
|
239
|
+
def NC_mc_l_shapley(coalition: list, data: Data, local_radius: int,
|
240
|
+
value_func: str, node_idx: int = -1,
|
241
|
+
subgraph_building_method='zero_filling', sample_num=1000) -> float:
|
242
|
+
""" monte carlo approximation of l_shapley where the target node is kept in both subgraph """
|
243
|
+
graph = to_networkx(data)
|
244
|
+
num_nodes = graph.number_of_nodes()
|
245
|
+
subgraph_build_func = get_graph_build_func(subgraph_building_method)
|
246
|
+
|
247
|
+
local_region = copy.copy(coalition)
|
248
|
+
for k in range(local_radius - 1):
|
249
|
+
k_neiborhoood = []
|
250
|
+
for node in local_region:
|
251
|
+
k_neiborhoood += list(graph.neighbors(node))
|
252
|
+
local_region += k_neiborhoood
|
253
|
+
local_region = list(set(local_region))
|
254
|
+
|
255
|
+
coalition_placeholder = num_nodes
|
256
|
+
set_exclude_masks = []
|
257
|
+
set_include_masks = []
|
258
|
+
for example_idx in range(sample_num):
|
259
|
+
subset_nodes_from = [node for node in local_region if node not in coalition]
|
260
|
+
random_nodes_permutation = np.array(subset_nodes_from + [coalition_placeholder])
|
261
|
+
random_nodes_permutation = np.random.permutation(random_nodes_permutation)
|
262
|
+
split_idx = np.where(random_nodes_permutation == coalition_placeholder)[0][0]
|
263
|
+
selected_nodes = random_nodes_permutation[:split_idx]
|
264
|
+
set_exclude_mask = np.ones(num_nodes)
|
265
|
+
set_exclude_mask[local_region] = 0.0
|
266
|
+
set_exclude_mask[selected_nodes] = 1.0
|
267
|
+
if node_idx != -1:
|
268
|
+
set_exclude_mask[node_idx] = 1.0
|
269
|
+
set_include_mask = set_exclude_mask.copy()
|
270
|
+
set_include_mask[coalition] = 1.0 # include the node_idx
|
271
|
+
|
272
|
+
set_exclude_masks.append(set_exclude_mask)
|
273
|
+
set_include_masks.append(set_include_mask)
|
274
|
+
|
275
|
+
exclude_mask = np.stack(set_exclude_masks, axis=0)
|
276
|
+
include_mask = np.stack(set_include_masks, axis=0)
|
277
|
+
marginal_contributions = \
|
278
|
+
marginal_contribution(data, exclude_mask, include_mask, value_func, subgraph_build_func)
|
279
|
+
|
280
|
+
mc_l_shapley_value = (marginal_contributions).mean().item()
|
281
|
+
return mc_l_shapley_value
|
282
|
+
|
283
|
+
|
284
|
+
def sparsity(coalition: list, data: Data, subgraph_building_method='zero_filling'):
|
285
|
+
if subgraph_building_method == 'zero_filling':
|
286
|
+
return 1.0 - len(coalition) / data.num_nodes
|
287
|
+
|
288
|
+
elif subgraph_building_method == 'split':
|
289
|
+
row, col = data.edge_index
|
290
|
+
node_mask = torch.zeros(data.x.shape[0])
|
291
|
+
node_mask[coalition] = 1.0
|
292
|
+
edge_mask = (node_mask[row] == 1) & (node_mask[col] == 1)
|
293
|
+
return 1.0 - edge_mask.sum() / edge_mask.shape[0]
|