pyg-nightly 2.7.0.dev20241009__py3-none-any.whl → 2.8.0.dev20251228__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251228.dist-info}/METADATA +77 -53
- {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251228.dist-info}/RECORD +227 -190
- {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251228.dist-info}/WHEEL +1 -1
- pyg_nightly-2.8.0.dev20251228.dist-info/licenses/LICENSE +19 -0
- torch_geometric/__init__.py +14 -2
- torch_geometric/_compile.py +9 -3
- torch_geometric/_onnx.py +214 -0
- torch_geometric/config_mixin.py +5 -3
- torch_geometric/config_store.py +1 -1
- torch_geometric/contrib/__init__.py +1 -1
- torch_geometric/contrib/explain/pgm_explainer.py +1 -1
- torch_geometric/data/batch.py +2 -2
- torch_geometric/data/collate.py +1 -3
- torch_geometric/data/data.py +109 -5
- torch_geometric/data/database.py +4 -0
- torch_geometric/data/dataset.py +14 -11
- torch_geometric/data/extract.py +1 -1
- torch_geometric/data/feature_store.py +17 -22
- torch_geometric/data/graph_store.py +3 -2
- torch_geometric/data/hetero_data.py +139 -7
- torch_geometric/data/hypergraph_data.py +2 -2
- torch_geometric/data/in_memory_dataset.py +2 -2
- torch_geometric/data/lightning/datamodule.py +42 -28
- torch_geometric/data/storage.py +9 -1
- torch_geometric/datasets/__init__.py +18 -1
- torch_geometric/datasets/actor.py +7 -9
- torch_geometric/datasets/airfrans.py +15 -17
- torch_geometric/datasets/airports.py +8 -10
- torch_geometric/datasets/amazon.py +8 -11
- torch_geometric/datasets/amazon_book.py +8 -9
- torch_geometric/datasets/amazon_products.py +7 -9
- torch_geometric/datasets/aminer.py +8 -9
- torch_geometric/datasets/aqsol.py +10 -13
- torch_geometric/datasets/attributed_graph_dataset.py +8 -10
- torch_geometric/datasets/ba_multi_shapes.py +10 -12
- torch_geometric/datasets/ba_shapes.py +5 -6
- torch_geometric/datasets/city.py +157 -0
- torch_geometric/datasets/dbp15k.py +1 -1
- torch_geometric/datasets/git_mol_dataset.py +263 -0
- torch_geometric/datasets/hgb_dataset.py +2 -2
- torch_geometric/datasets/hm.py +1 -1
- torch_geometric/datasets/instruct_mol_dataset.py +134 -0
- torch_geometric/datasets/md17.py +3 -3
- torch_geometric/datasets/medshapenet.py +145 -0
- torch_geometric/datasets/modelnet.py +1 -1
- torch_geometric/datasets/molecule_gpt_dataset.py +492 -0
- torch_geometric/datasets/molecule_net.py +3 -2
- torch_geometric/datasets/ppi.py +2 -1
- torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
- torch_geometric/datasets/qm7.py +1 -1
- torch_geometric/datasets/qm9.py +1 -1
- torch_geometric/datasets/snap_dataset.py +8 -4
- torch_geometric/datasets/tag_dataset.py +462 -0
- torch_geometric/datasets/teeth3ds.py +269 -0
- torch_geometric/datasets/web_qsp_dataset.py +310 -209
- torch_geometric/datasets/wikics.py +2 -1
- torch_geometric/deprecation.py +1 -1
- torch_geometric/distributed/__init__.py +13 -0
- torch_geometric/distributed/dist_loader.py +2 -2
- torch_geometric/distributed/partition.py +2 -2
- torch_geometric/distributed/rpc.py +3 -3
- torch_geometric/edge_index.py +18 -14
- torch_geometric/explain/algorithm/attention_explainer.py +219 -29
- torch_geometric/explain/algorithm/base.py +2 -2
- torch_geometric/explain/algorithm/captum.py +1 -1
- torch_geometric/explain/algorithm/captum_explainer.py +2 -1
- torch_geometric/explain/algorithm/gnn_explainer.py +406 -69
- torch_geometric/explain/algorithm/graphmask_explainer.py +8 -8
- torch_geometric/explain/algorithm/pg_explainer.py +305 -47
- torch_geometric/explain/explainer.py +2 -2
- torch_geometric/explain/explanation.py +87 -3
- torch_geometric/explain/metric/faithfulness.py +1 -1
- torch_geometric/graphgym/config.py +3 -2
- torch_geometric/graphgym/imports.py +15 -4
- torch_geometric/graphgym/logger.py +1 -1
- torch_geometric/graphgym/loss.py +1 -1
- torch_geometric/graphgym/models/encoder.py +2 -2
- torch_geometric/graphgym/models/layer.py +1 -1
- torch_geometric/graphgym/utils/comp_budget.py +4 -3
- torch_geometric/hash_tensor.py +798 -0
- torch_geometric/index.py +14 -5
- torch_geometric/inspector.py +4 -0
- torch_geometric/io/fs.py +5 -4
- torch_geometric/llm/__init__.py +9 -0
- torch_geometric/llm/large_graph_indexer.py +741 -0
- torch_geometric/llm/models/__init__.py +23 -0
- torch_geometric/{nn → llm}/models/g_retriever.py +77 -45
- torch_geometric/llm/models/git_mol.py +336 -0
- torch_geometric/llm/models/glem.py +397 -0
- torch_geometric/{nn/nlp → llm/models}/llm.py +180 -32
- torch_geometric/llm/models/llm_judge.py +158 -0
- torch_geometric/llm/models/molecule_gpt.py +222 -0
- torch_geometric/llm/models/protein_mpnn.py +333 -0
- torch_geometric/llm/models/sentence_transformer.py +188 -0
- torch_geometric/llm/models/txt2kg.py +353 -0
- torch_geometric/llm/models/vision_transformer.py +38 -0
- torch_geometric/llm/rag_loader.py +154 -0
- torch_geometric/llm/utils/__init__.py +10 -0
- torch_geometric/llm/utils/backend_utils.py +443 -0
- torch_geometric/llm/utils/feature_store.py +169 -0
- torch_geometric/llm/utils/graph_store.py +199 -0
- torch_geometric/llm/utils/vectorrag.py +125 -0
- torch_geometric/loader/cluster.py +4 -4
- torch_geometric/loader/ibmb_loader.py +4 -4
- torch_geometric/loader/link_loader.py +1 -1
- torch_geometric/loader/link_neighbor_loader.py +2 -1
- torch_geometric/loader/mixin.py +6 -5
- torch_geometric/loader/neighbor_loader.py +1 -1
- torch_geometric/loader/neighbor_sampler.py +2 -2
- torch_geometric/loader/prefetch.py +3 -2
- torch_geometric/loader/temporal_dataloader.py +2 -2
- torch_geometric/loader/utils.py +10 -10
- torch_geometric/metrics/__init__.py +14 -0
- torch_geometric/metrics/link_pred.py +745 -92
- torch_geometric/nn/__init__.py +1 -0
- torch_geometric/nn/aggr/base.py +1 -1
- torch_geometric/nn/aggr/equilibrium.py +1 -1
- torch_geometric/nn/aggr/fused.py +1 -1
- torch_geometric/nn/aggr/patch_transformer.py +8 -2
- torch_geometric/nn/aggr/set_transformer.py +1 -1
- torch_geometric/nn/aggr/utils.py +9 -4
- torch_geometric/nn/attention/__init__.py +9 -1
- torch_geometric/nn/attention/polynormer.py +107 -0
- torch_geometric/nn/attention/qformer.py +71 -0
- torch_geometric/nn/attention/sgformer.py +99 -0
- torch_geometric/nn/conv/__init__.py +2 -0
- torch_geometric/nn/conv/appnp.py +1 -1
- torch_geometric/nn/conv/cugraph/gat_conv.py +8 -2
- torch_geometric/nn/conv/cugraph/rgcn_conv.py +3 -0
- torch_geometric/nn/conv/cugraph/sage_conv.py +3 -0
- torch_geometric/nn/conv/dna_conv.py +1 -1
- torch_geometric/nn/conv/eg_conv.py +7 -7
- torch_geometric/nn/conv/gen_conv.py +1 -1
- torch_geometric/nn/conv/gravnet_conv.py +2 -1
- torch_geometric/nn/conv/hetero_conv.py +2 -1
- torch_geometric/nn/conv/meshcnn_conv.py +487 -0
- torch_geometric/nn/conv/message_passing.py +5 -4
- torch_geometric/nn/conv/rgcn_conv.py +2 -1
- torch_geometric/nn/conv/sg_conv.py +1 -1
- torch_geometric/nn/conv/spline_conv.py +2 -1
- torch_geometric/nn/conv/ssg_conv.py +1 -1
- torch_geometric/nn/conv/transformer_conv.py +5 -3
- torch_geometric/nn/data_parallel.py +5 -4
- torch_geometric/nn/dense/linear.py +0 -20
- torch_geometric/nn/encoding.py +17 -3
- torch_geometric/nn/fx.py +14 -12
- torch_geometric/nn/model_hub.py +2 -15
- torch_geometric/nn/models/__init__.py +11 -2
- torch_geometric/nn/models/attentive_fp.py +1 -1
- torch_geometric/nn/models/attract_repel.py +148 -0
- torch_geometric/nn/models/basic_gnn.py +2 -1
- torch_geometric/nn/models/captum.py +1 -1
- torch_geometric/nn/models/deep_graph_infomax.py +1 -1
- torch_geometric/nn/models/dimenet.py +2 -2
- torch_geometric/nn/models/dimenet_utils.py +4 -2
- torch_geometric/nn/models/gpse.py +1083 -0
- torch_geometric/nn/models/graph_unet.py +13 -4
- torch_geometric/nn/models/lpformer.py +783 -0
- torch_geometric/nn/models/metapath2vec.py +1 -1
- torch_geometric/nn/models/mlp.py +4 -2
- torch_geometric/nn/models/node2vec.py +1 -1
- torch_geometric/nn/models/polynormer.py +206 -0
- torch_geometric/nn/models/rev_gnn.py +3 -3
- torch_geometric/nn/models/sgformer.py +219 -0
- torch_geometric/nn/models/signed_gcn.py +1 -1
- torch_geometric/nn/models/visnet.py +2 -2
- torch_geometric/nn/norm/batch_norm.py +17 -7
- torch_geometric/nn/norm/diff_group_norm.py +7 -2
- torch_geometric/nn/norm/graph_norm.py +9 -4
- torch_geometric/nn/norm/instance_norm.py +5 -1
- torch_geometric/nn/norm/layer_norm.py +15 -7
- torch_geometric/nn/norm/msg_norm.py +8 -2
- torch_geometric/nn/pool/__init__.py +8 -4
- torch_geometric/nn/pool/cluster_pool.py +3 -4
- torch_geometric/nn/pool/connect/base.py +1 -3
- torch_geometric/nn/pool/knn.py +13 -10
- torch_geometric/nn/pool/select/base.py +1 -4
- torch_geometric/nn/to_hetero_module.py +4 -3
- torch_geometric/nn/to_hetero_transformer.py +3 -3
- torch_geometric/nn/to_hetero_with_bases_transformer.py +4 -4
- torch_geometric/profile/__init__.py +2 -0
- torch_geometric/profile/nvtx.py +66 -0
- torch_geometric/profile/utils.py +20 -5
- torch_geometric/sampler/__init__.py +2 -1
- torch_geometric/sampler/base.py +336 -7
- torch_geometric/sampler/hgt_sampler.py +11 -1
- torch_geometric/sampler/neighbor_sampler.py +296 -23
- torch_geometric/sampler/utils.py +93 -5
- torch_geometric/testing/__init__.py +4 -0
- torch_geometric/testing/decorators.py +35 -5
- torch_geometric/testing/distributed.py +1 -1
- torch_geometric/transforms/__init__.py +2 -0
- torch_geometric/transforms/add_gpse.py +49 -0
- torch_geometric/transforms/add_metapaths.py +8 -6
- torch_geometric/transforms/add_positional_encoding.py +2 -2
- torch_geometric/transforms/base_transform.py +2 -1
- torch_geometric/transforms/delaunay.py +65 -15
- torch_geometric/transforms/face_to_edge.py +32 -3
- torch_geometric/transforms/gdc.py +7 -8
- torch_geometric/transforms/largest_connected_components.py +1 -1
- torch_geometric/transforms/mask.py +5 -1
- torch_geometric/transforms/normalize_features.py +3 -3
- torch_geometric/transforms/random_link_split.py +1 -1
- torch_geometric/transforms/remove_duplicated_edges.py +4 -2
- torch_geometric/transforms/rooted_subgraph.py +1 -1
- torch_geometric/typing.py +70 -17
- torch_geometric/utils/__init__.py +4 -1
- torch_geometric/utils/_lexsort.py +0 -9
- torch_geometric/utils/_negative_sampling.py +27 -12
- torch_geometric/utils/_scatter.py +132 -195
- torch_geometric/utils/_sort_edge_index.py +0 -2
- torch_geometric/utils/_spmm.py +16 -14
- torch_geometric/utils/_subgraph.py +4 -0
- torch_geometric/utils/_to_dense_batch.py +2 -2
- torch_geometric/utils/_trim_to_layer.py +2 -2
- torch_geometric/utils/convert.py +17 -10
- torch_geometric/utils/cross_entropy.py +34 -13
- torch_geometric/utils/embedding.py +91 -2
- torch_geometric/utils/geodesic.py +4 -3
- torch_geometric/utils/influence.py +279 -0
- torch_geometric/utils/map.py +13 -9
- torch_geometric/utils/nested.py +1 -1
- torch_geometric/utils/smiles.py +3 -3
- torch_geometric/utils/sparse.py +7 -14
- torch_geometric/visualization/__init__.py +2 -1
- torch_geometric/visualization/graph.py +250 -5
- torch_geometric/warnings.py +11 -2
- torch_geometric/nn/nlp/__init__.py +0 -7
- torch_geometric/nn/nlp/sentence_transformer.py +0 -101
torch_geometric/edge_index.py
CHANGED
|
@@ -17,6 +17,7 @@ from typing import (
|
|
|
17
17
|
overload,
|
|
18
18
|
)
|
|
19
19
|
|
|
20
|
+
import numpy as np
|
|
20
21
|
import torch
|
|
21
22
|
import torch.utils._pytree as pytree
|
|
22
23
|
from torch import Tensor
|
|
@@ -183,7 +184,7 @@ class EdgeIndex(Tensor):
|
|
|
183
184
|
|
|
184
185
|
edge_index = EdgeIndex(
|
|
185
186
|
[[0, 1, 1, 2],
|
|
186
|
-
[1, 0, 2, 1]]
|
|
187
|
+
[1, 0, 2, 1]],
|
|
187
188
|
sparse_size=(3, 3),
|
|
188
189
|
sort_order='row',
|
|
189
190
|
is_undirected=True,
|
|
@@ -210,7 +211,7 @@ class EdgeIndex(Tensor):
|
|
|
210
211
|
assert not edge_index.is_undirected
|
|
211
212
|
|
|
212
213
|
# Sparse-Dense Matrix Multiplication:
|
|
213
|
-
out = edge_index.flip(0) @
|
|
214
|
+
out = edge_index.flip(0) @ torch.randn(3, 16)
|
|
214
215
|
assert out.size() == (3, 16)
|
|
215
216
|
"""
|
|
216
217
|
# See "https://pytorch.org/docs/stable/notes/extending.html"
|
|
@@ -297,8 +298,7 @@ class EdgeIndex(Tensor):
|
|
|
297
298
|
indptr = None
|
|
298
299
|
data = torch.stack([row, col], dim=0)
|
|
299
300
|
|
|
300
|
-
if
|
|
301
|
-
and data.layout == torch.sparse_csc):
|
|
301
|
+
if data.layout == torch.sparse_csc:
|
|
302
302
|
row = data.row_indices()
|
|
303
303
|
indptr = data.ccol_indices()
|
|
304
304
|
|
|
@@ -325,7 +325,7 @@ class EdgeIndex(Tensor):
|
|
|
325
325
|
elif sparse_size[0] is None and sparse_size[1] is not None:
|
|
326
326
|
sparse_size = (sparse_size[1], sparse_size[1])
|
|
327
327
|
|
|
328
|
-
out = Tensor._make_wrapper_subclass(
|
|
328
|
+
out = Tensor._make_wrapper_subclass(
|
|
329
329
|
cls,
|
|
330
330
|
size=data.size(),
|
|
331
331
|
strides=data.stride(),
|
|
@@ -803,7 +803,7 @@ class EdgeIndex(Tensor):
|
|
|
803
803
|
|
|
804
804
|
size = self.get_sparse_size()
|
|
805
805
|
if value is not None and value.dim() > 1:
|
|
806
|
-
size = size + value.size()[1:]
|
|
806
|
+
size = size + value.size()[1:]
|
|
807
807
|
|
|
808
808
|
out = torch.full(size, fill_value, dtype=dtype, device=self.device)
|
|
809
809
|
out[self._data[0], self._data[1]] = value if value is not None else 1
|
|
@@ -881,10 +881,6 @@ class EdgeIndex(Tensor):
|
|
|
881
881
|
If not specified, non-zero elements will be assigned a value of
|
|
882
882
|
:obj:`1.0`. (default: :obj:`None`)
|
|
883
883
|
"""
|
|
884
|
-
if not torch_geometric.typing.WITH_PT112:
|
|
885
|
-
raise NotImplementedError(
|
|
886
|
-
"'to_sparse_csc' not supported for PyTorch < 1.12")
|
|
887
|
-
|
|
888
884
|
(colptr, row), perm = self.get_csc()
|
|
889
885
|
if value is not None and perm is not None:
|
|
890
886
|
value = value[perm]
|
|
@@ -921,7 +917,7 @@ class EdgeIndex(Tensor):
|
|
|
921
917
|
return self.to_sparse_coo(value)
|
|
922
918
|
if layout == torch.sparse_csr:
|
|
923
919
|
return self.to_sparse_csr(value)
|
|
924
|
-
if
|
|
920
|
+
if layout == torch.sparse_csc:
|
|
925
921
|
return self.to_sparse_csc(value)
|
|
926
922
|
|
|
927
923
|
raise ValueError(f"Unexpected tensor layout (got '{layout}')")
|
|
@@ -1190,10 +1186,10 @@ class EdgeIndex(Tensor):
|
|
|
1190
1186
|
return edge_index
|
|
1191
1187
|
|
|
1192
1188
|
# Prevent auto-wrapping outputs back into the proper subclass type:
|
|
1193
|
-
__torch_function__ = torch._C._disabled_torch_function_impl
|
|
1189
|
+
__torch_function__ = torch._C._disabled_torch_function_impl # type: ignore
|
|
1194
1190
|
|
|
1195
1191
|
@classmethod
|
|
1196
|
-
def __torch_dispatch__(
|
|
1192
|
+
def __torch_dispatch__( # type: ignore
|
|
1197
1193
|
cls: Type,
|
|
1198
1194
|
func: Callable[..., Any],
|
|
1199
1195
|
types: Iterable[Type[Any]],
|
|
@@ -1246,6 +1242,14 @@ class EdgeIndex(Tensor):
|
|
|
1246
1242
|
return torch._tensor_str._add_suffixes(prefix + tensor_str, suffixes,
|
|
1247
1243
|
indent, force_newline=False)
|
|
1248
1244
|
|
|
1245
|
+
def tolist(self) -> List[Any]:
|
|
1246
|
+
"""""" # noqa: D419
|
|
1247
|
+
return self._data.tolist()
|
|
1248
|
+
|
|
1249
|
+
def numpy(self, *, force: bool = False) -> np.ndarray:
|
|
1250
|
+
"""""" # noqa: D419
|
|
1251
|
+
return self._data.numpy(force=force)
|
|
1252
|
+
|
|
1249
1253
|
# Helpers #################################################################
|
|
1250
1254
|
|
|
1251
1255
|
def _shallow_copy(self) -> 'EdgeIndex':
|
|
@@ -1478,7 +1482,7 @@ def _slice(
|
|
|
1478
1482
|
step: int = 1,
|
|
1479
1483
|
) -> Union[EdgeIndex, Tensor]:
|
|
1480
1484
|
|
|
1481
|
-
if ((start is None or start <=
|
|
1485
|
+
if ((start is None or start == 0 or start <= -input.size(dim))
|
|
1482
1486
|
and (end is None or end > input.size(dim)) and step == 1):
|
|
1483
1487
|
return input._shallow_copy() # No-op.
|
|
1484
1488
|
|
|
@@ -1,13 +1,14 @@
|
|
|
1
1
|
import logging
|
|
2
|
-
from typing import List, Optional, Union
|
|
2
|
+
from typing import Dict, List, Optional, Union, overload
|
|
3
3
|
|
|
4
4
|
import torch
|
|
5
5
|
from torch import Tensor
|
|
6
6
|
|
|
7
|
-
from torch_geometric.explain import Explanation
|
|
7
|
+
from torch_geometric.explain import Explanation, HeteroExplanation
|
|
8
8
|
from torch_geometric.explain.algorithm import ExplainerAlgorithm
|
|
9
9
|
from torch_geometric.explain.config import ExplanationType, ModelTaskLevel
|
|
10
10
|
from torch_geometric.nn.conv.message_passing import MessagePassing
|
|
11
|
+
from torch_geometric.typing import EdgeType, NodeType
|
|
11
12
|
|
|
12
13
|
|
|
13
14
|
class AttentionExplainer(ExplainerAlgorithm):
|
|
@@ -26,7 +27,9 @@ class AttentionExplainer(ExplainerAlgorithm):
|
|
|
26
27
|
def __init__(self, reduce: str = 'max'):
|
|
27
28
|
super().__init__()
|
|
28
29
|
self.reduce = reduce
|
|
30
|
+
self.is_hetero = False
|
|
29
31
|
|
|
32
|
+
@overload
|
|
30
33
|
def forward(
|
|
31
34
|
self,
|
|
32
35
|
model: torch.nn.Module,
|
|
@@ -37,65 +40,252 @@ class AttentionExplainer(ExplainerAlgorithm):
|
|
|
37
40
|
index: Optional[Union[int, Tensor]] = None,
|
|
38
41
|
**kwargs,
|
|
39
42
|
) -> Explanation:
|
|
40
|
-
|
|
41
|
-
raise ValueError(f"Heterogeneous graphs not yet supported in "
|
|
42
|
-
f"'{self.__class__.__name__}'")
|
|
43
|
+
...
|
|
43
44
|
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
45
|
+
@overload
|
|
46
|
+
def forward(
|
|
47
|
+
self,
|
|
48
|
+
model: torch.nn.Module,
|
|
49
|
+
x: Dict[NodeType, Tensor],
|
|
50
|
+
edge_index: Dict[EdgeType, Tensor],
|
|
51
|
+
*,
|
|
52
|
+
target: Tensor,
|
|
53
|
+
index: Optional[Union[int, Tensor]] = None,
|
|
54
|
+
**kwargs,
|
|
55
|
+
) -> HeteroExplanation:
|
|
56
|
+
...
|
|
57
|
+
|
|
58
|
+
def forward(
|
|
59
|
+
self,
|
|
60
|
+
model: torch.nn.Module,
|
|
61
|
+
x: Union[Tensor, Dict[NodeType, Tensor]],
|
|
62
|
+
edge_index: Union[Tensor, Dict[EdgeType, Tensor]],
|
|
63
|
+
*,
|
|
64
|
+
target: Tensor,
|
|
65
|
+
index: Optional[Union[int, Tensor]] = None,
|
|
66
|
+
**kwargs,
|
|
67
|
+
) -> Union[Explanation, HeteroExplanation]:
|
|
68
|
+
"""Generate explanations based on attention coefficients."""
|
|
69
|
+
self.is_hetero = isinstance(x, dict)
|
|
70
|
+
|
|
71
|
+
# Collect attention coefficients
|
|
72
|
+
alphas_dict = self._collect_attention_coefficients(
|
|
73
|
+
model, x, edge_index, **kwargs)
|
|
74
|
+
|
|
75
|
+
# Process attention coefficients
|
|
76
|
+
if self.is_hetero:
|
|
77
|
+
return self._create_hetero_explanation(model, alphas_dict,
|
|
78
|
+
edge_index, index, x)
|
|
79
|
+
else:
|
|
80
|
+
return self._create_homo_explanation(model, alphas_dict,
|
|
81
|
+
edge_index, index, x)
|
|
82
|
+
|
|
83
|
+
@overload
|
|
84
|
+
def _collect_attention_coefficients(
|
|
85
|
+
self,
|
|
86
|
+
model: torch.nn.Module,
|
|
87
|
+
x: Tensor,
|
|
88
|
+
edge_index: Tensor,
|
|
89
|
+
**kwargs,
|
|
90
|
+
) -> List[Tensor]:
|
|
91
|
+
...
|
|
92
|
+
|
|
93
|
+
@overload
|
|
94
|
+
def _collect_attention_coefficients(
|
|
95
|
+
self,
|
|
96
|
+
model: torch.nn.Module,
|
|
97
|
+
x: Dict[NodeType, Tensor],
|
|
98
|
+
edge_index: Dict[EdgeType, Tensor],
|
|
99
|
+
**kwargs,
|
|
100
|
+
) -> Dict[EdgeType, List[Tensor]]:
|
|
101
|
+
...
|
|
102
|
+
|
|
103
|
+
def _collect_attention_coefficients(
|
|
104
|
+
self,
|
|
105
|
+
model: torch.nn.Module,
|
|
106
|
+
x: Union[Tensor, Dict[NodeType, Tensor]],
|
|
107
|
+
edge_index: Union[Tensor, Dict[EdgeType, Tensor]],
|
|
108
|
+
**kwargs,
|
|
109
|
+
) -> Union[List[Tensor], Dict[EdgeType, List[Tensor]]]:
|
|
110
|
+
"""Collect attention coefficients from model layers."""
|
|
111
|
+
if self.is_hetero:
|
|
112
|
+
# For heterogeneous graphs, store alphas by edge type
|
|
113
|
+
alphas_dict: Dict[EdgeType, List[Tensor]] = {}
|
|
114
|
+
|
|
115
|
+
# Get list of edge types
|
|
116
|
+
edge_types = list(edge_index.keys())
|
|
117
|
+
|
|
118
|
+
# Hook function to capture attention coefficients by edge type
|
|
119
|
+
def hook(module, msg_kwargs, out):
|
|
120
|
+
# Find edge type from the module's full name
|
|
121
|
+
module_name = getattr(module, '_name', None)
|
|
122
|
+
if module_name is None:
|
|
123
|
+
return
|
|
50
124
|
|
|
51
|
-
|
|
125
|
+
edge_type = None
|
|
126
|
+
for edge_tuple in edge_types:
|
|
127
|
+
src_type, edge_name, dst_type = edge_tuple
|
|
128
|
+
# Check if all components appear in the module name in
|
|
129
|
+
# order
|
|
130
|
+
try:
|
|
131
|
+
src_idx = module_name.index(src_type)
|
|
132
|
+
edge_idx = module_name.index(edge_name, src_idx)
|
|
133
|
+
dst_idx = module_name.index(dst_type, edge_idx)
|
|
134
|
+
if src_idx < edge_idx < dst_idx:
|
|
135
|
+
edge_type = edge_tuple
|
|
136
|
+
break
|
|
137
|
+
except ValueError: # Component not found
|
|
138
|
+
continue
|
|
139
|
+
|
|
140
|
+
if edge_type is None:
|
|
141
|
+
return
|
|
142
|
+
|
|
143
|
+
if edge_type not in alphas_dict:
|
|
144
|
+
alphas_dict[edge_type] = []
|
|
145
|
+
|
|
146
|
+
# Extract alpha from message kwargs or module
|
|
147
|
+
if 'alpha' in msg_kwargs[0]:
|
|
148
|
+
alphas_dict[edge_type].append(
|
|
149
|
+
msg_kwargs[0]['alpha'].detach())
|
|
150
|
+
elif getattr(module, '_alpha', None) is not None:
|
|
151
|
+
alphas_dict[edge_type].append(module._alpha.detach())
|
|
152
|
+
else:
|
|
153
|
+
# For homogeneous graphs, store all alphas in a list
|
|
154
|
+
alphas: List[Tensor] = []
|
|
52
155
|
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
156
|
+
def hook(module, msg_kwargs, out):
|
|
157
|
+
if 'alpha' in msg_kwargs[0]:
|
|
158
|
+
alphas.append(msg_kwargs[0]['alpha'].detach())
|
|
159
|
+
elif getattr(module, '_alpha', None) is not None:
|
|
160
|
+
alphas.append(module._alpha.detach())
|
|
58
161
|
|
|
162
|
+
# Register hooks for all message passing modules
|
|
59
163
|
hook_handles = []
|
|
60
|
-
for module in model.
|
|
61
|
-
if
|
|
62
|
-
|
|
164
|
+
for name, module in model.named_modules():
|
|
165
|
+
if isinstance(module,
|
|
166
|
+
MessagePassing) and module.explain is not False:
|
|
167
|
+
# Store name for hetero graph lookup in the hook
|
|
168
|
+
if self.is_hetero:
|
|
169
|
+
module._name = name
|
|
170
|
+
|
|
63
171
|
hook_handles.append(module.register_message_forward_hook(hook))
|
|
64
172
|
|
|
173
|
+
# Forward pass to collect attention coefficients.
|
|
65
174
|
model(x, edge_index, **kwargs)
|
|
66
175
|
|
|
67
|
-
|
|
176
|
+
# Remove hooks
|
|
177
|
+
for handle in hook_handles:
|
|
68
178
|
handle.remove()
|
|
69
179
|
|
|
70
|
-
if
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
180
|
+
# Check if we collected any attention coefficients.
|
|
181
|
+
if self.is_hetero:
|
|
182
|
+
if not alphas_dict:
|
|
183
|
+
raise ValueError(
|
|
184
|
+
"Could not collect any attention coefficients. "
|
|
185
|
+
"Please ensure that your model is using "
|
|
186
|
+
"attention-based GNN layers.")
|
|
187
|
+
return alphas_dict
|
|
188
|
+
else:
|
|
189
|
+
if not alphas:
|
|
190
|
+
raise ValueError(
|
|
191
|
+
"Could not collect any attention coefficients. "
|
|
192
|
+
"Please ensure that your model is using "
|
|
193
|
+
"attention-based GNN layers.")
|
|
194
|
+
return alphas
|
|
74
195
|
|
|
196
|
+
def _process_attention_coefficients(
|
|
197
|
+
self,
|
|
198
|
+
alphas: List[Tensor],
|
|
199
|
+
edge_index_size: int,
|
|
200
|
+
) -> Tensor:
|
|
201
|
+
"""Process collected attention coefficients into a single mask."""
|
|
75
202
|
for i, alpha in enumerate(alphas):
|
|
76
|
-
|
|
203
|
+
# Ensure alpha doesn't exceed edge_index size
|
|
204
|
+
alpha = alpha[:edge_index_size]
|
|
205
|
+
|
|
206
|
+
# Reduce multi-head attention
|
|
77
207
|
if alpha.dim() == 2:
|
|
78
208
|
alpha = getattr(torch, self.reduce)(alpha, dim=-1)
|
|
79
|
-
if isinstance(alpha, tuple): #
|
|
209
|
+
if isinstance(alpha, tuple): # Handle torch.max output
|
|
80
210
|
alpha = alpha[0]
|
|
81
211
|
elif alpha.dim() > 2:
|
|
82
|
-
raise ValueError(f"
|
|
212
|
+
raise ValueError(f"Cannot reduce attention coefficients of "
|
|
83
213
|
f"shape {list(alpha.size())}")
|
|
84
214
|
alphas[i] = alpha
|
|
85
215
|
|
|
216
|
+
# Combine attention coefficients across layers
|
|
86
217
|
if len(alphas) > 1:
|
|
87
218
|
alpha = torch.stack(alphas, dim=-1)
|
|
88
219
|
alpha = getattr(torch, self.reduce)(alpha, dim=-1)
|
|
89
|
-
if isinstance(alpha, tuple): #
|
|
220
|
+
if isinstance(alpha, tuple): # Handle torch.max output
|
|
90
221
|
alpha = alpha[0]
|
|
91
222
|
else:
|
|
92
223
|
alpha = alphas[0]
|
|
93
224
|
|
|
225
|
+
return alpha
|
|
226
|
+
|
|
227
|
+
def _create_homo_explanation(
|
|
228
|
+
self,
|
|
229
|
+
model: torch.nn.Module,
|
|
230
|
+
alphas: List[Tensor],
|
|
231
|
+
edge_index: Tensor,
|
|
232
|
+
index: Optional[Union[int, Tensor]],
|
|
233
|
+
x: Tensor,
|
|
234
|
+
) -> Explanation:
|
|
235
|
+
"""Create explanation for homogeneous graph."""
|
|
236
|
+
# Get hard edge mask for node-level tasks
|
|
237
|
+
hard_edge_mask = None
|
|
238
|
+
if self.model_config.task_level == ModelTaskLevel.node:
|
|
239
|
+
_, hard_edge_mask = self._get_hard_masks(model, index, edge_index,
|
|
240
|
+
num_nodes=x.size(0))
|
|
241
|
+
|
|
242
|
+
# Process attention coefficients
|
|
243
|
+
alpha = self._process_attention_coefficients(alphas,
|
|
244
|
+
edge_index.size(1))
|
|
245
|
+
|
|
246
|
+
# Post-process mask with hard edge mask if needed
|
|
94
247
|
alpha = self._post_process_mask(alpha, hard_edge_mask,
|
|
95
248
|
apply_sigmoid=False)
|
|
96
249
|
|
|
97
250
|
return Explanation(edge_mask=alpha)
|
|
98
251
|
|
|
252
|
+
def _create_hetero_explanation(
|
|
253
|
+
self,
|
|
254
|
+
model: torch.nn.Module,
|
|
255
|
+
alphas_dict: Dict[EdgeType, List[Tensor]],
|
|
256
|
+
edge_index: Dict[EdgeType, Tensor],
|
|
257
|
+
index: Optional[Union[int, Tensor]],
|
|
258
|
+
x: Dict[NodeType, Tensor],
|
|
259
|
+
) -> HeteroExplanation:
|
|
260
|
+
"""Create explanation for heterogeneous graph."""
|
|
261
|
+
edge_masks_dict = {}
|
|
262
|
+
|
|
263
|
+
# Process each edge type separately
|
|
264
|
+
for edge_type, alphas in alphas_dict.items():
|
|
265
|
+
if not alphas:
|
|
266
|
+
continue
|
|
267
|
+
|
|
268
|
+
# Get hard edge mask for node-level tasks
|
|
269
|
+
hard_edge_mask = None
|
|
270
|
+
if self.model_config.task_level == ModelTaskLevel.node:
|
|
271
|
+
src_type, _, dst_type = edge_type
|
|
272
|
+
_, hard_edge_mask = self._get_hard_masks(
|
|
273
|
+
model, index, edge_index[edge_type],
|
|
274
|
+
num_nodes=max(x[src_type].size(0), x[dst_type].size(0)))
|
|
275
|
+
|
|
276
|
+
# Process attention coefficients for this edge type
|
|
277
|
+
alpha = self._process_attention_coefficients(
|
|
278
|
+
alphas, edge_index[edge_type].size(1))
|
|
279
|
+
|
|
280
|
+
# Apply hard mask if available
|
|
281
|
+
edge_masks_dict[edge_type] = self._post_process_mask(
|
|
282
|
+
alpha, hard_edge_mask, apply_sigmoid=False)
|
|
283
|
+
|
|
284
|
+
# Create heterogeneous explanation
|
|
285
|
+
explanation = HeteroExplanation()
|
|
286
|
+
explanation.set_value_dict('edge_mask', edge_masks_dict)
|
|
287
|
+
return explanation
|
|
288
|
+
|
|
99
289
|
def supports(self) -> bool:
|
|
100
290
|
explanation_type = self.explainer_config.explanation_type
|
|
101
291
|
if explanation_type != ExplanationType.model:
|
|
@@ -166,7 +166,7 @@ class ExplainerAlgorithm(torch.nn.Module):
|
|
|
166
166
|
elif self.model_config.return_type == ModelReturnType.probs:
|
|
167
167
|
loss_fn = F.binary_cross_entropy
|
|
168
168
|
else:
|
|
169
|
-
|
|
169
|
+
raise AssertionError()
|
|
170
170
|
|
|
171
171
|
return loss_fn(y_hat.view_as(y), y.float())
|
|
172
172
|
|
|
@@ -183,7 +183,7 @@ class ExplainerAlgorithm(torch.nn.Module):
|
|
|
183
183
|
elif self.model_config.return_type == ModelReturnType.log_probs:
|
|
184
184
|
loss_fn = F.nll_loss
|
|
185
185
|
else:
|
|
186
|
-
|
|
186
|
+
raise AssertionError()
|
|
187
187
|
|
|
188
188
|
return loss_fn(y_hat, y)
|
|
189
189
|
|
|
@@ -190,7 +190,7 @@ def to_captum_input(
|
|
|
190
190
|
|
|
191
191
|
Args:
|
|
192
192
|
x (torch.Tensor or Dict[NodeType, torch.Tensor]): The node features.
|
|
193
|
-
For heterogeneous graphs this is a dictionary holding node
|
|
193
|
+
For heterogeneous graphs this is a dictionary holding node features
|
|
194
194
|
for each node type.
|
|
195
195
|
edge_index(torch.Tensor or Dict[EdgeType, torch.Tensor]): The edge
|
|
196
196
|
indices. For heterogeneous graphs this is a dictionary holding the
|
|
@@ -73,7 +73,8 @@ class CaptumExplainer(ExplainerAlgorithm):
|
|
|
73
73
|
f"{self.attribution_method_class.__name__}")
|
|
74
74
|
|
|
75
75
|
if kwargs.get('internal_batch_size', 1) != 1:
|
|
76
|
-
warnings.warn("Overriding 'internal_batch_size' to 1"
|
|
76
|
+
warnings.warn("Overriding 'internal_batch_size' to 1",
|
|
77
|
+
stacklevel=2)
|
|
77
78
|
|
|
78
79
|
if 'internal_batch_size' in self._get_attribute_parameters():
|
|
79
80
|
kwargs['internal_batch_size'] = 1
|