pyg-nightly 2.6.0.dev20240704__py3-none-any.whl → 2.8.0.dev20251207__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pyg-nightly might be problematic. Click here for more details.
- {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/METADATA +81 -58
- {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/RECORD +265 -221
- {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/WHEEL +1 -1
- pyg_nightly-2.8.0.dev20251207.dist-info/licenses/LICENSE +19 -0
- torch_geometric/__init__.py +34 -1
- torch_geometric/_compile.py +11 -3
- torch_geometric/_onnx.py +228 -0
- torch_geometric/config_mixin.py +8 -3
- torch_geometric/config_store.py +1 -1
- torch_geometric/contrib/__init__.py +1 -1
- torch_geometric/contrib/explain/pgm_explainer.py +1 -1
- torch_geometric/data/__init__.py +19 -1
- torch_geometric/data/batch.py +2 -2
- torch_geometric/data/collate.py +1 -3
- torch_geometric/data/data.py +110 -6
- torch_geometric/data/database.py +19 -5
- torch_geometric/data/dataset.py +14 -9
- torch_geometric/data/extract.py +1 -1
- torch_geometric/data/feature_store.py +17 -22
- torch_geometric/data/graph_store.py +3 -2
- torch_geometric/data/hetero_data.py +139 -7
- torch_geometric/data/hypergraph_data.py +2 -2
- torch_geometric/data/in_memory_dataset.py +2 -2
- torch_geometric/data/lightning/datamodule.py +42 -28
- torch_geometric/data/storage.py +9 -1
- torch_geometric/datasets/__init__.py +20 -1
- torch_geometric/datasets/actor.py +7 -9
- torch_geometric/datasets/airfrans.py +17 -20
- torch_geometric/datasets/airports.py +8 -10
- torch_geometric/datasets/amazon.py +8 -11
- torch_geometric/datasets/amazon_book.py +8 -9
- torch_geometric/datasets/amazon_products.py +7 -9
- torch_geometric/datasets/aminer.py +8 -9
- torch_geometric/datasets/aqsol.py +10 -13
- torch_geometric/datasets/attributed_graph_dataset.py +8 -10
- torch_geometric/datasets/ba_multi_shapes.py +10 -12
- torch_geometric/datasets/ba_shapes.py +5 -6
- torch_geometric/datasets/brca_tgca.py +1 -1
- torch_geometric/datasets/city.py +157 -0
- torch_geometric/datasets/dbp15k.py +1 -1
- torch_geometric/datasets/gdelt_lite.py +3 -2
- torch_geometric/datasets/ged_dataset.py +3 -2
- torch_geometric/datasets/git_mol_dataset.py +263 -0
- torch_geometric/datasets/gnn_benchmark_dataset.py +3 -2
- torch_geometric/datasets/hgb_dataset.py +2 -2
- torch_geometric/datasets/hm.py +1 -1
- torch_geometric/datasets/instruct_mol_dataset.py +134 -0
- torch_geometric/datasets/linkx_dataset.py +4 -3
- torch_geometric/datasets/lrgb.py +3 -5
- torch_geometric/datasets/malnet_tiny.py +2 -1
- torch_geometric/datasets/md17.py +3 -3
- torch_geometric/datasets/medshapenet.py +145 -0
- torch_geometric/datasets/mnist_superpixels.py +2 -3
- torch_geometric/datasets/modelnet.py +1 -1
- torch_geometric/datasets/molecule_gpt_dataset.py +492 -0
- torch_geometric/datasets/molecule_net.py +3 -2
- torch_geometric/datasets/neurograph.py +1 -3
- torch_geometric/datasets/ogb_mag.py +1 -1
- torch_geometric/datasets/opf.py +19 -5
- torch_geometric/datasets/pascal_pf.py +1 -1
- torch_geometric/datasets/pcqm4m.py +2 -1
- torch_geometric/datasets/ppi.py +2 -1
- torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
- torch_geometric/datasets/qm7.py +1 -1
- torch_geometric/datasets/qm9.py +3 -2
- torch_geometric/datasets/shrec2016.py +2 -2
- torch_geometric/datasets/snap_dataset.py +8 -4
- torch_geometric/datasets/tag_dataset.py +462 -0
- torch_geometric/datasets/teeth3ds.py +269 -0
- torch_geometric/datasets/web_qsp_dataset.py +342 -0
- torch_geometric/datasets/wikics.py +2 -1
- torch_geometric/datasets/wikidata.py +2 -1
- torch_geometric/deprecation.py +1 -1
- torch_geometric/distributed/__init__.py +13 -0
- torch_geometric/distributed/dist_loader.py +2 -2
- torch_geometric/distributed/local_feature_store.py +3 -2
- torch_geometric/distributed/local_graph_store.py +2 -1
- torch_geometric/distributed/partition.py +9 -8
- torch_geometric/distributed/rpc.py +3 -3
- torch_geometric/edge_index.py +35 -22
- torch_geometric/explain/algorithm/attention_explainer.py +219 -29
- torch_geometric/explain/algorithm/base.py +2 -2
- torch_geometric/explain/algorithm/captum.py +1 -1
- torch_geometric/explain/algorithm/captum_explainer.py +2 -1
- torch_geometric/explain/algorithm/gnn_explainer.py +406 -69
- torch_geometric/explain/algorithm/graphmask_explainer.py +8 -8
- torch_geometric/explain/algorithm/pg_explainer.py +305 -47
- torch_geometric/explain/explainer.py +2 -2
- torch_geometric/explain/explanation.py +89 -5
- torch_geometric/explain/metric/faithfulness.py +1 -1
- torch_geometric/graphgym/checkpoint.py +2 -1
- torch_geometric/graphgym/config.py +3 -2
- torch_geometric/graphgym/imports.py +15 -4
- torch_geometric/graphgym/logger.py +1 -1
- torch_geometric/graphgym/loss.py +1 -1
- torch_geometric/graphgym/models/encoder.py +2 -2
- torch_geometric/graphgym/models/layer.py +1 -1
- torch_geometric/graphgym/utils/comp_budget.py +4 -3
- torch_geometric/hash_tensor.py +798 -0
- torch_geometric/index.py +16 -7
- torch_geometric/inspector.py +6 -2
- torch_geometric/io/fs.py +27 -0
- torch_geometric/io/tu.py +2 -3
- torch_geometric/llm/__init__.py +9 -0
- torch_geometric/llm/large_graph_indexer.py +741 -0
- torch_geometric/llm/models/__init__.py +23 -0
- torch_geometric/llm/models/g_retriever.py +251 -0
- torch_geometric/llm/models/git_mol.py +336 -0
- torch_geometric/llm/models/glem.py +397 -0
- torch_geometric/llm/models/llm.py +470 -0
- torch_geometric/llm/models/llm_judge.py +158 -0
- torch_geometric/llm/models/molecule_gpt.py +222 -0
- torch_geometric/llm/models/protein_mpnn.py +333 -0
- torch_geometric/llm/models/sentence_transformer.py +188 -0
- torch_geometric/llm/models/txt2kg.py +353 -0
- torch_geometric/llm/models/vision_transformer.py +38 -0
- torch_geometric/llm/rag_loader.py +154 -0
- torch_geometric/llm/utils/__init__.py +10 -0
- torch_geometric/llm/utils/backend_utils.py +443 -0
- torch_geometric/llm/utils/feature_store.py +169 -0
- torch_geometric/llm/utils/graph_store.py +199 -0
- torch_geometric/llm/utils/vectorrag.py +125 -0
- torch_geometric/loader/cluster.py +6 -5
- torch_geometric/loader/graph_saint.py +2 -1
- torch_geometric/loader/ibmb_loader.py +4 -4
- torch_geometric/loader/link_loader.py +1 -1
- torch_geometric/loader/link_neighbor_loader.py +2 -1
- torch_geometric/loader/mixin.py +6 -5
- torch_geometric/loader/neighbor_loader.py +1 -1
- torch_geometric/loader/neighbor_sampler.py +2 -2
- torch_geometric/loader/prefetch.py +4 -3
- torch_geometric/loader/temporal_dataloader.py +2 -2
- torch_geometric/loader/utils.py +10 -10
- torch_geometric/metrics/__init__.py +23 -2
- torch_geometric/metrics/link_pred.py +755 -85
- torch_geometric/nn/__init__.py +1 -0
- torch_geometric/nn/aggr/__init__.py +2 -0
- torch_geometric/nn/aggr/base.py +1 -1
- torch_geometric/nn/aggr/equilibrium.py +1 -1
- torch_geometric/nn/aggr/fused.py +1 -1
- torch_geometric/nn/aggr/patch_transformer.py +149 -0
- torch_geometric/nn/aggr/set_transformer.py +1 -1
- torch_geometric/nn/aggr/utils.py +9 -4
- torch_geometric/nn/attention/__init__.py +9 -1
- torch_geometric/nn/attention/polynormer.py +107 -0
- torch_geometric/nn/attention/qformer.py +71 -0
- torch_geometric/nn/attention/sgformer.py +99 -0
- torch_geometric/nn/conv/__init__.py +2 -0
- torch_geometric/nn/conv/appnp.py +1 -1
- torch_geometric/nn/conv/collect.jinja +6 -3
- torch_geometric/nn/conv/cugraph/gat_conv.py +8 -2
- torch_geometric/nn/conv/cugraph/rgcn_conv.py +3 -0
- torch_geometric/nn/conv/cugraph/sage_conv.py +3 -0
- torch_geometric/nn/conv/dna_conv.py +1 -1
- torch_geometric/nn/conv/eg_conv.py +7 -7
- torch_geometric/nn/conv/gat_conv.py +33 -4
- torch_geometric/nn/conv/gatv2_conv.py +35 -4
- torch_geometric/nn/conv/gen_conv.py +1 -1
- torch_geometric/nn/conv/general_conv.py +1 -1
- torch_geometric/nn/conv/gravnet_conv.py +2 -1
- torch_geometric/nn/conv/hetero_conv.py +3 -2
- torch_geometric/nn/conv/meshcnn_conv.py +487 -0
- torch_geometric/nn/conv/message_passing.py +6 -5
- torch_geometric/nn/conv/mixhop_conv.py +1 -1
- torch_geometric/nn/conv/rgcn_conv.py +2 -1
- torch_geometric/nn/conv/sg_conv.py +1 -1
- torch_geometric/nn/conv/spline_conv.py +2 -1
- torch_geometric/nn/conv/ssg_conv.py +1 -1
- torch_geometric/nn/conv/transformer_conv.py +5 -3
- torch_geometric/nn/data_parallel.py +5 -4
- torch_geometric/nn/dense/linear.py +5 -24
- torch_geometric/nn/encoding.py +17 -3
- torch_geometric/nn/fx.py +17 -15
- torch_geometric/nn/model_hub.py +5 -16
- torch_geometric/nn/models/__init__.py +11 -0
- torch_geometric/nn/models/attentive_fp.py +1 -1
- torch_geometric/nn/models/attract_repel.py +148 -0
- torch_geometric/nn/models/basic_gnn.py +2 -1
- torch_geometric/nn/models/captum.py +1 -1
- torch_geometric/nn/models/deep_graph_infomax.py +1 -1
- torch_geometric/nn/models/dimenet.py +2 -2
- torch_geometric/nn/models/dimenet_utils.py +4 -2
- torch_geometric/nn/models/gpse.py +1083 -0
- torch_geometric/nn/models/graph_unet.py +13 -4
- torch_geometric/nn/models/lpformer.py +783 -0
- torch_geometric/nn/models/metapath2vec.py +1 -1
- torch_geometric/nn/models/mlp.py +4 -2
- torch_geometric/nn/models/node2vec.py +1 -1
- torch_geometric/nn/models/polynormer.py +206 -0
- torch_geometric/nn/models/rev_gnn.py +3 -3
- torch_geometric/nn/models/schnet.py +2 -1
- torch_geometric/nn/models/sgformer.py +219 -0
- torch_geometric/nn/models/signed_gcn.py +1 -1
- torch_geometric/nn/models/visnet.py +2 -2
- torch_geometric/nn/norm/batch_norm.py +17 -7
- torch_geometric/nn/norm/diff_group_norm.py +7 -2
- torch_geometric/nn/norm/graph_norm.py +9 -4
- torch_geometric/nn/norm/instance_norm.py +5 -1
- torch_geometric/nn/norm/layer_norm.py +15 -7
- torch_geometric/nn/norm/msg_norm.py +8 -2
- torch_geometric/nn/pool/__init__.py +15 -9
- torch_geometric/nn/pool/cluster_pool.py +144 -0
- torch_geometric/nn/pool/connect/base.py +1 -3
- torch_geometric/nn/pool/edge_pool.py +1 -1
- torch_geometric/nn/pool/knn.py +13 -10
- torch_geometric/nn/pool/select/base.py +1 -4
- torch_geometric/nn/summary.py +1 -1
- torch_geometric/nn/to_hetero_module.py +4 -3
- torch_geometric/nn/to_hetero_transformer.py +3 -3
- torch_geometric/nn/to_hetero_with_bases_transformer.py +5 -5
- torch_geometric/profile/__init__.py +2 -0
- torch_geometric/profile/nvtx.py +66 -0
- torch_geometric/profile/profiler.py +18 -9
- torch_geometric/profile/utils.py +20 -5
- torch_geometric/sampler/__init__.py +2 -1
- torch_geometric/sampler/base.py +337 -8
- torch_geometric/sampler/hgt_sampler.py +11 -1
- torch_geometric/sampler/neighbor_sampler.py +298 -25
- torch_geometric/sampler/utils.py +93 -5
- torch_geometric/testing/__init__.py +4 -0
- torch_geometric/testing/decorators.py +35 -5
- torch_geometric/testing/distributed.py +1 -1
- torch_geometric/transforms/__init__.py +4 -0
- torch_geometric/transforms/add_gpse.py +49 -0
- torch_geometric/transforms/add_metapaths.py +10 -8
- torch_geometric/transforms/add_positional_encoding.py +2 -2
- torch_geometric/transforms/base_transform.py +2 -1
- torch_geometric/transforms/delaunay.py +65 -15
- torch_geometric/transforms/face_to_edge.py +32 -3
- torch_geometric/transforms/gdc.py +8 -9
- torch_geometric/transforms/largest_connected_components.py +1 -1
- torch_geometric/transforms/mask.py +5 -1
- torch_geometric/transforms/node_property_split.py +1 -1
- torch_geometric/transforms/normalize_features.py +3 -3
- torch_geometric/transforms/pad.py +1 -1
- torch_geometric/transforms/random_link_split.py +1 -1
- torch_geometric/transforms/remove_duplicated_edges.py +4 -2
- torch_geometric/transforms/remove_self_loops.py +36 -0
- torch_geometric/transforms/rooted_subgraph.py +1 -1
- torch_geometric/transforms/svd_feature_reduction.py +1 -1
- torch_geometric/transforms/virtual_node.py +2 -1
- torch_geometric/typing.py +82 -17
- torch_geometric/utils/__init__.py +6 -1
- torch_geometric/utils/_lexsort.py +0 -9
- torch_geometric/utils/_negative_sampling.py +28 -13
- torch_geometric/utils/_normalize_edge_index.py +46 -0
- torch_geometric/utils/_scatter.py +126 -164
- torch_geometric/utils/_sort_edge_index.py +0 -2
- torch_geometric/utils/_spmm.py +16 -14
- torch_geometric/utils/_subgraph.py +4 -0
- torch_geometric/utils/_tree_decomposition.py +1 -1
- torch_geometric/utils/_trim_to_layer.py +2 -2
- torch_geometric/utils/augmentation.py +1 -1
- torch_geometric/utils/convert.py +17 -10
- torch_geometric/utils/cross_entropy.py +34 -13
- torch_geometric/utils/embedding.py +91 -2
- torch_geometric/utils/geodesic.py +28 -25
- torch_geometric/utils/influence.py +279 -0
- torch_geometric/utils/map.py +14 -10
- torch_geometric/utils/nested.py +1 -1
- torch_geometric/utils/smiles.py +3 -3
- torch_geometric/utils/sparse.py +32 -24
- torch_geometric/visualization/__init__.py +2 -1
- torch_geometric/visualization/graph.py +250 -5
- torch_geometric/warnings.py +11 -2
- torch_geometric/nn/nlp/__init__.py +0 -7
- torch_geometric/nn/nlp/llm.py +0 -283
- torch_geometric/nn/nlp/sentence_transformer.py +0 -94
|
@@ -1,21 +1,26 @@
|
|
|
1
1
|
import logging
|
|
2
|
-
from typing import Optional, Union
|
|
2
|
+
from typing import Dict, Optional, Tuple, Union, overload
|
|
3
3
|
|
|
4
4
|
import torch
|
|
5
5
|
from torch import Tensor
|
|
6
6
|
from torch.nn import ReLU, Sequential
|
|
7
7
|
|
|
8
|
-
from torch_geometric.explain import Explanation
|
|
8
|
+
from torch_geometric.explain import Explanation, HeteroExplanation
|
|
9
9
|
from torch_geometric.explain.algorithm import ExplainerAlgorithm
|
|
10
|
-
from torch_geometric.explain.algorithm.utils import
|
|
10
|
+
from torch_geometric.explain.algorithm.utils import (
|
|
11
|
+
clear_masks,
|
|
12
|
+
set_hetero_masks,
|
|
13
|
+
set_masks,
|
|
14
|
+
)
|
|
11
15
|
from torch_geometric.explain.config import (
|
|
12
16
|
ExplanationType,
|
|
13
17
|
ModelMode,
|
|
14
18
|
ModelTaskLevel,
|
|
15
19
|
)
|
|
16
|
-
from torch_geometric.nn import Linear
|
|
20
|
+
from torch_geometric.nn import HANConv, HeteroConv, HGTConv, Linear
|
|
17
21
|
from torch_geometric.nn.inits import reset
|
|
18
|
-
from torch_geometric.
|
|
22
|
+
from torch_geometric.typing import EdgeType, NodeType
|
|
23
|
+
from torch_geometric.utils import get_embeddings, get_embeddings_hetero
|
|
19
24
|
|
|
20
25
|
|
|
21
26
|
class PGExplainer(ExplainerAlgorithm):
|
|
@@ -62,6 +67,13 @@ class PGExplainer(ExplainerAlgorithm):
|
|
|
62
67
|
'bias': 0.01,
|
|
63
68
|
}
|
|
64
69
|
|
|
70
|
+
# NOTE: Add more in the future as needed.
|
|
71
|
+
SUPPORTED_HETERO_MODELS = [
|
|
72
|
+
HGTConv,
|
|
73
|
+
HANConv,
|
|
74
|
+
HeteroConv,
|
|
75
|
+
]
|
|
76
|
+
|
|
65
77
|
def __init__(self, epochs: int, lr: float = 0.003, **kwargs):
|
|
66
78
|
super().__init__()
|
|
67
79
|
self.epochs = epochs
|
|
@@ -75,11 +87,13 @@ class PGExplainer(ExplainerAlgorithm):
|
|
|
75
87
|
)
|
|
76
88
|
self.optimizer = torch.optim.Adam(self.mlp.parameters(), lr=lr)
|
|
77
89
|
self._curr_epoch = -1
|
|
90
|
+
self.is_hetero = False
|
|
78
91
|
|
|
79
92
|
def reset_parameters(self):
|
|
80
93
|
r"""Resets all learnable parameters of the module."""
|
|
81
94
|
reset(self.mlp)
|
|
82
95
|
|
|
96
|
+
@overload
|
|
83
97
|
def train(
|
|
84
98
|
self,
|
|
85
99
|
epoch: int,
|
|
@@ -90,17 +104,44 @@ class PGExplainer(ExplainerAlgorithm):
|
|
|
90
104
|
target: Tensor,
|
|
91
105
|
index: Optional[Union[int, Tensor]] = None,
|
|
92
106
|
**kwargs,
|
|
93
|
-
):
|
|
107
|
+
) -> float:
|
|
108
|
+
...
|
|
109
|
+
|
|
110
|
+
@overload
|
|
111
|
+
def train(
|
|
112
|
+
self,
|
|
113
|
+
epoch: int,
|
|
114
|
+
model: torch.nn.Module,
|
|
115
|
+
x: Dict[NodeType, Tensor],
|
|
116
|
+
edge_index: Dict[EdgeType, Tensor],
|
|
117
|
+
*,
|
|
118
|
+
target: Tensor,
|
|
119
|
+
index: Optional[Union[int, Tensor]] = None,
|
|
120
|
+
**kwargs,
|
|
121
|
+
) -> float:
|
|
122
|
+
...
|
|
123
|
+
|
|
124
|
+
def train(
|
|
125
|
+
self,
|
|
126
|
+
epoch: int,
|
|
127
|
+
model: torch.nn.Module,
|
|
128
|
+
x: Union[Tensor, Dict[NodeType, Tensor]],
|
|
129
|
+
edge_index: Union[Tensor, Dict[EdgeType, Tensor]],
|
|
130
|
+
*,
|
|
131
|
+
target: Tensor,
|
|
132
|
+
index: Optional[Union[int, Tensor]] = None,
|
|
133
|
+
**kwargs,
|
|
134
|
+
) -> float:
|
|
94
135
|
r"""Trains the underlying explainer model.
|
|
95
136
|
Needs to be called before being able to make predictions.
|
|
96
137
|
|
|
97
138
|
Args:
|
|
98
139
|
epoch (int): The current epoch of the training phase.
|
|
99
140
|
model (torch.nn.Module): The model to explain.
|
|
100
|
-
x (torch.Tensor): The input node
|
|
101
|
-
homogeneous
|
|
102
|
-
edge_index (torch.Tensor): The input
|
|
103
|
-
|
|
141
|
+
x (torch.Tensor or Dict[str, torch.Tensor]): The input node
|
|
142
|
+
features. Can be either homogeneous or heterogeneous.
|
|
143
|
+
edge_index (torch.Tensor or Dict[Tuple[str, str, str]): The input
|
|
144
|
+
edge indices. Can be either homogeneous or heterogeneous.
|
|
104
145
|
target (torch.Tensor): The target of the model.
|
|
105
146
|
index (int or torch.Tensor, optional): The index of the model
|
|
106
147
|
output to explain. Needs to be a single index.
|
|
@@ -108,9 +149,9 @@ class PGExplainer(ExplainerAlgorithm):
|
|
|
108
149
|
**kwargs (optional): Additional keyword arguments passed to
|
|
109
150
|
:obj:`model`.
|
|
110
151
|
"""
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
152
|
+
self.is_hetero = isinstance(x, dict)
|
|
153
|
+
if self.is_hetero:
|
|
154
|
+
assert isinstance(edge_index, dict)
|
|
114
155
|
|
|
115
156
|
if self.model_config.task_level == ModelTaskLevel.node:
|
|
116
157
|
if index is None:
|
|
@@ -121,35 +162,68 @@ class PGExplainer(ExplainerAlgorithm):
|
|
|
121
162
|
raise ValueError(f"Only scalars are supported for the 'index' "
|
|
122
163
|
f"argument in '{self.__class__.__name__}'")
|
|
123
164
|
|
|
124
|
-
|
|
165
|
+
# Get embeddings based on whether the graph is homogeneous or
|
|
166
|
+
# heterogeneous
|
|
167
|
+
node_embeddings = self._get_embeddings(model, x, edge_index, **kwargs)
|
|
125
168
|
|
|
169
|
+
# Train the model
|
|
126
170
|
self.optimizer.zero_grad()
|
|
127
171
|
temperature = self._get_temperature(epoch)
|
|
128
172
|
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
if self.
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
173
|
+
# Process embeddings and generate edge masks
|
|
174
|
+
edge_mask = self._generate_edge_masks(node_embeddings, edge_index,
|
|
175
|
+
index, temperature)
|
|
176
|
+
|
|
177
|
+
# Apply masks to the model
|
|
178
|
+
if self.is_hetero:
|
|
179
|
+
set_hetero_masks(model, edge_mask, edge_index, apply_sigmoid=True)
|
|
180
|
+
|
|
181
|
+
# For node-level tasks, we can compute hard masks
|
|
182
|
+
if self.model_config.task_level == ModelTaskLevel.node:
|
|
183
|
+
# Process each edge type separately
|
|
184
|
+
for edge_type, mask in edge_mask.items():
|
|
185
|
+
# Get the edge indices for this edge type
|
|
186
|
+
edges = edge_index[edge_type]
|
|
187
|
+
src_type, _, dst_type = edge_type
|
|
188
|
+
|
|
189
|
+
# Get hard masks for this specific edge type
|
|
190
|
+
_, hard_mask = self._get_hard_masks(
|
|
191
|
+
model, index, edges,
|
|
192
|
+
num_nodes=max(x[src_type].size(0),
|
|
193
|
+
x[dst_type].size(0)))
|
|
194
|
+
|
|
195
|
+
edge_mask[edge_type] = mask[hard_mask]
|
|
196
|
+
else:
|
|
197
|
+
# Apply masks for homogeneous graphs
|
|
198
|
+
set_masks(model, edge_mask, edge_index, apply_sigmoid=True)
|
|
199
|
+
|
|
200
|
+
# For node-level tasks, we may need to apply hard masks
|
|
201
|
+
hard_edge_mask = None
|
|
202
|
+
if self.model_config.task_level == ModelTaskLevel.node:
|
|
203
|
+
_, hard_edge_mask = self._get_hard_masks(
|
|
204
|
+
model, index, edge_index, num_nodes=x.size(0))
|
|
205
|
+
edge_mask = edge_mask[hard_edge_mask]
|
|
206
|
+
|
|
207
|
+
# Forward pass with masks applied
|
|
139
208
|
y_hat, y = model(x, edge_index, **kwargs), target
|
|
140
209
|
|
|
141
210
|
if index is not None:
|
|
142
211
|
y_hat, y = y_hat[index], y[index]
|
|
143
212
|
|
|
213
|
+
# Calculate loss
|
|
144
214
|
loss = self._loss(y_hat, y, edge_mask)
|
|
215
|
+
|
|
216
|
+
# Backward pass and optimization
|
|
145
217
|
loss.backward()
|
|
146
218
|
self.optimizer.step()
|
|
147
219
|
|
|
220
|
+
# Clean up
|
|
148
221
|
clear_masks(model)
|
|
149
222
|
self._curr_epoch = epoch
|
|
150
223
|
|
|
151
224
|
return float(loss)
|
|
152
225
|
|
|
226
|
+
@overload
|
|
153
227
|
def forward(
|
|
154
228
|
self,
|
|
155
229
|
model: torch.nn.Module,
|
|
@@ -160,9 +234,32 @@ class PGExplainer(ExplainerAlgorithm):
|
|
|
160
234
|
index: Optional[Union[int, Tensor]] = None,
|
|
161
235
|
**kwargs,
|
|
162
236
|
) -> Explanation:
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
237
|
+
...
|
|
238
|
+
|
|
239
|
+
@overload
|
|
240
|
+
def forward(
|
|
241
|
+
self,
|
|
242
|
+
model: torch.nn.Module,
|
|
243
|
+
x: Dict[NodeType, Tensor],
|
|
244
|
+
edge_index: Dict[EdgeType, Tensor],
|
|
245
|
+
*,
|
|
246
|
+
target: Tensor,
|
|
247
|
+
index: Optional[Union[int, Tensor]] = None,
|
|
248
|
+
**kwargs,
|
|
249
|
+
) -> HeteroExplanation:
|
|
250
|
+
...
|
|
251
|
+
|
|
252
|
+
def forward(
|
|
253
|
+
self,
|
|
254
|
+
model: torch.nn.Module,
|
|
255
|
+
x: Union[Tensor, Dict[NodeType, Tensor]],
|
|
256
|
+
edge_index: Union[Tensor, Dict[EdgeType, Tensor]],
|
|
257
|
+
*,
|
|
258
|
+
target: Tensor,
|
|
259
|
+
index: Optional[Union[int, Tensor]] = None,
|
|
260
|
+
**kwargs,
|
|
261
|
+
) -> Union[Explanation, HeteroExplanation]:
|
|
262
|
+
self.is_hetero = isinstance(x, dict)
|
|
166
263
|
|
|
167
264
|
if self._curr_epoch < self.epochs - 1: # Safety check:
|
|
168
265
|
raise ValueError(f"'{self.__class__.__name__}' is not yet fully "
|
|
@@ -171,7 +268,6 @@ class PGExplainer(ExplainerAlgorithm):
|
|
|
171
268
|
f"the underlying explainer model by running "
|
|
172
269
|
f"`explainer.algorithm.train(...)`.")
|
|
173
270
|
|
|
174
|
-
hard_edge_mask = None
|
|
175
271
|
if self.model_config.task_level == ModelTaskLevel.node:
|
|
176
272
|
if index is None:
|
|
177
273
|
raise ValueError(f"The 'index' argument needs to be provided "
|
|
@@ -181,20 +277,55 @@ class PGExplainer(ExplainerAlgorithm):
|
|
|
181
277
|
raise ValueError(f"Only scalars are supported for the 'index' "
|
|
182
278
|
f"argument in '{self.__class__.__name__}'")
|
|
183
279
|
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
280
|
+
# Get embeddings
|
|
281
|
+
node_embeddings = self._get_embeddings(model, x, edge_index, **kwargs)
|
|
282
|
+
|
|
283
|
+
# Generate explanations
|
|
284
|
+
if self.is_hetero:
|
|
285
|
+
# Generate edge masks for each edge type
|
|
286
|
+
edge_masks = {}
|
|
287
|
+
|
|
288
|
+
# Generate masks for each edge type
|
|
289
|
+
for edge_type, edge_idx in edge_index.items():
|
|
290
|
+
src_node_type, _, dst_node_type = edge_type
|
|
291
|
+
|
|
292
|
+
assert src_node_type in node_embeddings
|
|
293
|
+
assert dst_node_type in node_embeddings
|
|
294
|
+
|
|
295
|
+
inputs = self._get_inputs_hetero(node_embeddings, edge_type,
|
|
296
|
+
edge_idx, index)
|
|
297
|
+
logits = self.mlp(inputs).view(-1)
|
|
298
|
+
|
|
299
|
+
# For node-level explanations, get hard masks for this
|
|
300
|
+
# specific edge type
|
|
301
|
+
hard_edge_mask = None
|
|
302
|
+
if self.model_config.task_level == ModelTaskLevel.node:
|
|
303
|
+
_, hard_edge_mask = self._get_hard_masks(
|
|
304
|
+
model, index, edge_idx,
|
|
305
|
+
num_nodes=max(x[src_node_type].size(0),
|
|
306
|
+
x[dst_node_type].size(0)))
|
|
188
307
|
|
|
189
|
-
|
|
308
|
+
# Apply hard mask if available and it has any True values
|
|
309
|
+
edge_masks[edge_type] = self._post_process_mask(
|
|
310
|
+
logits, hard_edge_mask, apply_sigmoid=True)
|
|
190
311
|
|
|
191
|
-
|
|
192
|
-
|
|
312
|
+
explanation = HeteroExplanation()
|
|
313
|
+
explanation.set_value_dict('edge_mask', edge_masks)
|
|
314
|
+
return explanation
|
|
315
|
+
else:
|
|
316
|
+
hard_edge_mask = None
|
|
317
|
+
if self.model_config.task_level == ModelTaskLevel.node:
|
|
318
|
+
# We need to compute hard masks to properly clean up edges
|
|
319
|
+
_, hard_edge_mask = self._get_hard_masks(
|
|
320
|
+
model, index, edge_index, num_nodes=x.size(0))
|
|
193
321
|
|
|
194
|
-
|
|
195
|
-
|
|
322
|
+
inputs = self._get_inputs(node_embeddings, edge_index, index)
|
|
323
|
+
logits = self.mlp(inputs).view(-1)
|
|
196
324
|
|
|
197
|
-
|
|
325
|
+
edge_mask = self._post_process_mask(logits, hard_edge_mask,
|
|
326
|
+
apply_sigmoid=True)
|
|
327
|
+
|
|
328
|
+
return Explanation(edge_mask=edge_mask)
|
|
198
329
|
|
|
199
330
|
def supports(self) -> bool:
|
|
200
331
|
explanation_type = self.explainer_config.explanation_type
|
|
@@ -222,6 +353,76 @@ class PGExplainer(ExplainerAlgorithm):
|
|
|
222
353
|
|
|
223
354
|
###########################################################################
|
|
224
355
|
|
|
356
|
+
def _get_embeddings(self, model: torch.nn.Module, x: Union[Tensor,
|
|
357
|
+
Dict[NodeType,
|
|
358
|
+
Tensor]],
|
|
359
|
+
edge_index: Union[Tensor, Dict[EdgeType, Tensor]],
|
|
360
|
+
**kwargs) -> Union[Tensor, Dict[NodeType, Tensor]]:
|
|
361
|
+
"""Get embeddings from the model based on input type."""
|
|
362
|
+
if self.is_hetero:
|
|
363
|
+
# For heterogeneous graphs, get embeddings for each node type
|
|
364
|
+
embeddings_dict = get_embeddings_hetero(
|
|
365
|
+
model,
|
|
366
|
+
self.SUPPORTED_HETERO_MODELS,
|
|
367
|
+
x,
|
|
368
|
+
edge_index,
|
|
369
|
+
**kwargs,
|
|
370
|
+
)
|
|
371
|
+
|
|
372
|
+
# Use the last layer's embeddings for each node type
|
|
373
|
+
last_embedding_dict = {
|
|
374
|
+
node_type: embs[-1] if embs and len(embs) > 0 else None
|
|
375
|
+
for node_type, embs in embeddings_dict.items()
|
|
376
|
+
}
|
|
377
|
+
|
|
378
|
+
# Skip if no embeddings were captured
|
|
379
|
+
if not any(emb is not None
|
|
380
|
+
for emb in last_embedding_dict.values()):
|
|
381
|
+
raise ValueError(
|
|
382
|
+
"No embeddings were captured from the model. "
|
|
383
|
+
"Please check if the model architecture is supported.")
|
|
384
|
+
|
|
385
|
+
return last_embedding_dict
|
|
386
|
+
else:
|
|
387
|
+
# For homogeneous graphs, get embeddings directly
|
|
388
|
+
return get_embeddings(model, x, edge_index, **kwargs)[-1]
|
|
389
|
+
|
|
390
|
+
def _generate_edge_masks(
|
|
391
|
+
self, emb: Union[Tensor, Dict[NodeType, Tensor]],
|
|
392
|
+
edge_index: Union[Tensor,
|
|
393
|
+
Dict[EdgeType,
|
|
394
|
+
Tensor]], index: Optional[Union[int,
|
|
395
|
+
Tensor]],
|
|
396
|
+
temperature: float) -> Union[Tensor, Dict[EdgeType, Tensor]]:
|
|
397
|
+
"""Generate edge masks based on embeddings."""
|
|
398
|
+
if self.is_hetero:
|
|
399
|
+
# For heterogeneous graphs, generate masks for each edge type
|
|
400
|
+
edge_masks = {}
|
|
401
|
+
|
|
402
|
+
for edge_type, edge_idx in edge_index.items():
|
|
403
|
+
src, _, dst = edge_type
|
|
404
|
+
|
|
405
|
+
assert src in emb and dst in emb
|
|
406
|
+
# Generate inputs for this edge type
|
|
407
|
+
inputs = self._get_inputs_hetero(emb, edge_type, edge_idx,
|
|
408
|
+
index)
|
|
409
|
+
logits = self.mlp(inputs).view(-1)
|
|
410
|
+
edge_masks[edge_type] = self._concrete_sample(
|
|
411
|
+
logits, temperature)
|
|
412
|
+
|
|
413
|
+
# Ensure we have at least one valid edge mask
|
|
414
|
+
if not edge_masks:
|
|
415
|
+
raise ValueError(
|
|
416
|
+
"Could not generate edge masks for any edge type. "
|
|
417
|
+
"Please ensure the model architecture is supported.")
|
|
418
|
+
|
|
419
|
+
return edge_masks
|
|
420
|
+
else:
|
|
421
|
+
# For homogeneous graphs, generate a single mask
|
|
422
|
+
inputs = self._get_inputs(emb, edge_index, index)
|
|
423
|
+
logits = self.mlp(inputs).view(-1)
|
|
424
|
+
return self._concrete_sample(logits, temperature)
|
|
425
|
+
|
|
225
426
|
def _get_inputs(self, embedding: Tensor, edge_index: Tensor,
|
|
226
427
|
index: Optional[int] = None) -> Tensor:
|
|
227
428
|
zs = [embedding[edge_index[0]], embedding[edge_index[1]]]
|
|
@@ -230,6 +431,27 @@ class PGExplainer(ExplainerAlgorithm):
|
|
|
230
431
|
zs.append(embedding[index].view(1, -1).repeat(zs[0].size(0), 1))
|
|
231
432
|
return torch.cat(zs, dim=-1)
|
|
232
433
|
|
|
434
|
+
def _get_inputs_hetero(self, embedding_dict: Dict[NodeType, Tensor],
|
|
435
|
+
edge_type: Tuple[str, str, str], edge_index: Tensor,
|
|
436
|
+
index: Optional[int] = None) -> Tensor:
|
|
437
|
+
src, _, dst = edge_type
|
|
438
|
+
|
|
439
|
+
# Get embeddings for source and destination nodes
|
|
440
|
+
src_emb = embedding_dict[src]
|
|
441
|
+
dst_emb = embedding_dict[dst]
|
|
442
|
+
|
|
443
|
+
# Source and destination node embeddings
|
|
444
|
+
zs = [src_emb[edge_index[0]], dst_emb[edge_index[1]]]
|
|
445
|
+
|
|
446
|
+
# For node-level explanations, add the target node embedding
|
|
447
|
+
if self.model_config.task_level == ModelTaskLevel.node:
|
|
448
|
+
assert index is not None
|
|
449
|
+
# Assuming index refers to a node of type 'src'
|
|
450
|
+
target_emb = src_emb[index].view(1, -1).repeat(zs[0].size(0), 1)
|
|
451
|
+
zs.append(target_emb)
|
|
452
|
+
|
|
453
|
+
return torch.cat(zs, dim=-1)
|
|
454
|
+
|
|
233
455
|
def _get_temperature(self, epoch: int) -> float:
|
|
234
456
|
temp = self.coeffs['temp']
|
|
235
457
|
return temp[0] * pow(temp[1] / temp[0], epoch / self.epochs)
|
|
@@ -240,19 +462,55 @@ class PGExplainer(ExplainerAlgorithm):
|
|
|
240
462
|
eps = (1 - 2 * bias) * torch.rand_like(logits) + bias
|
|
241
463
|
return (eps.log() - (1 - eps).log() + logits) / temperature
|
|
242
464
|
|
|
243
|
-
def _loss(self, y_hat: Tensor, y: Tensor,
|
|
465
|
+
def _loss(self, y_hat: Tensor, y: Tensor,
|
|
466
|
+
edge_mask: Union[Tensor, Dict[EdgeType, Tensor]]) -> Tensor:
|
|
467
|
+
# Calculate base loss based on model configuration
|
|
468
|
+
loss = self._calculate_base_loss(y_hat, y)
|
|
469
|
+
|
|
470
|
+
# Apply regularization based on graph type
|
|
471
|
+
if self.is_hetero:
|
|
472
|
+
loss = self._apply_hetero_regularization(loss, edge_mask)
|
|
473
|
+
else:
|
|
474
|
+
loss = self._apply_homo_regularization(loss, edge_mask)
|
|
475
|
+
|
|
476
|
+
return loss
|
|
477
|
+
|
|
478
|
+
def _calculate_base_loss(self, y_hat: Tensor, y: Tensor) -> Tensor:
|
|
479
|
+
"""Calculate base loss based on model configuration."""
|
|
244
480
|
if self.model_config.mode == ModelMode.binary_classification:
|
|
245
|
-
|
|
481
|
+
return self._loss_binary_classification(y_hat, y)
|
|
246
482
|
elif self.model_config.mode == ModelMode.multiclass_classification:
|
|
247
|
-
|
|
483
|
+
return self._loss_multiclass_classification(y_hat, y)
|
|
248
484
|
elif self.model_config.mode == ModelMode.regression:
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
485
|
+
return self._loss_regression(y_hat, y)
|
|
486
|
+
else:
|
|
487
|
+
raise ValueError(
|
|
488
|
+
f"Unsupported model mode: {self.model_config.mode}")
|
|
489
|
+
|
|
490
|
+
def _apply_hetero_regularization(
|
|
491
|
+
self, loss: Tensor, edge_mask: Dict[EdgeType, Tensor]) -> Tensor:
|
|
492
|
+
"""Apply regularization for heterogeneous graph."""
|
|
493
|
+
for _, mask in edge_mask.items():
|
|
494
|
+
loss = self._add_mask_regularization(loss, mask)
|
|
495
|
+
|
|
496
|
+
return loss
|
|
497
|
+
|
|
498
|
+
def _apply_homo_regularization(self, loss: Tensor,
|
|
499
|
+
edge_mask: Tensor) -> Tensor:
|
|
500
|
+
"""Apply regularization for homogeneous graph."""
|
|
501
|
+
return self._add_mask_regularization(loss, edge_mask)
|
|
502
|
+
|
|
503
|
+
def _add_mask_regularization(self, loss: Tensor, mask: Tensor) -> Tensor:
|
|
504
|
+
"""Add size and entropy regularization for a mask."""
|
|
505
|
+
# Apply sigmoid for mask values
|
|
506
|
+
mask = mask.sigmoid()
|
|
507
|
+
|
|
508
|
+
# Size regularization
|
|
253
509
|
size_loss = mask.sum() * self.coeffs['edge_size']
|
|
254
|
-
|
|
255
|
-
|
|
510
|
+
|
|
511
|
+
# Entropy regularization
|
|
512
|
+
masked = 0.99 * mask + 0.005
|
|
513
|
+
mask_ent = -masked * masked.log() - (1 - masked) * (1 - masked).log()
|
|
256
514
|
mask_ent_loss = mask_ent.mean() * self.coeffs['edge_ent']
|
|
257
515
|
|
|
258
516
|
return loss + size_loss + mask_ent_loss
|
|
@@ -192,7 +192,7 @@ class Explainer:
|
|
|
192
192
|
if target is not None:
|
|
193
193
|
warnings.warn(
|
|
194
194
|
f"The 'target' should not be provided for the explanation "
|
|
195
|
-
f"type '{self.explanation_type.value}'")
|
|
195
|
+
f"type '{self.explanation_type.value}'", stacklevel=2)
|
|
196
196
|
prediction = self.get_prediction(x, edge_index, **kwargs)
|
|
197
197
|
target = self.get_target(prediction)
|
|
198
198
|
|
|
@@ -265,7 +265,7 @@ class Explainer:
|
|
|
265
265
|
return (prediction > 0).long().view(-1)
|
|
266
266
|
if self.model_config.return_type == ModelReturnType.probs:
|
|
267
267
|
return (prediction > 0.5).long().view(-1)
|
|
268
|
-
|
|
268
|
+
raise AssertionError()
|
|
269
269
|
|
|
270
270
|
if self.model_config.mode == ModelMode.multiclass_classification:
|
|
271
271
|
return prediction.argmax(dim=-1)
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import copy
|
|
2
|
-
from typing import Dict, List, Optional, Union
|
|
2
|
+
from typing import Dict, List, Optional, Tuple, Union
|
|
3
3
|
|
|
4
4
|
import torch
|
|
5
5
|
from torch import Tensor
|
|
@@ -8,7 +8,10 @@ from torch_geometric.data.data import Data, warn_or_raise
|
|
|
8
8
|
from torch_geometric.data.hetero_data import HeteroData
|
|
9
9
|
from torch_geometric.explain.config import ThresholdConfig, ThresholdType
|
|
10
10
|
from torch_geometric.typing import EdgeType, NodeType
|
|
11
|
-
from torch_geometric.visualization import
|
|
11
|
+
from torch_geometric.visualization import (
|
|
12
|
+
visualize_graph,
|
|
13
|
+
visualize_hetero_graph,
|
|
14
|
+
)
|
|
12
15
|
|
|
13
16
|
|
|
14
17
|
class ExplanationMixin:
|
|
@@ -100,7 +103,7 @@ class ExplanationMixin:
|
|
|
100
103
|
out[index] = 1.0
|
|
101
104
|
return out.view(mask.size())
|
|
102
105
|
|
|
103
|
-
|
|
106
|
+
raise AssertionError()
|
|
104
107
|
|
|
105
108
|
def threshold(
|
|
106
109
|
self,
|
|
@@ -340,10 +343,10 @@ class HeteroExplanation(HeteroData, ExplanationMixin):
|
|
|
340
343
|
"""
|
|
341
344
|
node_mask_dict = self.node_mask_dict
|
|
342
345
|
for node_mask in node_mask_dict.values():
|
|
343
|
-
if node_mask.dim() != 2
|
|
346
|
+
if node_mask.dim() != 2:
|
|
344
347
|
raise ValueError(f"Cannot compute feature importance for "
|
|
345
348
|
f"object-level 'node_mask' "
|
|
346
|
-
f"(got shape {
|
|
349
|
+
f"(got shape {node_mask.size()})")
|
|
347
350
|
|
|
348
351
|
if feat_labels is None:
|
|
349
352
|
feat_labels = {}
|
|
@@ -362,6 +365,87 @@ class HeteroExplanation(HeteroData, ExplanationMixin):
|
|
|
362
365
|
|
|
363
366
|
return _visualize_score(score, all_feat_labels, path, top_k)
|
|
364
367
|
|
|
368
|
+
def visualize_graph(
|
|
369
|
+
self,
|
|
370
|
+
path: Optional[str] = None,
|
|
371
|
+
node_labels: Optional[Dict[NodeType, List[str]]] = None,
|
|
372
|
+
node_size_range: Tuple[float, float] = (50, 500),
|
|
373
|
+
node_opacity_range: Tuple[float, float] = (0.2, 1.0),
|
|
374
|
+
edge_width_range: Tuple[float, float] = (0.1, 2.0),
|
|
375
|
+
edge_opacity_range: Tuple[float, float] = (0.2, 1.0),
|
|
376
|
+
) -> None:
|
|
377
|
+
r"""Visualizes the explanation subgraph using networkx, with edge
|
|
378
|
+
opacity corresponding to edge importance and node colors
|
|
379
|
+
corresponding to node types.
|
|
380
|
+
|
|
381
|
+
Args:
|
|
382
|
+
path (str, optional): The path to where the plot is saved.
|
|
383
|
+
If set to :obj:`None`, will visualize the plot on-the-fly.
|
|
384
|
+
(default: :obj:`None`)
|
|
385
|
+
node_labels (Dict[NodeType, List[str]], optional): The display
|
|
386
|
+
names of nodes for each node type that will be shown in the
|
|
387
|
+
visualization. (default: :obj:`None`)
|
|
388
|
+
node_size_range (Tuple[float, float], optional): The minimum and
|
|
389
|
+
maximum node size in the visualization.
|
|
390
|
+
(default: :obj:`(50, 500)`)
|
|
391
|
+
node_opacity_range (Tuple[float, float], optional): The minimum and
|
|
392
|
+
maximum node opacity in the visualization.
|
|
393
|
+
(default: :obj:`(0.2, 1.0)`)
|
|
394
|
+
edge_width_range (Tuple[float, float], optional): The minimum and
|
|
395
|
+
maximum edge width in the visualization.
|
|
396
|
+
(default: :obj:`(0.1, 2.0)`)
|
|
397
|
+
edge_opacity_range (Tuple[float, float], optional): The minimum and
|
|
398
|
+
maximum edge opacity in the visualization.
|
|
399
|
+
(default: :obj:`(0.2, 1.0)`)
|
|
400
|
+
"""
|
|
401
|
+
# Validate node labels if provided
|
|
402
|
+
if node_labels is not None:
|
|
403
|
+
for node_type, labels in node_labels.items():
|
|
404
|
+
if node_type not in self.node_types:
|
|
405
|
+
raise ValueError(
|
|
406
|
+
f"Node type '{node_type}' in node_labels "
|
|
407
|
+
f"does not exist in the explanation graph")
|
|
408
|
+
if len(labels) != self[node_type].num_nodes:
|
|
409
|
+
raise ValueError(f"Number of labels for node type "
|
|
410
|
+
f"'{node_type}' (got {len(labels)}) does "
|
|
411
|
+
f"not match the number of nodes "
|
|
412
|
+
f"(got {self[node_type].num_nodes})")
|
|
413
|
+
# Get the explanation subgraph
|
|
414
|
+
subgraph = self.get_explanation_subgraph()
|
|
415
|
+
|
|
416
|
+
# Prepare edge indices and weights for each edge type
|
|
417
|
+
edge_index_dict = {}
|
|
418
|
+
edge_weight_dict = {}
|
|
419
|
+
for edge_type in subgraph.edge_types:
|
|
420
|
+
if edge_type[0] == 'x' or edge_type[-1] == 'x': # Skip edges
|
|
421
|
+
continue
|
|
422
|
+
edge_index_dict[edge_type] = subgraph[edge_type].edge_index
|
|
423
|
+
edge_weight_dict[edge_type] = subgraph[edge_type].get(
|
|
424
|
+
'edge_mask',
|
|
425
|
+
torch.ones(subgraph[edge_type].edge_index.size(1)))
|
|
426
|
+
|
|
427
|
+
# Prepare node weights for each node type
|
|
428
|
+
node_weight_dict = {}
|
|
429
|
+
for node_type in subgraph.node_types:
|
|
430
|
+
if node_type == 'x': # Skip the global store
|
|
431
|
+
continue
|
|
432
|
+
node_weight_dict[node_type] = subgraph[node_type] \
|
|
433
|
+
.get('node_mask',
|
|
434
|
+
torch.ones(subgraph[node_type].num_nodes)).squeeze(-1)
|
|
435
|
+
|
|
436
|
+
# Call the visualization function
|
|
437
|
+
visualize_hetero_graph(
|
|
438
|
+
edge_index_dict=edge_index_dict,
|
|
439
|
+
edge_weight_dict=edge_weight_dict,
|
|
440
|
+
path=path,
|
|
441
|
+
node_labels_dict=node_labels,
|
|
442
|
+
node_weight_dict=node_weight_dict,
|
|
443
|
+
node_size_range=node_size_range,
|
|
444
|
+
node_opacity_range=node_opacity_range,
|
|
445
|
+
edge_width_range=edge_width_range,
|
|
446
|
+
edge_opacity_range=edge_opacity_range,
|
|
447
|
+
)
|
|
448
|
+
|
|
365
449
|
|
|
366
450
|
def _visualize_score(
|
|
367
451
|
score: torch.Tensor,
|
|
@@ -13,7 +13,7 @@ def unfaithfulness(
|
|
|
13
13
|
top_k: Optional[int] = None,
|
|
14
14
|
) -> float:
|
|
15
15
|
r"""Evaluates how faithful an :class:`~torch_geometric.explain.Explanation`
|
|
16
|
-
is to an
|
|
16
|
+
is to an underlying GNN predictor, as described in the
|
|
17
17
|
`"Evaluating Explainability for Graph Neural Networks"
|
|
18
18
|
<https://arxiv.org/abs/2208.09339>`_ paper.
|
|
19
19
|
|
|
@@ -6,6 +6,7 @@ from typing import Any, Dict, List, Optional, Union
|
|
|
6
6
|
import torch
|
|
7
7
|
|
|
8
8
|
from torch_geometric.graphgym.config import cfg
|
|
9
|
+
from torch_geometric.io import fs
|
|
9
10
|
|
|
10
11
|
MODEL_STATE = 'model_state'
|
|
11
12
|
OPTIMIZER_STATE = 'optimizer_state'
|
|
@@ -25,7 +26,7 @@ def load_ckpt(
|
|
|
25
26
|
if not osp.exists(path):
|
|
26
27
|
return 0
|
|
27
28
|
|
|
28
|
-
ckpt =
|
|
29
|
+
ckpt = fs.torch_load(path)
|
|
29
30
|
model.load_state_dict(ckpt[MODEL_STATE])
|
|
30
31
|
if optimizer is not None and OPTIMIZER_STATE in ckpt:
|
|
31
32
|
optimizer.load_state_dict(ckpt[OPTIMIZER_STATE])
|
|
@@ -16,8 +16,9 @@ try: # Define global config object
|
|
|
16
16
|
cfg = CN()
|
|
17
17
|
except ImportError:
|
|
18
18
|
cfg = None
|
|
19
|
-
warnings.warn(
|
|
20
|
-
|
|
19
|
+
warnings.warn(
|
|
20
|
+
"Could not define global config object. Please install "
|
|
21
|
+
"'yacs' via 'pip install yacs' in order to use GraphGym", stacklevel=2)
|
|
21
22
|
|
|
22
23
|
|
|
23
24
|
def set_cfg(cfg):
|