pyg-nightly 2.7.0.dev20241009__py3-none-any.whl → 2.8.0.dev20251228__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251228.dist-info}/METADATA +77 -53
- {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251228.dist-info}/RECORD +227 -190
- {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251228.dist-info}/WHEEL +1 -1
- pyg_nightly-2.8.0.dev20251228.dist-info/licenses/LICENSE +19 -0
- torch_geometric/__init__.py +14 -2
- torch_geometric/_compile.py +9 -3
- torch_geometric/_onnx.py +214 -0
- torch_geometric/config_mixin.py +5 -3
- torch_geometric/config_store.py +1 -1
- torch_geometric/contrib/__init__.py +1 -1
- torch_geometric/contrib/explain/pgm_explainer.py +1 -1
- torch_geometric/data/batch.py +2 -2
- torch_geometric/data/collate.py +1 -3
- torch_geometric/data/data.py +109 -5
- torch_geometric/data/database.py +4 -0
- torch_geometric/data/dataset.py +14 -11
- torch_geometric/data/extract.py +1 -1
- torch_geometric/data/feature_store.py +17 -22
- torch_geometric/data/graph_store.py +3 -2
- torch_geometric/data/hetero_data.py +139 -7
- torch_geometric/data/hypergraph_data.py +2 -2
- torch_geometric/data/in_memory_dataset.py +2 -2
- torch_geometric/data/lightning/datamodule.py +42 -28
- torch_geometric/data/storage.py +9 -1
- torch_geometric/datasets/__init__.py +18 -1
- torch_geometric/datasets/actor.py +7 -9
- torch_geometric/datasets/airfrans.py +15 -17
- torch_geometric/datasets/airports.py +8 -10
- torch_geometric/datasets/amazon.py +8 -11
- torch_geometric/datasets/amazon_book.py +8 -9
- torch_geometric/datasets/amazon_products.py +7 -9
- torch_geometric/datasets/aminer.py +8 -9
- torch_geometric/datasets/aqsol.py +10 -13
- torch_geometric/datasets/attributed_graph_dataset.py +8 -10
- torch_geometric/datasets/ba_multi_shapes.py +10 -12
- torch_geometric/datasets/ba_shapes.py +5 -6
- torch_geometric/datasets/city.py +157 -0
- torch_geometric/datasets/dbp15k.py +1 -1
- torch_geometric/datasets/git_mol_dataset.py +263 -0
- torch_geometric/datasets/hgb_dataset.py +2 -2
- torch_geometric/datasets/hm.py +1 -1
- torch_geometric/datasets/instruct_mol_dataset.py +134 -0
- torch_geometric/datasets/md17.py +3 -3
- torch_geometric/datasets/medshapenet.py +145 -0
- torch_geometric/datasets/modelnet.py +1 -1
- torch_geometric/datasets/molecule_gpt_dataset.py +492 -0
- torch_geometric/datasets/molecule_net.py +3 -2
- torch_geometric/datasets/ppi.py +2 -1
- torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
- torch_geometric/datasets/qm7.py +1 -1
- torch_geometric/datasets/qm9.py +1 -1
- torch_geometric/datasets/snap_dataset.py +8 -4
- torch_geometric/datasets/tag_dataset.py +462 -0
- torch_geometric/datasets/teeth3ds.py +269 -0
- torch_geometric/datasets/web_qsp_dataset.py +310 -209
- torch_geometric/datasets/wikics.py +2 -1
- torch_geometric/deprecation.py +1 -1
- torch_geometric/distributed/__init__.py +13 -0
- torch_geometric/distributed/dist_loader.py +2 -2
- torch_geometric/distributed/partition.py +2 -2
- torch_geometric/distributed/rpc.py +3 -3
- torch_geometric/edge_index.py +18 -14
- torch_geometric/explain/algorithm/attention_explainer.py +219 -29
- torch_geometric/explain/algorithm/base.py +2 -2
- torch_geometric/explain/algorithm/captum.py +1 -1
- torch_geometric/explain/algorithm/captum_explainer.py +2 -1
- torch_geometric/explain/algorithm/gnn_explainer.py +406 -69
- torch_geometric/explain/algorithm/graphmask_explainer.py +8 -8
- torch_geometric/explain/algorithm/pg_explainer.py +305 -47
- torch_geometric/explain/explainer.py +2 -2
- torch_geometric/explain/explanation.py +87 -3
- torch_geometric/explain/metric/faithfulness.py +1 -1
- torch_geometric/graphgym/config.py +3 -2
- torch_geometric/graphgym/imports.py +15 -4
- torch_geometric/graphgym/logger.py +1 -1
- torch_geometric/graphgym/loss.py +1 -1
- torch_geometric/graphgym/models/encoder.py +2 -2
- torch_geometric/graphgym/models/layer.py +1 -1
- torch_geometric/graphgym/utils/comp_budget.py +4 -3
- torch_geometric/hash_tensor.py +798 -0
- torch_geometric/index.py +14 -5
- torch_geometric/inspector.py +4 -0
- torch_geometric/io/fs.py +5 -4
- torch_geometric/llm/__init__.py +9 -0
- torch_geometric/llm/large_graph_indexer.py +741 -0
- torch_geometric/llm/models/__init__.py +23 -0
- torch_geometric/{nn → llm}/models/g_retriever.py +77 -45
- torch_geometric/llm/models/git_mol.py +336 -0
- torch_geometric/llm/models/glem.py +397 -0
- torch_geometric/{nn/nlp → llm/models}/llm.py +180 -32
- torch_geometric/llm/models/llm_judge.py +158 -0
- torch_geometric/llm/models/molecule_gpt.py +222 -0
- torch_geometric/llm/models/protein_mpnn.py +333 -0
- torch_geometric/llm/models/sentence_transformer.py +188 -0
- torch_geometric/llm/models/txt2kg.py +353 -0
- torch_geometric/llm/models/vision_transformer.py +38 -0
- torch_geometric/llm/rag_loader.py +154 -0
- torch_geometric/llm/utils/__init__.py +10 -0
- torch_geometric/llm/utils/backend_utils.py +443 -0
- torch_geometric/llm/utils/feature_store.py +169 -0
- torch_geometric/llm/utils/graph_store.py +199 -0
- torch_geometric/llm/utils/vectorrag.py +125 -0
- torch_geometric/loader/cluster.py +4 -4
- torch_geometric/loader/ibmb_loader.py +4 -4
- torch_geometric/loader/link_loader.py +1 -1
- torch_geometric/loader/link_neighbor_loader.py +2 -1
- torch_geometric/loader/mixin.py +6 -5
- torch_geometric/loader/neighbor_loader.py +1 -1
- torch_geometric/loader/neighbor_sampler.py +2 -2
- torch_geometric/loader/prefetch.py +3 -2
- torch_geometric/loader/temporal_dataloader.py +2 -2
- torch_geometric/loader/utils.py +10 -10
- torch_geometric/metrics/__init__.py +14 -0
- torch_geometric/metrics/link_pred.py +745 -92
- torch_geometric/nn/__init__.py +1 -0
- torch_geometric/nn/aggr/base.py +1 -1
- torch_geometric/nn/aggr/equilibrium.py +1 -1
- torch_geometric/nn/aggr/fused.py +1 -1
- torch_geometric/nn/aggr/patch_transformer.py +8 -2
- torch_geometric/nn/aggr/set_transformer.py +1 -1
- torch_geometric/nn/aggr/utils.py +9 -4
- torch_geometric/nn/attention/__init__.py +9 -1
- torch_geometric/nn/attention/polynormer.py +107 -0
- torch_geometric/nn/attention/qformer.py +71 -0
- torch_geometric/nn/attention/sgformer.py +99 -0
- torch_geometric/nn/conv/__init__.py +2 -0
- torch_geometric/nn/conv/appnp.py +1 -1
- torch_geometric/nn/conv/cugraph/gat_conv.py +8 -2
- torch_geometric/nn/conv/cugraph/rgcn_conv.py +3 -0
- torch_geometric/nn/conv/cugraph/sage_conv.py +3 -0
- torch_geometric/nn/conv/dna_conv.py +1 -1
- torch_geometric/nn/conv/eg_conv.py +7 -7
- torch_geometric/nn/conv/gen_conv.py +1 -1
- torch_geometric/nn/conv/gravnet_conv.py +2 -1
- torch_geometric/nn/conv/hetero_conv.py +2 -1
- torch_geometric/nn/conv/meshcnn_conv.py +487 -0
- torch_geometric/nn/conv/message_passing.py +5 -4
- torch_geometric/nn/conv/rgcn_conv.py +2 -1
- torch_geometric/nn/conv/sg_conv.py +1 -1
- torch_geometric/nn/conv/spline_conv.py +2 -1
- torch_geometric/nn/conv/ssg_conv.py +1 -1
- torch_geometric/nn/conv/transformer_conv.py +5 -3
- torch_geometric/nn/data_parallel.py +5 -4
- torch_geometric/nn/dense/linear.py +0 -20
- torch_geometric/nn/encoding.py +17 -3
- torch_geometric/nn/fx.py +14 -12
- torch_geometric/nn/model_hub.py +2 -15
- torch_geometric/nn/models/__init__.py +11 -2
- torch_geometric/nn/models/attentive_fp.py +1 -1
- torch_geometric/nn/models/attract_repel.py +148 -0
- torch_geometric/nn/models/basic_gnn.py +2 -1
- torch_geometric/nn/models/captum.py +1 -1
- torch_geometric/nn/models/deep_graph_infomax.py +1 -1
- torch_geometric/nn/models/dimenet.py +2 -2
- torch_geometric/nn/models/dimenet_utils.py +4 -2
- torch_geometric/nn/models/gpse.py +1083 -0
- torch_geometric/nn/models/graph_unet.py +13 -4
- torch_geometric/nn/models/lpformer.py +783 -0
- torch_geometric/nn/models/metapath2vec.py +1 -1
- torch_geometric/nn/models/mlp.py +4 -2
- torch_geometric/nn/models/node2vec.py +1 -1
- torch_geometric/nn/models/polynormer.py +206 -0
- torch_geometric/nn/models/rev_gnn.py +3 -3
- torch_geometric/nn/models/sgformer.py +219 -0
- torch_geometric/nn/models/signed_gcn.py +1 -1
- torch_geometric/nn/models/visnet.py +2 -2
- torch_geometric/nn/norm/batch_norm.py +17 -7
- torch_geometric/nn/norm/diff_group_norm.py +7 -2
- torch_geometric/nn/norm/graph_norm.py +9 -4
- torch_geometric/nn/norm/instance_norm.py +5 -1
- torch_geometric/nn/norm/layer_norm.py +15 -7
- torch_geometric/nn/norm/msg_norm.py +8 -2
- torch_geometric/nn/pool/__init__.py +8 -4
- torch_geometric/nn/pool/cluster_pool.py +3 -4
- torch_geometric/nn/pool/connect/base.py +1 -3
- torch_geometric/nn/pool/knn.py +13 -10
- torch_geometric/nn/pool/select/base.py +1 -4
- torch_geometric/nn/to_hetero_module.py +4 -3
- torch_geometric/nn/to_hetero_transformer.py +3 -3
- torch_geometric/nn/to_hetero_with_bases_transformer.py +4 -4
- torch_geometric/profile/__init__.py +2 -0
- torch_geometric/profile/nvtx.py +66 -0
- torch_geometric/profile/utils.py +20 -5
- torch_geometric/sampler/__init__.py +2 -1
- torch_geometric/sampler/base.py +336 -7
- torch_geometric/sampler/hgt_sampler.py +11 -1
- torch_geometric/sampler/neighbor_sampler.py +296 -23
- torch_geometric/sampler/utils.py +93 -5
- torch_geometric/testing/__init__.py +4 -0
- torch_geometric/testing/decorators.py +35 -5
- torch_geometric/testing/distributed.py +1 -1
- torch_geometric/transforms/__init__.py +2 -0
- torch_geometric/transforms/add_gpse.py +49 -0
- torch_geometric/transforms/add_metapaths.py +8 -6
- torch_geometric/transforms/add_positional_encoding.py +2 -2
- torch_geometric/transforms/base_transform.py +2 -1
- torch_geometric/transforms/delaunay.py +65 -15
- torch_geometric/transforms/face_to_edge.py +32 -3
- torch_geometric/transforms/gdc.py +7 -8
- torch_geometric/transforms/largest_connected_components.py +1 -1
- torch_geometric/transforms/mask.py +5 -1
- torch_geometric/transforms/normalize_features.py +3 -3
- torch_geometric/transforms/random_link_split.py +1 -1
- torch_geometric/transforms/remove_duplicated_edges.py +4 -2
- torch_geometric/transforms/rooted_subgraph.py +1 -1
- torch_geometric/typing.py +70 -17
- torch_geometric/utils/__init__.py +4 -1
- torch_geometric/utils/_lexsort.py +0 -9
- torch_geometric/utils/_negative_sampling.py +27 -12
- torch_geometric/utils/_scatter.py +132 -195
- torch_geometric/utils/_sort_edge_index.py +0 -2
- torch_geometric/utils/_spmm.py +16 -14
- torch_geometric/utils/_subgraph.py +4 -0
- torch_geometric/utils/_to_dense_batch.py +2 -2
- torch_geometric/utils/_trim_to_layer.py +2 -2
- torch_geometric/utils/convert.py +17 -10
- torch_geometric/utils/cross_entropy.py +34 -13
- torch_geometric/utils/embedding.py +91 -2
- torch_geometric/utils/geodesic.py +4 -3
- torch_geometric/utils/influence.py +279 -0
- torch_geometric/utils/map.py +13 -9
- torch_geometric/utils/nested.py +1 -1
- torch_geometric/utils/smiles.py +3 -3
- torch_geometric/utils/sparse.py +7 -14
- torch_geometric/visualization/__init__.py +2 -1
- torch_geometric/visualization/graph.py +250 -5
- torch_geometric/warnings.py +11 -2
- torch_geometric/nn/nlp/__init__.py +0 -7
- torch_geometric/nn/nlp/sentence_transformer.py +0 -101
|
@@ -1,14 +1,24 @@
|
|
|
1
1
|
from math import sqrt
|
|
2
|
-
from typing import Optional, Tuple, Union
|
|
2
|
+
from typing import Dict, Optional, Tuple, Union, overload
|
|
3
3
|
|
|
4
4
|
import torch
|
|
5
5
|
from torch import Tensor
|
|
6
6
|
from torch.nn.parameter import Parameter
|
|
7
7
|
|
|
8
|
-
from torch_geometric.explain import
|
|
8
|
+
from torch_geometric.explain import (
|
|
9
|
+
ExplainerConfig,
|
|
10
|
+
Explanation,
|
|
11
|
+
HeteroExplanation,
|
|
12
|
+
ModelConfig,
|
|
13
|
+
)
|
|
9
14
|
from torch_geometric.explain.algorithm import ExplainerAlgorithm
|
|
10
|
-
from torch_geometric.explain.algorithm.utils import
|
|
15
|
+
from torch_geometric.explain.algorithm.utils import (
|
|
16
|
+
clear_masks,
|
|
17
|
+
set_hetero_masks,
|
|
18
|
+
set_masks,
|
|
19
|
+
)
|
|
11
20
|
from torch_geometric.explain.config import MaskType, ModelMode, ModelTaskLevel
|
|
21
|
+
from torch_geometric.typing import EdgeType, NodeType
|
|
12
22
|
|
|
13
23
|
|
|
14
24
|
class GNNExplainer(ExplainerAlgorithm):
|
|
@@ -51,7 +61,7 @@ class GNNExplainer(ExplainerAlgorithm):
|
|
|
51
61
|
:attr:`~torch_geometric.explain.algorithm.GNNExplainer.coeffs`.
|
|
52
62
|
"""
|
|
53
63
|
|
|
54
|
-
|
|
64
|
+
default_coeffs = {
|
|
55
65
|
'edge_size': 0.005,
|
|
56
66
|
'edge_reduction': 'sum',
|
|
57
67
|
'node_feat_size': 1.0,
|
|
@@ -65,11 +75,14 @@ class GNNExplainer(ExplainerAlgorithm):
|
|
|
65
75
|
super().__init__()
|
|
66
76
|
self.epochs = epochs
|
|
67
77
|
self.lr = lr
|
|
78
|
+
self.coeffs = dict(self.default_coeffs)
|
|
68
79
|
self.coeffs.update(kwargs)
|
|
69
80
|
|
|
70
81
|
self.node_mask = self.hard_node_mask = None
|
|
71
82
|
self.edge_mask = self.hard_edge_mask = None
|
|
83
|
+
self.is_hetero = False
|
|
72
84
|
|
|
85
|
+
@overload
|
|
73
86
|
def forward(
|
|
74
87
|
self,
|
|
75
88
|
model: torch.nn.Module,
|
|
@@ -80,30 +93,87 @@ class GNNExplainer(ExplainerAlgorithm):
|
|
|
80
93
|
index: Optional[Union[int, Tensor]] = None,
|
|
81
94
|
**kwargs,
|
|
82
95
|
) -> Explanation:
|
|
83
|
-
|
|
84
|
-
raise ValueError(f"Heterogeneous graphs not yet supported in "
|
|
85
|
-
f"'{self.__class__.__name__}'")
|
|
96
|
+
...
|
|
86
97
|
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
98
|
+
@overload
|
|
99
|
+
def forward(
|
|
100
|
+
self,
|
|
101
|
+
model: torch.nn.Module,
|
|
102
|
+
x: Dict[NodeType, Tensor],
|
|
103
|
+
edge_index: Dict[EdgeType, Tensor],
|
|
104
|
+
*,
|
|
105
|
+
target: Tensor,
|
|
106
|
+
index: Optional[Union[int, Tensor]] = None,
|
|
107
|
+
**kwargs,
|
|
108
|
+
) -> HeteroExplanation:
|
|
109
|
+
...
|
|
99
110
|
|
|
111
|
+
def forward(
|
|
112
|
+
self,
|
|
113
|
+
model: torch.nn.Module,
|
|
114
|
+
x: Union[Tensor, Dict[NodeType, Tensor]],
|
|
115
|
+
edge_index: Union[Tensor, Dict[EdgeType, Tensor]],
|
|
116
|
+
*,
|
|
117
|
+
target: Tensor,
|
|
118
|
+
index: Optional[Union[int, Tensor]] = None,
|
|
119
|
+
**kwargs,
|
|
120
|
+
) -> Union[Explanation, HeteroExplanation]:
|
|
121
|
+
self.is_hetero = isinstance(x, dict)
|
|
122
|
+
self._train(model, x, edge_index, target=target, index=index, **kwargs)
|
|
123
|
+
explanation = self._create_explanation()
|
|
100
124
|
self._clean_model(model)
|
|
125
|
+
return explanation
|
|
126
|
+
|
|
127
|
+
def _create_explanation(self) -> Union[Explanation, HeteroExplanation]:
|
|
128
|
+
"""Create an explanation object from the current masks."""
|
|
129
|
+
if self.is_hetero:
|
|
130
|
+
# For heterogeneous graphs, process each type separately
|
|
131
|
+
node_mask_dict = {}
|
|
132
|
+
edge_mask_dict = {}
|
|
133
|
+
|
|
134
|
+
for node_type, mask in self.node_mask.items():
|
|
135
|
+
if mask is not None:
|
|
136
|
+
node_mask_dict[node_type] = self._post_process_mask(
|
|
137
|
+
mask,
|
|
138
|
+
self.hard_node_mask[node_type],
|
|
139
|
+
apply_sigmoid=True,
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
for edge_type, mask in self.edge_mask.items():
|
|
143
|
+
if mask is not None:
|
|
144
|
+
edge_mask_dict[edge_type] = self._post_process_mask(
|
|
145
|
+
mask,
|
|
146
|
+
self.hard_edge_mask[edge_type],
|
|
147
|
+
apply_sigmoid=True,
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
# Create heterogeneous explanation
|
|
151
|
+
explanation = HeteroExplanation()
|
|
152
|
+
explanation.set_value_dict('node_mask', node_mask_dict)
|
|
153
|
+
explanation.set_value_dict('edge_mask', edge_mask_dict)
|
|
101
154
|
|
|
102
|
-
|
|
155
|
+
else:
|
|
156
|
+
# For homogeneous graphs, process single masks
|
|
157
|
+
node_mask = self._post_process_mask(
|
|
158
|
+
self.node_mask,
|
|
159
|
+
self.hard_node_mask,
|
|
160
|
+
apply_sigmoid=True,
|
|
161
|
+
)
|
|
162
|
+
edge_mask = self._post_process_mask(
|
|
163
|
+
self.edge_mask,
|
|
164
|
+
self.hard_edge_mask,
|
|
165
|
+
apply_sigmoid=True,
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
# Create homogeneous explanation
|
|
169
|
+
explanation = Explanation(node_mask=node_mask, edge_mask=edge_mask)
|
|
170
|
+
|
|
171
|
+
return explanation
|
|
103
172
|
|
|
104
173
|
def supports(self) -> bool:
|
|
105
174
|
return True
|
|
106
175
|
|
|
176
|
+
@overload
|
|
107
177
|
def _train(
|
|
108
178
|
self,
|
|
109
179
|
model: torch.nn.Module,
|
|
@@ -113,57 +183,222 @@ class GNNExplainer(ExplainerAlgorithm):
|
|
|
113
183
|
target: Tensor,
|
|
114
184
|
index: Optional[Union[int, Tensor]] = None,
|
|
115
185
|
**kwargs,
|
|
116
|
-
):
|
|
186
|
+
) -> None:
|
|
187
|
+
...
|
|
188
|
+
|
|
189
|
+
@overload
|
|
190
|
+
def _train(
|
|
191
|
+
self,
|
|
192
|
+
model: torch.nn.Module,
|
|
193
|
+
x: Dict[NodeType, Tensor],
|
|
194
|
+
edge_index: Dict[EdgeType, Tensor],
|
|
195
|
+
*,
|
|
196
|
+
target: Tensor,
|
|
197
|
+
index: Optional[Union[int, Tensor]] = None,
|
|
198
|
+
**kwargs,
|
|
199
|
+
) -> None:
|
|
200
|
+
...
|
|
201
|
+
|
|
202
|
+
def _train(
|
|
203
|
+
self,
|
|
204
|
+
model: torch.nn.Module,
|
|
205
|
+
x: Union[Tensor, Dict[NodeType, Tensor]],
|
|
206
|
+
edge_index: Union[Tensor, Dict[EdgeType, Tensor]],
|
|
207
|
+
*,
|
|
208
|
+
target: Tensor,
|
|
209
|
+
index: Optional[Union[int, Tensor]] = None,
|
|
210
|
+
**kwargs,
|
|
211
|
+
) -> None:
|
|
212
|
+
# Initialize masks based on input type
|
|
117
213
|
self._initialize_masks(x, edge_index)
|
|
118
214
|
|
|
119
|
-
parameters
|
|
120
|
-
|
|
121
|
-
parameters.append(self.node_mask)
|
|
122
|
-
if self.edge_mask is not None:
|
|
123
|
-
set_masks(model, self.edge_mask, edge_index, apply_sigmoid=True)
|
|
124
|
-
parameters.append(self.edge_mask)
|
|
215
|
+
# Collect parameters for optimization
|
|
216
|
+
parameters = self._collect_parameters(model, edge_index)
|
|
125
217
|
|
|
218
|
+
# Create optimizer
|
|
126
219
|
optimizer = torch.optim.Adam(parameters, lr=self.lr)
|
|
127
220
|
|
|
221
|
+
# Training loop
|
|
128
222
|
for i in range(self.epochs):
|
|
129
223
|
optimizer.zero_grad()
|
|
130
224
|
|
|
131
|
-
|
|
132
|
-
y_hat
|
|
225
|
+
# Forward pass with masked inputs
|
|
226
|
+
y_hat = self._forward_with_masks(model, x, edge_index, **kwargs)
|
|
227
|
+
y = target
|
|
133
228
|
|
|
229
|
+
# Handle index if provided
|
|
134
230
|
if index is not None:
|
|
135
231
|
y_hat, y = y_hat[index], y[index]
|
|
136
232
|
|
|
233
|
+
# Calculate loss
|
|
137
234
|
loss = self._loss(y_hat, y)
|
|
138
235
|
|
|
236
|
+
# Backward pass
|
|
139
237
|
loss.backward()
|
|
140
238
|
optimizer.step()
|
|
141
239
|
|
|
142
|
-
# In the first iteration,
|
|
143
|
-
#
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
240
|
+
# In the first iteration, collect gradients to identify important
|
|
241
|
+
# nodes/edges
|
|
242
|
+
if i == 0:
|
|
243
|
+
self._collect_gradients()
|
|
244
|
+
|
|
245
|
+
def _collect_parameters(self, model, edge_index):
|
|
246
|
+
"""Collect parameters for optimization."""
|
|
247
|
+
parameters = []
|
|
248
|
+
|
|
249
|
+
if self.is_hetero:
|
|
250
|
+
# For heterogeneous graphs, collect parameters from all types
|
|
251
|
+
for mask in self.node_mask.values():
|
|
252
|
+
if mask is not None:
|
|
253
|
+
parameters.append(mask)
|
|
254
|
+
if any(v is not None for v in self.edge_mask.values()):
|
|
255
|
+
set_hetero_masks(model, self.edge_mask, edge_index)
|
|
256
|
+
for mask in self.edge_mask.values():
|
|
257
|
+
if mask is not None:
|
|
258
|
+
parameters.append(mask)
|
|
259
|
+
else:
|
|
260
|
+
# For homogeneous graphs, collect single parameters
|
|
261
|
+
if self.node_mask is not None:
|
|
262
|
+
parameters.append(self.node_mask)
|
|
263
|
+
if self.edge_mask is not None:
|
|
264
|
+
set_masks(model, self.edge_mask, edge_index,
|
|
265
|
+
apply_sigmoid=True)
|
|
266
|
+
parameters.append(self.edge_mask)
|
|
267
|
+
|
|
268
|
+
return parameters
|
|
269
|
+
|
|
270
|
+
@overload
|
|
271
|
+
def _forward_with_masks(
|
|
272
|
+
self,
|
|
273
|
+
model: torch.nn.Module,
|
|
274
|
+
x: Tensor,
|
|
275
|
+
edge_index: Tensor,
|
|
276
|
+
**kwargs,
|
|
277
|
+
) -> Tensor:
|
|
278
|
+
...
|
|
279
|
+
|
|
280
|
+
@overload
|
|
281
|
+
def _forward_with_masks(
|
|
282
|
+
self,
|
|
283
|
+
model: torch.nn.Module,
|
|
284
|
+
x: Dict[NodeType, Tensor],
|
|
285
|
+
edge_index: Dict[EdgeType, Tensor],
|
|
286
|
+
**kwargs,
|
|
287
|
+
) -> Tensor:
|
|
288
|
+
...
|
|
289
|
+
|
|
290
|
+
def _forward_with_masks(
|
|
291
|
+
self,
|
|
292
|
+
model: torch.nn.Module,
|
|
293
|
+
x: Union[Tensor, Dict[NodeType, Tensor]],
|
|
294
|
+
edge_index: Union[Tensor, Dict[EdgeType, Tensor]],
|
|
295
|
+
**kwargs,
|
|
296
|
+
) -> Tensor:
|
|
297
|
+
"""Forward pass with masked inputs."""
|
|
298
|
+
if self.is_hetero:
|
|
299
|
+
# Apply masks to heterogeneous inputs
|
|
300
|
+
h_dict = {}
|
|
301
|
+
for node_type, features in x.items():
|
|
302
|
+
if node_type in self.node_mask and self.node_mask[
|
|
303
|
+
node_type] is not None:
|
|
304
|
+
h_dict[node_type] = features * self.node_mask[
|
|
305
|
+
node_type].sigmoid()
|
|
306
|
+
else:
|
|
307
|
+
h_dict[node_type] = features
|
|
308
|
+
|
|
309
|
+
# Forward pass with masked features
|
|
310
|
+
return model(h_dict, edge_index, **kwargs)
|
|
311
|
+
else:
|
|
312
|
+
# Apply mask to homogeneous input
|
|
313
|
+
h = x if self.node_mask is None else x * self.node_mask.sigmoid()
|
|
314
|
+
|
|
315
|
+
# Forward pass with masked features
|
|
316
|
+
return model(h, edge_index, **kwargs)
|
|
317
|
+
|
|
318
|
+
def _initialize_masks(
|
|
319
|
+
self,
|
|
320
|
+
x: Union[Tensor, Dict[NodeType, Tensor]],
|
|
321
|
+
edge_index: Union[Tensor, Dict[EdgeType, Tensor]],
|
|
322
|
+
) -> None:
|
|
161
323
|
node_mask_type = self.explainer_config.node_mask_type
|
|
162
324
|
edge_mask_type = self.explainer_config.edge_mask_type
|
|
163
325
|
|
|
164
|
-
|
|
165
|
-
|
|
326
|
+
if self.is_hetero:
|
|
327
|
+
# Initialize dictionaries for heterogeneous masks
|
|
328
|
+
self.node_mask = {}
|
|
329
|
+
self.hard_node_mask = {}
|
|
330
|
+
self.edge_mask = {}
|
|
331
|
+
self.hard_edge_mask = {}
|
|
332
|
+
|
|
333
|
+
# Initialize node masks for each node type
|
|
334
|
+
for node_type, features in x.items():
|
|
335
|
+
device = features.device
|
|
336
|
+
N, F = features.size()
|
|
337
|
+
self._initialize_node_mask(node_mask_type, node_type, N, F,
|
|
338
|
+
device)
|
|
339
|
+
|
|
340
|
+
# Initialize edge masks for each edge type
|
|
341
|
+
for edge_type, indices in edge_index.items():
|
|
342
|
+
device = indices.device
|
|
343
|
+
E = indices.size(1)
|
|
344
|
+
N = max(indices.max().item() + 1,
|
|
345
|
+
max(feat.size(0) for feat in x.values()))
|
|
346
|
+
self._initialize_edge_mask(edge_mask_type, edge_type, E, N,
|
|
347
|
+
device)
|
|
348
|
+
else:
|
|
349
|
+
# Initialize masks for homogeneous graph
|
|
350
|
+
device = x.device
|
|
351
|
+
(N, F), E = x.size(), edge_index.size(1)
|
|
352
|
+
|
|
353
|
+
# Initialize homogeneous node and edge masks
|
|
354
|
+
self._initialize_homogeneous_masks(node_mask_type, edge_mask_type,
|
|
355
|
+
N, F, E, device)
|
|
356
|
+
|
|
357
|
+
def _initialize_node_mask(
|
|
358
|
+
self,
|
|
359
|
+
node_mask_type,
|
|
360
|
+
node_type,
|
|
361
|
+
N,
|
|
362
|
+
F,
|
|
363
|
+
device,
|
|
364
|
+
) -> None:
|
|
365
|
+
"""Initialize node mask for a specific node type."""
|
|
366
|
+
std = 0.1
|
|
367
|
+
if node_mask_type is None:
|
|
368
|
+
self.node_mask[node_type] = None
|
|
369
|
+
self.hard_node_mask[node_type] = None
|
|
370
|
+
elif node_mask_type == MaskType.object:
|
|
371
|
+
self.node_mask[node_type] = Parameter(
|
|
372
|
+
torch.randn(N, 1, device=device) * std)
|
|
373
|
+
self.hard_node_mask[node_type] = None
|
|
374
|
+
elif node_mask_type == MaskType.attributes:
|
|
375
|
+
self.node_mask[node_type] = Parameter(
|
|
376
|
+
torch.randn(N, F, device=device) * std)
|
|
377
|
+
self.hard_node_mask[node_type] = None
|
|
378
|
+
elif node_mask_type == MaskType.common_attributes:
|
|
379
|
+
self.node_mask[node_type] = Parameter(
|
|
380
|
+
torch.randn(1, F, device=device) * std)
|
|
381
|
+
self.hard_node_mask[node_type] = None
|
|
382
|
+
else:
|
|
383
|
+
raise ValueError(f"Invalid node mask type: {node_mask_type}")
|
|
384
|
+
|
|
385
|
+
def _initialize_edge_mask(self, edge_mask_type, edge_type, E, N, device):
|
|
386
|
+
"""Initialize edge mask for a specific edge type."""
|
|
387
|
+
if edge_mask_type is None:
|
|
388
|
+
self.edge_mask[edge_type] = None
|
|
389
|
+
self.hard_edge_mask[edge_type] = None
|
|
390
|
+
elif edge_mask_type == MaskType.object:
|
|
391
|
+
std = torch.nn.init.calculate_gain('relu') * sqrt(2.0 / (2 * N))
|
|
392
|
+
self.edge_mask[edge_type] = Parameter(
|
|
393
|
+
torch.randn(E, device=device) * std)
|
|
394
|
+
self.hard_edge_mask[edge_type] = None
|
|
395
|
+
else:
|
|
396
|
+
raise ValueError(f"Invalid edge mask type: {edge_mask_type}")
|
|
166
397
|
|
|
398
|
+
def _initialize_homogeneous_masks(self, node_mask_type, edge_mask_type, N,
|
|
399
|
+
F, E, device):
|
|
400
|
+
"""Initialize masks for homogeneous graph."""
|
|
401
|
+
# Initialize node mask
|
|
167
402
|
std = 0.1
|
|
168
403
|
if node_mask_type is None:
|
|
169
404
|
self.node_mask = None
|
|
@@ -174,43 +409,145 @@ class GNNExplainer(ExplainerAlgorithm):
|
|
|
174
409
|
elif node_mask_type == MaskType.common_attributes:
|
|
175
410
|
self.node_mask = Parameter(torch.randn(1, F, device=device) * std)
|
|
176
411
|
else:
|
|
177
|
-
|
|
412
|
+
raise ValueError(f"Invalid node mask type: {node_mask_type}")
|
|
178
413
|
|
|
414
|
+
# Initialize edge mask
|
|
179
415
|
if edge_mask_type is None:
|
|
180
416
|
self.edge_mask = None
|
|
181
417
|
elif edge_mask_type == MaskType.object:
|
|
182
418
|
std = torch.nn.init.calculate_gain('relu') * sqrt(2.0 / (2 * N))
|
|
183
419
|
self.edge_mask = Parameter(torch.randn(E, device=device) * std)
|
|
184
420
|
else:
|
|
185
|
-
|
|
421
|
+
raise ValueError(f"Invalid edge mask type: {edge_mask_type}")
|
|
422
|
+
|
|
423
|
+
def _collect_gradients(self) -> None:
|
|
424
|
+
if self.is_hetero:
|
|
425
|
+
self._collect_hetero_gradients()
|
|
426
|
+
else:
|
|
427
|
+
self._collect_homo_gradients()
|
|
428
|
+
|
|
429
|
+
def _collect_hetero_gradients(self):
|
|
430
|
+
"""Collect gradients for heterogeneous graph."""
|
|
431
|
+
for node_type, mask in self.node_mask.items():
|
|
432
|
+
if mask is not None:
|
|
433
|
+
if mask.grad is None:
|
|
434
|
+
raise ValueError(
|
|
435
|
+
f"Could not compute gradients for node masks of type "
|
|
436
|
+
f"'{node_type}'. Please make sure that node masks are "
|
|
437
|
+
f"used inside the model or disable it via "
|
|
438
|
+
f"`node_mask_type=None`.")
|
|
439
|
+
|
|
440
|
+
self.hard_node_mask[node_type] = mask.grad != 0.0
|
|
441
|
+
|
|
442
|
+
for edge_type, mask in self.edge_mask.items():
|
|
443
|
+
if mask is not None:
|
|
444
|
+
if mask.grad is None:
|
|
445
|
+
raise ValueError(
|
|
446
|
+
f"Could not compute gradients for edge masks of type "
|
|
447
|
+
f"'{edge_type}'. Please make sure that edge masks are "
|
|
448
|
+
f"used inside the model or disable it via "
|
|
449
|
+
f"`edge_mask_type=None`.")
|
|
450
|
+
self.hard_edge_mask[edge_type] = mask.grad != 0.0
|
|
451
|
+
|
|
452
|
+
def _collect_homo_gradients(self):
|
|
453
|
+
"""Collect gradients for homogeneous graph."""
|
|
454
|
+
if self.node_mask is not None:
|
|
455
|
+
if self.node_mask.grad is None:
|
|
456
|
+
raise ValueError("Could not compute gradients for node "
|
|
457
|
+
"features. Please make sure that node "
|
|
458
|
+
"features are used inside the model or "
|
|
459
|
+
"disable it via `node_mask_type=None`.")
|
|
460
|
+
self.hard_node_mask = self.node_mask.grad != 0.0
|
|
461
|
+
|
|
462
|
+
if self.edge_mask is not None:
|
|
463
|
+
if self.edge_mask.grad is None:
|
|
464
|
+
raise ValueError("Could not compute gradients for edges. "
|
|
465
|
+
"Please make sure that edges are used "
|
|
466
|
+
"via message passing inside the model or "
|
|
467
|
+
"disable it via `edge_mask_type=None`.")
|
|
468
|
+
self.hard_edge_mask = self.edge_mask.grad != 0.0
|
|
186
469
|
|
|
187
470
|
def _loss(self, y_hat: Tensor, y: Tensor) -> Tensor:
|
|
471
|
+
# Calculate base loss based on model configuration
|
|
472
|
+
loss = self._calculate_base_loss(y_hat, y)
|
|
473
|
+
|
|
474
|
+
# Apply regularization based on graph type
|
|
475
|
+
if self.is_hetero:
|
|
476
|
+
# Apply regularization for heterogeneous graph
|
|
477
|
+
loss = self._apply_hetero_regularization(loss)
|
|
478
|
+
else:
|
|
479
|
+
# Apply regularization for homogeneous graph
|
|
480
|
+
loss = self._apply_homo_regularization(loss)
|
|
481
|
+
|
|
482
|
+
return loss
|
|
483
|
+
|
|
484
|
+
def _calculate_base_loss(self, y_hat, y):
|
|
485
|
+
"""Calculate base loss based on model configuration."""
|
|
188
486
|
if self.model_config.mode == ModelMode.binary_classification:
|
|
189
|
-
|
|
487
|
+
return self._loss_binary_classification(y_hat, y)
|
|
190
488
|
elif self.model_config.mode == ModelMode.multiclass_classification:
|
|
191
|
-
|
|
489
|
+
return self._loss_multiclass_classification(y_hat, y)
|
|
192
490
|
elif self.model_config.mode == ModelMode.regression:
|
|
193
|
-
|
|
491
|
+
return self._loss_regression(y_hat, y)
|
|
194
492
|
else:
|
|
195
|
-
|
|
493
|
+
raise ValueError(f"Invalid model mode: {self.model_config.mode}")
|
|
494
|
+
|
|
495
|
+
def _apply_hetero_regularization(self, loss):
|
|
496
|
+
"""Apply regularization for heterogeneous graph."""
|
|
497
|
+
# Apply regularization for each edge type
|
|
498
|
+
for edge_type, mask in self.edge_mask.items():
|
|
499
|
+
if (mask is not None
|
|
500
|
+
and self.hard_edge_mask[edge_type] is not None):
|
|
501
|
+
loss = self._add_mask_regularization(
|
|
502
|
+
loss, mask, self.hard_edge_mask[edge_type],
|
|
503
|
+
self.coeffs['edge_size'], self.coeffs['edge_reduction'],
|
|
504
|
+
self.coeffs['edge_ent'])
|
|
505
|
+
|
|
506
|
+
# Apply regularization for each node type
|
|
507
|
+
for node_type, mask in self.node_mask.items():
|
|
508
|
+
if (mask is not None
|
|
509
|
+
and self.hard_node_mask[node_type] is not None):
|
|
510
|
+
loss = self._add_mask_regularization(
|
|
511
|
+
loss, mask, self.hard_node_mask[node_type],
|
|
512
|
+
self.coeffs['node_feat_size'],
|
|
513
|
+
self.coeffs['node_feat_reduction'],
|
|
514
|
+
self.coeffs['node_feat_ent'])
|
|
196
515
|
|
|
516
|
+
return loss
|
|
517
|
+
|
|
518
|
+
def _apply_homo_regularization(self, loss):
|
|
519
|
+
"""Apply regularization for homogeneous graph."""
|
|
520
|
+
# Apply regularization for edge mask
|
|
197
521
|
if self.hard_edge_mask is not None:
|
|
198
522
|
assert self.edge_mask is not None
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
loss = loss + self.coeffs['edge_ent'] * ent.mean()
|
|
523
|
+
loss = self._add_mask_regularization(loss, self.edge_mask,
|
|
524
|
+
self.hard_edge_mask,
|
|
525
|
+
self.coeffs['edge_size'],
|
|
526
|
+
self.coeffs['edge_reduction'],
|
|
527
|
+
self.coeffs['edge_ent'])
|
|
205
528
|
|
|
529
|
+
# Apply regularization for node mask
|
|
206
530
|
if self.hard_node_mask is not None:
|
|
207
531
|
assert self.node_mask is not None
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
532
|
+
loss = self._add_mask_regularization(
|
|
533
|
+
loss, self.node_mask, self.hard_node_mask,
|
|
534
|
+
self.coeffs['node_feat_size'],
|
|
535
|
+
self.coeffs['node_feat_reduction'],
|
|
536
|
+
self.coeffs['node_feat_ent'])
|
|
537
|
+
|
|
538
|
+
return loss
|
|
539
|
+
|
|
540
|
+
def _add_mask_regularization(self, loss, mask, hard_mask, size_coeff,
|
|
541
|
+
reduction_name, ent_coeff):
|
|
542
|
+
"""Add size and entropy regularization for a mask."""
|
|
543
|
+
m = mask[hard_mask].sigmoid()
|
|
544
|
+
reduce_fn = getattr(torch, reduction_name)
|
|
545
|
+
# Add size regularization
|
|
546
|
+
loss = loss + size_coeff * reduce_fn(m)
|
|
547
|
+
# Add entropy regularization
|
|
548
|
+
ent = -m * torch.log(m + self.coeffs['EPS']) - (
|
|
549
|
+
1 - m) * torch.log(1 - m + self.coeffs['EPS'])
|
|
550
|
+
loss = loss + ent_coeff * ent.mean()
|
|
214
551
|
|
|
215
552
|
return loss
|
|
216
553
|
|
|
@@ -223,7 +560,7 @@ class GNNExplainer(ExplainerAlgorithm):
|
|
|
223
560
|
class GNNExplainer_:
|
|
224
561
|
r"""Deprecated version for :class:`GNNExplainer`."""
|
|
225
562
|
|
|
226
|
-
coeffs = GNNExplainer.
|
|
563
|
+
coeffs = GNNExplainer.default_coeffs
|
|
227
564
|
|
|
228
565
|
conversion_node_mask_type = {
|
|
229
566
|
'feature': 'common_attributes',
|
|
@@ -202,25 +202,25 @@ class GraphMaskExplainer(ExplainerAlgorithm):
|
|
|
202
202
|
|
|
203
203
|
baselines, self.gates, full_biases = [], torch.nn.ModuleList(), []
|
|
204
204
|
|
|
205
|
-
for v_dim, m_dim,
|
|
205
|
+
for v_dim, m_dim, o_dim in zip(i_dim, j_dim, h_dim):
|
|
206
206
|
self.transform, self.layer_norm = [], []
|
|
207
207
|
input_dims = [v_dim, m_dim, v_dim]
|
|
208
208
|
for _, input_dim in enumerate(input_dims):
|
|
209
209
|
self.transform.append(
|
|
210
|
-
Linear(input_dim,
|
|
211
|
-
self.layer_norm.append(LayerNorm(
|
|
210
|
+
Linear(input_dim, o_dim, bias=False).to(device))
|
|
211
|
+
self.layer_norm.append(LayerNorm(o_dim).to(device))
|
|
212
212
|
|
|
213
213
|
self.transforms = torch.nn.ModuleList(self.transform)
|
|
214
214
|
self.layer_norms = torch.nn.ModuleList(self.layer_norm)
|
|
215
215
|
|
|
216
216
|
self.full_bias = Parameter(
|
|
217
|
-
torch.tensor(
|
|
217
|
+
torch.tensor(o_dim, dtype=torch.float, device=device))
|
|
218
218
|
full_biases.append(self.full_bias)
|
|
219
219
|
|
|
220
|
-
self.reset_parameters(input_dims,
|
|
220
|
+
self.reset_parameters(input_dims, o_dim)
|
|
221
221
|
|
|
222
222
|
self.non_linear = ReLU()
|
|
223
|
-
self.output_layer = Linear(
|
|
223
|
+
self.output_layer = Linear(o_dim, 1).to(device)
|
|
224
224
|
|
|
225
225
|
gate = [
|
|
226
226
|
self.transforms, self.layer_norms, self.non_linear,
|
|
@@ -274,7 +274,7 @@ class GraphMaskExplainer(ExplainerAlgorithm):
|
|
|
274
274
|
elif self.model_config.mode == ModelMode.regression:
|
|
275
275
|
loss = self._loss_regression(y_hat, y)
|
|
276
276
|
else:
|
|
277
|
-
|
|
277
|
+
raise AssertionError()
|
|
278
278
|
|
|
279
279
|
g = torch.relu(loss - self.allowance).mean()
|
|
280
280
|
f = penalty * self.penalty_scaling
|
|
@@ -385,7 +385,7 @@ class GraphMaskExplainer(ExplainerAlgorithm):
|
|
|
385
385
|
f'Train explainer for graph {index} with layer '
|
|
386
386
|
f'{layer}')
|
|
387
387
|
self._enable_layer(layer)
|
|
388
|
-
for
|
|
388
|
+
for _ in range(self.epochs):
|
|
389
389
|
with torch.no_grad():
|
|
390
390
|
model(x, edge_index, **kwargs)
|
|
391
391
|
gates, total_penalty = [], 0
|