pyg-nightly 2.7.0.dev20250416__py3-none-any.whl → 2.7.0.dev20250417__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pyg_nightly-2.7.0.dev20250416.dist-info → pyg_nightly-2.7.0.dev20250417.dist-info}/METADATA +1 -1
- {pyg_nightly-2.7.0.dev20250416.dist-info → pyg_nightly-2.7.0.dev20250417.dist-info}/RECORD +8 -8
- torch_geometric/__init__.py +1 -1
- torch_geometric/explain/algorithm/pg_explainer.py +305 -47
- torch_geometric/utils/__init__.py +2 -1
- torch_geometric/utils/embedding.py +88 -1
- {pyg_nightly-2.7.0.dev20250416.dist-info → pyg_nightly-2.7.0.dev20250417.dist-info}/WHEEL +0 -0
- {pyg_nightly-2.7.0.dev20250416.dist-info → pyg_nightly-2.7.0.dev20250417.dist-info}/licenses/LICENSE +0 -0
{pyg_nightly-2.7.0.dev20250416.dist-info → pyg_nightly-2.7.0.dev20250417.dist-info}/METADATA
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: pyg-nightly
|
3
|
-
Version: 2.7.0.
|
3
|
+
Version: 2.7.0.dev20250417
|
4
4
|
Summary: Graph Neural Network Library for PyTorch
|
5
5
|
Keywords: deep-learning,pytorch,geometric-deep-learning,graph-neural-networks,graph-convolutional-networks
|
6
6
|
Author-email: Matthias Fey <matthias@pyg.org>
|
@@ -1,4 +1,4 @@
|
|
1
|
-
torch_geometric/__init__.py,sha256=
|
1
|
+
torch_geometric/__init__.py,sha256=0IDLxL4UGxFzwGVGf-wSekwHSAK3yYrKOr5mJrBk8sM,1978
|
2
2
|
torch_geometric/_compile.py,sha256=f-WQeH4VLi5Hn9lrgztFUCSrN_FImjhQa6BxFzcYC38,1338
|
3
3
|
torch_geometric/_onnx.py,sha256=V9ffrIKSqhDw6xUZ12lkuSfNs48cQp2EeJ6Z19GfnVw,349
|
4
4
|
torch_geometric/backend.py,sha256=lVaf7aLoVaB3M-UcByUJ1G4T4FOK6LXAg0CF4W3E8jo,1575
|
@@ -202,7 +202,7 @@ torch_geometric/explain/algorithm/captum_explainer.py,sha256=oz-c40hvdzii4_chEQP
|
|
202
202
|
torch_geometric/explain/algorithm/dummy_explainer.py,sha256=jvcVQmfngmUWgoKa5p7CXzju2HM5D5DfieJhZW3gbLc,2872
|
203
203
|
torch_geometric/explain/algorithm/gnn_explainer.py,sha256=iu45fGWdd4c6wNczWEAT-29HCAz7ncuoaS6cpx-xDJM,24660
|
204
204
|
torch_geometric/explain/algorithm/graphmask_explainer.py,sha256=T2B081dK-JSpaQmutnkQd5xF3JF49_dPZCOgwqIKJDk,21367
|
205
|
-
torch_geometric/explain/algorithm/pg_explainer.py,sha256=
|
205
|
+
torch_geometric/explain/algorithm/pg_explainer.py,sha256=LMlNcqSqtEP-IzYA7Xix6FoAogcrLUaEUAxDVyz2eyc,20162
|
206
206
|
torch_geometric/explain/algorithm/utils.py,sha256=eh0ARPG41V7piVw5jdMYpV0p7WjTlpehnY-bWqPV_zg,2564
|
207
207
|
torch_geometric/explain/metric/__init__.py,sha256=swLeuWVaM3K7UvowsH7q3BzfTq_W1vhcFY8nEP7vFPQ,301
|
208
208
|
torch_geometric/explain/metric/basic.py,sha256=qN-cho4lxwPlw_X26svJrW5QOnw5GB3lLKf0Js_6rBE,1888
|
@@ -584,7 +584,7 @@ torch_geometric/transforms/to_superpixels.py,sha256=g8ysBv-ezcHn2gHucKuBtnbe-kBD
|
|
584
584
|
torch_geometric/transforms/to_undirected.py,sha256=oklgrNzev7HjvVaBHwPQFo0RxcQpmcIebNbcv6vNCtY,2972
|
585
585
|
torch_geometric/transforms/two_hop.py,sha256=XxZl3eztTjE00ZlyAIqYu36rjaRddQT-1v4AFF9VUBc,1313
|
586
586
|
torch_geometric/transforms/virtual_node.py,sha256=FMGT6LZBH-SU2zmp76GKNqJBZ8PyS1_6Em2BbVhv8Tw,2932
|
587
|
-
torch_geometric/utils/__init__.py,sha256=
|
587
|
+
torch_geometric/utils/__init__.py,sha256=aVet2bjRvr3URikJ6LpjLATz447YuBS6FuSu5l3JwLY,4982
|
588
588
|
torch_geometric/utils/_assortativity.py,sha256=pe2Hv5xLWhTW7dgqVWNiwDgDVMxMbliTdLeQf5Y65Ug,2347
|
589
589
|
torch_geometric/utils/_coalesce.py,sha256=m4s_maBhib0jByQi6Cd8dazzhFVshZXLfB9aykCZT2g,6769
|
590
590
|
torch_geometric/utils/_degree.py,sha256=FcsGx5cQdrBmoCQ4qQ2csjsTiDICP1as4x1HD9y5XVk,1017
|
@@ -613,7 +613,7 @@ torch_geometric/utils/augmentation.py,sha256=1F0YCuaklZ9ZbXxdFV0oOoemWvLd8p60WvF
|
|
613
613
|
torch_geometric/utils/convert.py,sha256=0KEJoBOzU-w-mMQu9QYaMhUqcrGBxBmeRl0hv8NPvII,21697
|
614
614
|
torch_geometric/utils/cross_entropy.py,sha256=ZFS5bivtzv3EV9zqgKsekmuQyoZZggPSclhl_tRNHxo,3047
|
615
615
|
torch_geometric/utils/dropout.py,sha256=gg0rDnD4FLvBaKSoLAkZwViAQflhLefJm6_Mju5dmQs,11416
|
616
|
-
torch_geometric/utils/embedding.py,sha256=
|
616
|
+
torch_geometric/utils/embedding.py,sha256=b-CQ-aapEgahxSS7fuL4aNQX6GJROboV0xclZ_MmwO0,5179
|
617
617
|
torch_geometric/utils/functions.py,sha256=orQdS_6EpzWSmBHSok3WhxCzLy9neB-cin1aTnlXY-8,703
|
618
618
|
torch_geometric/utils/geodesic.py,sha256=-xsqE3FZU7Y9gMbucIlGJ4FM-3nk8o0AQBxIdN-QfEw,4770
|
619
619
|
torch_geometric/utils/hetero.py,sha256=ok4uAAOyMiaeEPmvyS4DNoDwdKnLS2gmgs5WVVklxOo,5539
|
@@ -636,7 +636,7 @@ torch_geometric/utils/undirected.py,sha256=H_nfpI0_WluOG6VfjPyldvcjL4w5USAKWu2x5
|
|
636
636
|
torch_geometric/visualization/__init__.py,sha256=PyR_4K5SafsJrBr6qWrkjKr6GBL1b7FtZybyXCDEVwY,154
|
637
637
|
torch_geometric/visualization/graph.py,sha256=ZuLPL92yGRi7lxlqsUPwL_EVVXF7P2kMcveTtW79vpA,4784
|
638
638
|
torch_geometric/visualization/influence.py,sha256=CWMvuNA_Nf1sfbJmQgn58yS4OFpeKXeZPe7kEuvkUBw,477
|
639
|
-
pyg_nightly-2.7.0.
|
640
|
-
pyg_nightly-2.7.0.
|
641
|
-
pyg_nightly-2.7.0.
|
642
|
-
pyg_nightly-2.7.0.
|
639
|
+
pyg_nightly-2.7.0.dev20250417.dist-info/licenses/LICENSE,sha256=ic-27cMJc1kWoMEYncz3Ya3Ur2Bi3bNLWib2DT763-o,1067
|
640
|
+
pyg_nightly-2.7.0.dev20250417.dist-info/WHEEL,sha256=G2gURzTEtmeR8nrdXUJfNiB3VYVxigPQ-bEQujpNiNs,82
|
641
|
+
pyg_nightly-2.7.0.dev20250417.dist-info/METADATA,sha256=o5Go2F_1ZpI3j7XxxvhyaRfDh1YnHLFrtEmZLBOvw3w,62979
|
642
|
+
pyg_nightly-2.7.0.dev20250417.dist-info/RECORD,,
|
torch_geometric/__init__.py
CHANGED
@@ -31,7 +31,7 @@ from .lazy_loader import LazyLoader
|
|
31
31
|
contrib = LazyLoader('contrib', globals(), 'torch_geometric.contrib')
|
32
32
|
graphgym = LazyLoader('graphgym', globals(), 'torch_geometric.graphgym')
|
33
33
|
|
34
|
-
__version__ = '2.7.0.
|
34
|
+
__version__ = '2.7.0.dev20250417'
|
35
35
|
|
36
36
|
__all__ = [
|
37
37
|
'Index',
|
@@ -1,21 +1,26 @@
|
|
1
1
|
import logging
|
2
|
-
from typing import Optional, Union
|
2
|
+
from typing import Dict, Optional, Tuple, Union, overload
|
3
3
|
|
4
4
|
import torch
|
5
5
|
from torch import Tensor
|
6
6
|
from torch.nn import ReLU, Sequential
|
7
7
|
|
8
|
-
from torch_geometric.explain import Explanation
|
8
|
+
from torch_geometric.explain import Explanation, HeteroExplanation
|
9
9
|
from torch_geometric.explain.algorithm import ExplainerAlgorithm
|
10
|
-
from torch_geometric.explain.algorithm.utils import
|
10
|
+
from torch_geometric.explain.algorithm.utils import (
|
11
|
+
clear_masks,
|
12
|
+
set_hetero_masks,
|
13
|
+
set_masks,
|
14
|
+
)
|
11
15
|
from torch_geometric.explain.config import (
|
12
16
|
ExplanationType,
|
13
17
|
ModelMode,
|
14
18
|
ModelTaskLevel,
|
15
19
|
)
|
16
|
-
from torch_geometric.nn import Linear
|
20
|
+
from torch_geometric.nn import HANConv, HeteroConv, HGTConv, Linear
|
17
21
|
from torch_geometric.nn.inits import reset
|
18
|
-
from torch_geometric.
|
22
|
+
from torch_geometric.typing import EdgeType, NodeType
|
23
|
+
from torch_geometric.utils import get_embeddings, get_embeddings_hetero
|
19
24
|
|
20
25
|
|
21
26
|
class PGExplainer(ExplainerAlgorithm):
|
@@ -62,6 +67,13 @@ class PGExplainer(ExplainerAlgorithm):
|
|
62
67
|
'bias': 0.01,
|
63
68
|
}
|
64
69
|
|
70
|
+
# NOTE: Add more in the future as needed.
|
71
|
+
SUPPORTED_HETERO_MODELS = [
|
72
|
+
HGTConv,
|
73
|
+
HANConv,
|
74
|
+
HeteroConv,
|
75
|
+
]
|
76
|
+
|
65
77
|
def __init__(self, epochs: int, lr: float = 0.003, **kwargs):
|
66
78
|
super().__init__()
|
67
79
|
self.epochs = epochs
|
@@ -75,11 +87,13 @@ class PGExplainer(ExplainerAlgorithm):
|
|
75
87
|
)
|
76
88
|
self.optimizer = torch.optim.Adam(self.mlp.parameters(), lr=lr)
|
77
89
|
self._curr_epoch = -1
|
90
|
+
self.is_hetero = False
|
78
91
|
|
79
92
|
def reset_parameters(self):
|
80
93
|
r"""Resets all learnable parameters of the module."""
|
81
94
|
reset(self.mlp)
|
82
95
|
|
96
|
+
@overload
|
83
97
|
def train(
|
84
98
|
self,
|
85
99
|
epoch: int,
|
@@ -90,17 +104,44 @@ class PGExplainer(ExplainerAlgorithm):
|
|
90
104
|
target: Tensor,
|
91
105
|
index: Optional[Union[int, Tensor]] = None,
|
92
106
|
**kwargs,
|
93
|
-
):
|
107
|
+
) -> float:
|
108
|
+
...
|
109
|
+
|
110
|
+
@overload
|
111
|
+
def train(
|
112
|
+
self,
|
113
|
+
epoch: int,
|
114
|
+
model: torch.nn.Module,
|
115
|
+
x: Dict[NodeType, Tensor],
|
116
|
+
edge_index: Dict[EdgeType, Tensor],
|
117
|
+
*,
|
118
|
+
target: Tensor,
|
119
|
+
index: Optional[Union[int, Tensor]] = None,
|
120
|
+
**kwargs,
|
121
|
+
) -> float:
|
122
|
+
...
|
123
|
+
|
124
|
+
def train(
|
125
|
+
self,
|
126
|
+
epoch: int,
|
127
|
+
model: torch.nn.Module,
|
128
|
+
x: Union[Tensor, Dict[NodeType, Tensor]],
|
129
|
+
edge_index: Union[Tensor, Dict[EdgeType, Tensor]],
|
130
|
+
*,
|
131
|
+
target: Tensor,
|
132
|
+
index: Optional[Union[int, Tensor]] = None,
|
133
|
+
**kwargs,
|
134
|
+
) -> float:
|
94
135
|
r"""Trains the underlying explainer model.
|
95
136
|
Needs to be called before being able to make predictions.
|
96
137
|
|
97
138
|
Args:
|
98
139
|
epoch (int): The current epoch of the training phase.
|
99
140
|
model (torch.nn.Module): The model to explain.
|
100
|
-
x (torch.Tensor): The input node
|
101
|
-
homogeneous
|
102
|
-
edge_index (torch.Tensor): The input
|
103
|
-
|
141
|
+
x (torch.Tensor or Dict[str, torch.Tensor]): The input node
|
142
|
+
features. Can be either homogeneous or heterogeneous.
|
143
|
+
edge_index (torch.Tensor or Dict[Tuple[str, str, str]): The input
|
144
|
+
edge indices. Can be either homogeneous or heterogeneous.
|
104
145
|
target (torch.Tensor): The target of the model.
|
105
146
|
index (int or torch.Tensor, optional): The index of the model
|
106
147
|
output to explain. Needs to be a single index.
|
@@ -108,9 +149,9 @@ class PGExplainer(ExplainerAlgorithm):
|
|
108
149
|
**kwargs (optional): Additional keyword arguments passed to
|
109
150
|
:obj:`model`.
|
110
151
|
"""
|
111
|
-
|
112
|
-
|
113
|
-
|
152
|
+
self.is_hetero = isinstance(x, dict)
|
153
|
+
if self.is_hetero:
|
154
|
+
assert isinstance(edge_index, dict)
|
114
155
|
|
115
156
|
if self.model_config.task_level == ModelTaskLevel.node:
|
116
157
|
if index is None:
|
@@ -121,35 +162,68 @@ class PGExplainer(ExplainerAlgorithm):
|
|
121
162
|
raise ValueError(f"Only scalars are supported for the 'index' "
|
122
163
|
f"argument in '{self.__class__.__name__}'")
|
123
164
|
|
124
|
-
|
165
|
+
# Get embeddings based on whether the graph is homogeneous or
|
166
|
+
# heterogeneous
|
167
|
+
node_embeddings = self._get_embeddings(model, x, edge_index, **kwargs)
|
125
168
|
|
169
|
+
# Train the model
|
126
170
|
self.optimizer.zero_grad()
|
127
171
|
temperature = self._get_temperature(epoch)
|
128
172
|
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
if self.
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
173
|
+
# Process embeddings and generate edge masks
|
174
|
+
edge_mask = self._generate_edge_masks(node_embeddings, edge_index,
|
175
|
+
index, temperature)
|
176
|
+
|
177
|
+
# Apply masks to the model
|
178
|
+
if self.is_hetero:
|
179
|
+
set_hetero_masks(model, edge_mask, edge_index, apply_sigmoid=True)
|
180
|
+
|
181
|
+
# For node-level tasks, we can compute hard masks
|
182
|
+
if self.model_config.task_level == ModelTaskLevel.node:
|
183
|
+
# Process each edge type separately
|
184
|
+
for edge_type, mask in edge_mask.items():
|
185
|
+
# Get the edge indices for this edge type
|
186
|
+
edges = edge_index[edge_type]
|
187
|
+
src_type, _, dst_type = edge_type
|
188
|
+
|
189
|
+
# Get hard masks for this specific edge type
|
190
|
+
_, hard_mask = self._get_hard_masks(
|
191
|
+
model, index, edges,
|
192
|
+
num_nodes=max(x[src_type].size(0),
|
193
|
+
x[dst_type].size(0)))
|
194
|
+
|
195
|
+
edge_mask[edge_type] = mask[hard_mask]
|
196
|
+
else:
|
197
|
+
# Apply masks for homogeneous graphs
|
198
|
+
set_masks(model, edge_mask, edge_index, apply_sigmoid=True)
|
199
|
+
|
200
|
+
# For node-level tasks, we may need to apply hard masks
|
201
|
+
hard_edge_mask = None
|
202
|
+
if self.model_config.task_level == ModelTaskLevel.node:
|
203
|
+
_, hard_edge_mask = self._get_hard_masks(
|
204
|
+
model, index, edge_index, num_nodes=x.size(0))
|
205
|
+
edge_mask = edge_mask[hard_edge_mask]
|
206
|
+
|
207
|
+
# Forward pass with masks applied
|
139
208
|
y_hat, y = model(x, edge_index, **kwargs), target
|
140
209
|
|
141
210
|
if index is not None:
|
142
211
|
y_hat, y = y_hat[index], y[index]
|
143
212
|
|
213
|
+
# Calculate loss
|
144
214
|
loss = self._loss(y_hat, y, edge_mask)
|
215
|
+
|
216
|
+
# Backward pass and optimization
|
145
217
|
loss.backward()
|
146
218
|
self.optimizer.step()
|
147
219
|
|
220
|
+
# Clean up
|
148
221
|
clear_masks(model)
|
149
222
|
self._curr_epoch = epoch
|
150
223
|
|
151
224
|
return float(loss)
|
152
225
|
|
226
|
+
@overload
|
153
227
|
def forward(
|
154
228
|
self,
|
155
229
|
model: torch.nn.Module,
|
@@ -160,9 +234,32 @@ class PGExplainer(ExplainerAlgorithm):
|
|
160
234
|
index: Optional[Union[int, Tensor]] = None,
|
161
235
|
**kwargs,
|
162
236
|
) -> Explanation:
|
163
|
-
|
164
|
-
|
165
|
-
|
237
|
+
...
|
238
|
+
|
239
|
+
@overload
|
240
|
+
def forward(
|
241
|
+
self,
|
242
|
+
model: torch.nn.Module,
|
243
|
+
x: Dict[NodeType, Tensor],
|
244
|
+
edge_index: Dict[EdgeType, Tensor],
|
245
|
+
*,
|
246
|
+
target: Tensor,
|
247
|
+
index: Optional[Union[int, Tensor]] = None,
|
248
|
+
**kwargs,
|
249
|
+
) -> HeteroExplanation:
|
250
|
+
...
|
251
|
+
|
252
|
+
def forward(
|
253
|
+
self,
|
254
|
+
model: torch.nn.Module,
|
255
|
+
x: Union[Tensor, Dict[NodeType, Tensor]],
|
256
|
+
edge_index: Union[Tensor, Dict[EdgeType, Tensor]],
|
257
|
+
*,
|
258
|
+
target: Tensor,
|
259
|
+
index: Optional[Union[int, Tensor]] = None,
|
260
|
+
**kwargs,
|
261
|
+
) -> Union[Explanation, HeteroExplanation]:
|
262
|
+
self.is_hetero = isinstance(x, dict)
|
166
263
|
|
167
264
|
if self._curr_epoch < self.epochs - 1: # Safety check:
|
168
265
|
raise ValueError(f"'{self.__class__.__name__}' is not yet fully "
|
@@ -171,7 +268,6 @@ class PGExplainer(ExplainerAlgorithm):
|
|
171
268
|
f"the underlying explainer model by running "
|
172
269
|
f"`explainer.algorithm.train(...)`.")
|
173
270
|
|
174
|
-
hard_edge_mask = None
|
175
271
|
if self.model_config.task_level == ModelTaskLevel.node:
|
176
272
|
if index is None:
|
177
273
|
raise ValueError(f"The 'index' argument needs to be provided "
|
@@ -181,20 +277,55 @@ class PGExplainer(ExplainerAlgorithm):
|
|
181
277
|
raise ValueError(f"Only scalars are supported for the 'index' "
|
182
278
|
f"argument in '{self.__class__.__name__}'")
|
183
279
|
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
280
|
+
# Get embeddings
|
281
|
+
node_embeddings = self._get_embeddings(model, x, edge_index, **kwargs)
|
282
|
+
|
283
|
+
# Generate explanations
|
284
|
+
if self.is_hetero:
|
285
|
+
# Generate edge masks for each edge type
|
286
|
+
edge_masks = {}
|
287
|
+
|
288
|
+
# Generate masks for each edge type
|
289
|
+
for edge_type, edge_idx in edge_index.items():
|
290
|
+
src_node_type, _, dst_node_type = edge_type
|
291
|
+
|
292
|
+
assert src_node_type in node_embeddings
|
293
|
+
assert dst_node_type in node_embeddings
|
294
|
+
|
295
|
+
inputs = self._get_inputs_hetero(node_embeddings, edge_type,
|
296
|
+
edge_idx, index)
|
297
|
+
logits = self.mlp(inputs).view(-1)
|
298
|
+
|
299
|
+
# For node-level explanations, get hard masks for this
|
300
|
+
# specific edge type
|
301
|
+
hard_edge_mask = None
|
302
|
+
if self.model_config.task_level == ModelTaskLevel.node:
|
303
|
+
_, hard_edge_mask = self._get_hard_masks(
|
304
|
+
model, index, edge_idx,
|
305
|
+
num_nodes=max(x[src_node_type].size(0),
|
306
|
+
x[dst_node_type].size(0)))
|
188
307
|
|
189
|
-
|
308
|
+
# Apply hard mask if available and it has any True values
|
309
|
+
edge_masks[edge_type] = self._post_process_mask(
|
310
|
+
logits, hard_edge_mask, apply_sigmoid=True)
|
190
311
|
|
191
|
-
|
192
|
-
|
312
|
+
explanation = HeteroExplanation()
|
313
|
+
explanation.set_value_dict('edge_mask', edge_masks)
|
314
|
+
return explanation
|
315
|
+
else:
|
316
|
+
hard_edge_mask = None
|
317
|
+
if self.model_config.task_level == ModelTaskLevel.node:
|
318
|
+
# We need to compute hard masks to properly clean up edges
|
319
|
+
_, hard_edge_mask = self._get_hard_masks(
|
320
|
+
model, index, edge_index, num_nodes=x.size(0))
|
193
321
|
|
194
|
-
|
195
|
-
|
322
|
+
inputs = self._get_inputs(node_embeddings, edge_index, index)
|
323
|
+
logits = self.mlp(inputs).view(-1)
|
196
324
|
|
197
|
-
|
325
|
+
edge_mask = self._post_process_mask(logits, hard_edge_mask,
|
326
|
+
apply_sigmoid=True)
|
327
|
+
|
328
|
+
return Explanation(edge_mask=edge_mask)
|
198
329
|
|
199
330
|
def supports(self) -> bool:
|
200
331
|
explanation_type = self.explainer_config.explanation_type
|
@@ -222,6 +353,76 @@ class PGExplainer(ExplainerAlgorithm):
|
|
222
353
|
|
223
354
|
###########################################################################
|
224
355
|
|
356
|
+
def _get_embeddings(self, model: torch.nn.Module, x: Union[Tensor,
|
357
|
+
Dict[NodeType,
|
358
|
+
Tensor]],
|
359
|
+
edge_index: Union[Tensor, Dict[EdgeType, Tensor]],
|
360
|
+
**kwargs) -> Union[Tensor, Dict[NodeType, Tensor]]:
|
361
|
+
"""Get embeddings from the model based on input type."""
|
362
|
+
if self.is_hetero:
|
363
|
+
# For heterogeneous graphs, get embeddings for each node type
|
364
|
+
embeddings_dict = get_embeddings_hetero(
|
365
|
+
model,
|
366
|
+
self.SUPPORTED_HETERO_MODELS,
|
367
|
+
x,
|
368
|
+
edge_index,
|
369
|
+
**kwargs,
|
370
|
+
)
|
371
|
+
|
372
|
+
# Use the last layer's embeddings for each node type
|
373
|
+
last_embedding_dict = {
|
374
|
+
node_type: embs[-1] if embs and len(embs) > 0 else None
|
375
|
+
for node_type, embs in embeddings_dict.items()
|
376
|
+
}
|
377
|
+
|
378
|
+
# Skip if no embeddings were captured
|
379
|
+
if not any(emb is not None
|
380
|
+
for emb in last_embedding_dict.values()):
|
381
|
+
raise ValueError(
|
382
|
+
"No embeddings were captured from the model. "
|
383
|
+
"Please check if the model architecture is supported.")
|
384
|
+
|
385
|
+
return last_embedding_dict
|
386
|
+
else:
|
387
|
+
# For homogeneous graphs, get embeddings directly
|
388
|
+
return get_embeddings(model, x, edge_index, **kwargs)[-1]
|
389
|
+
|
390
|
+
def _generate_edge_masks(
|
391
|
+
self, emb: Union[Tensor, Dict[NodeType, Tensor]],
|
392
|
+
edge_index: Union[Tensor,
|
393
|
+
Dict[EdgeType,
|
394
|
+
Tensor]], index: Optional[Union[int,
|
395
|
+
Tensor]],
|
396
|
+
temperature: float) -> Union[Tensor, Dict[EdgeType, Tensor]]:
|
397
|
+
"""Generate edge masks based on embeddings."""
|
398
|
+
if self.is_hetero:
|
399
|
+
# For heterogeneous graphs, generate masks for each edge type
|
400
|
+
edge_masks = {}
|
401
|
+
|
402
|
+
for edge_type, edge_idx in edge_index.items():
|
403
|
+
src, _, dst = edge_type
|
404
|
+
|
405
|
+
assert src in emb and dst in emb
|
406
|
+
# Generate inputs for this edge type
|
407
|
+
inputs = self._get_inputs_hetero(emb, edge_type, edge_idx,
|
408
|
+
index)
|
409
|
+
logits = self.mlp(inputs).view(-1)
|
410
|
+
edge_masks[edge_type] = self._concrete_sample(
|
411
|
+
logits, temperature)
|
412
|
+
|
413
|
+
# Ensure we have at least one valid edge mask
|
414
|
+
if not edge_masks:
|
415
|
+
raise ValueError(
|
416
|
+
"Could not generate edge masks for any edge type. "
|
417
|
+
"Please ensure the model architecture is supported.")
|
418
|
+
|
419
|
+
return edge_masks
|
420
|
+
else:
|
421
|
+
# For homogeneous graphs, generate a single mask
|
422
|
+
inputs = self._get_inputs(emb, edge_index, index)
|
423
|
+
logits = self.mlp(inputs).view(-1)
|
424
|
+
return self._concrete_sample(logits, temperature)
|
425
|
+
|
225
426
|
def _get_inputs(self, embedding: Tensor, edge_index: Tensor,
|
226
427
|
index: Optional[int] = None) -> Tensor:
|
227
428
|
zs = [embedding[edge_index[0]], embedding[edge_index[1]]]
|
@@ -230,6 +431,27 @@ class PGExplainer(ExplainerAlgorithm):
|
|
230
431
|
zs.append(embedding[index].view(1, -1).repeat(zs[0].size(0), 1))
|
231
432
|
return torch.cat(zs, dim=-1)
|
232
433
|
|
434
|
+
def _get_inputs_hetero(self, embedding_dict: Dict[NodeType, Tensor],
|
435
|
+
edge_type: Tuple[str, str, str], edge_index: Tensor,
|
436
|
+
index: Optional[int] = None) -> Tensor:
|
437
|
+
src, _, dst = edge_type
|
438
|
+
|
439
|
+
# Get embeddings for source and destination nodes
|
440
|
+
src_emb = embedding_dict[src]
|
441
|
+
dst_emb = embedding_dict[dst]
|
442
|
+
|
443
|
+
# Source and destination node embeddings
|
444
|
+
zs = [src_emb[edge_index[0]], dst_emb[edge_index[1]]]
|
445
|
+
|
446
|
+
# For node-level explanations, add the target node embedding
|
447
|
+
if self.model_config.task_level == ModelTaskLevel.node:
|
448
|
+
assert index is not None
|
449
|
+
# Assuming index refers to a node of type 'src'
|
450
|
+
target_emb = src_emb[index].view(1, -1).repeat(zs[0].size(0), 1)
|
451
|
+
zs.append(target_emb)
|
452
|
+
|
453
|
+
return torch.cat(zs, dim=-1)
|
454
|
+
|
233
455
|
def _get_temperature(self, epoch: int) -> float:
|
234
456
|
temp = self.coeffs['temp']
|
235
457
|
return temp[0] * pow(temp[1] / temp[0], epoch / self.epochs)
|
@@ -240,19 +462,55 @@ class PGExplainer(ExplainerAlgorithm):
|
|
240
462
|
eps = (1 - 2 * bias) * torch.rand_like(logits) + bias
|
241
463
|
return (eps.log() - (1 - eps).log() + logits) / temperature
|
242
464
|
|
243
|
-
def _loss(self, y_hat: Tensor, y: Tensor,
|
465
|
+
def _loss(self, y_hat: Tensor, y: Tensor,
|
466
|
+
edge_mask: Union[Tensor, Dict[EdgeType, Tensor]]) -> Tensor:
|
467
|
+
# Calculate base loss based on model configuration
|
468
|
+
loss = self._calculate_base_loss(y_hat, y)
|
469
|
+
|
470
|
+
# Apply regularization based on graph type
|
471
|
+
if self.is_hetero:
|
472
|
+
loss = self._apply_hetero_regularization(loss, edge_mask)
|
473
|
+
else:
|
474
|
+
loss = self._apply_homo_regularization(loss, edge_mask)
|
475
|
+
|
476
|
+
return loss
|
477
|
+
|
478
|
+
def _calculate_base_loss(self, y_hat: Tensor, y: Tensor) -> Tensor:
|
479
|
+
"""Calculate base loss based on model configuration."""
|
244
480
|
if self.model_config.mode == ModelMode.binary_classification:
|
245
|
-
|
481
|
+
return self._loss_binary_classification(y_hat, y)
|
246
482
|
elif self.model_config.mode == ModelMode.multiclass_classification:
|
247
|
-
|
483
|
+
return self._loss_multiclass_classification(y_hat, y)
|
248
484
|
elif self.model_config.mode == ModelMode.regression:
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
485
|
+
return self._loss_regression(y_hat, y)
|
486
|
+
else:
|
487
|
+
raise ValueError(
|
488
|
+
f"Unsupported model mode: {self.model_config.mode}")
|
489
|
+
|
490
|
+
def _apply_hetero_regularization(
|
491
|
+
self, loss: Tensor, edge_mask: Dict[EdgeType, Tensor]) -> Tensor:
|
492
|
+
"""Apply regularization for heterogeneous graph."""
|
493
|
+
for _, mask in edge_mask.items():
|
494
|
+
loss = self._add_mask_regularization(loss, mask)
|
495
|
+
|
496
|
+
return loss
|
497
|
+
|
498
|
+
def _apply_homo_regularization(self, loss: Tensor,
|
499
|
+
edge_mask: Tensor) -> Tensor:
|
500
|
+
"""Apply regularization for homogeneous graph."""
|
501
|
+
return self._add_mask_regularization(loss, edge_mask)
|
502
|
+
|
503
|
+
def _add_mask_regularization(self, loss: Tensor, mask: Tensor) -> Tensor:
|
504
|
+
"""Add size and entropy regularization for a mask."""
|
505
|
+
# Apply sigmoid for mask values
|
506
|
+
mask = mask.sigmoid()
|
507
|
+
|
508
|
+
# Size regularization
|
253
509
|
size_loss = mask.sum() * self.coeffs['edge_size']
|
254
|
-
|
255
|
-
|
510
|
+
|
511
|
+
# Entropy regularization
|
512
|
+
masked = 0.99 * mask + 0.005
|
513
|
+
mask_ent = -masked * masked.log() - (1 - masked) * (1 - masked).log()
|
256
514
|
mask_ent_loss = mask_ent.mean() * self.coeffs['edge_ent']
|
257
515
|
|
258
516
|
return loss + size_loss + mask_ent_loss
|
@@ -53,7 +53,7 @@ from ._negative_sampling import (negative_sampling, batched_negative_sampling,
|
|
53
53
|
structured_negative_sampling_feasible)
|
54
54
|
from .augmentation import shuffle_node, mask_feature, add_random_edge
|
55
55
|
from ._tree_decomposition import tree_decomposition
|
56
|
-
from .embedding import get_embeddings
|
56
|
+
from .embedding import get_embeddings, get_embeddings_hetero
|
57
57
|
from ._trim_to_layer import trim_to_layer
|
58
58
|
from .ppr import get_ppr
|
59
59
|
from ._train_test_split_edges import train_test_split_edges
|
@@ -145,6 +145,7 @@ __all__ = [
|
|
145
145
|
'add_random_edge',
|
146
146
|
'tree_decomposition',
|
147
147
|
'get_embeddings',
|
148
|
+
'get_embeddings_hetero',
|
148
149
|
'trim_to_layer',
|
149
150
|
'get_ppr',
|
150
151
|
'train_test_split_edges',
|
@@ -1,9 +1,11 @@
|
|
1
1
|
import warnings
|
2
|
-
from typing import Any, List
|
2
|
+
from typing import Any, Dict, List, Optional, Type
|
3
3
|
|
4
4
|
import torch
|
5
5
|
from torch import Tensor
|
6
6
|
|
7
|
+
from torch_geometric.typing import NodeType
|
8
|
+
|
7
9
|
|
8
10
|
def get_embeddings(
|
9
11
|
model: torch.nn.Module,
|
@@ -52,3 +54,88 @@ def get_embeddings(
|
|
52
54
|
handle.remove()
|
53
55
|
|
54
56
|
return embeddings
|
57
|
+
|
58
|
+
|
59
|
+
def get_embeddings_hetero(
|
60
|
+
model: torch.nn.Module,
|
61
|
+
supported_models: Optional[List[Type[torch.nn.Module]]] = None,
|
62
|
+
*args: Any,
|
63
|
+
**kwargs: Any,
|
64
|
+
) -> Dict[NodeType, List[Tensor]]:
|
65
|
+
"""Returns the output embeddings of all
|
66
|
+
:class:`~torch_geometric.nn.conv.MessagePassing` layers in a heterogeneous
|
67
|
+
:obj:`model`, organized by edge type.
|
68
|
+
|
69
|
+
Internally, this method registers forward hooks on all modules that process
|
70
|
+
heterogeneous graphs in the model and runs the forward pass of the model.
|
71
|
+
For heterogeneous models, the output is a dictionary where each key is a
|
72
|
+
node type and each value is a list of embeddings from different layers.
|
73
|
+
|
74
|
+
Args:
|
75
|
+
model (torch.nn.Module): The heterogeneous GNN model.
|
76
|
+
supported_models (List[Type[torch.nn.Module]], optional): A list of
|
77
|
+
supported model classes. If not provided, defaults to
|
78
|
+
[HGTConv, HANConv, HeteroConv].
|
79
|
+
*args: Arguments passed to the model.
|
80
|
+
**kwargs (optional): Additional keyword arguments passed to the model.
|
81
|
+
|
82
|
+
Returns:
|
83
|
+
Dict[NodeType, List[Tensor]]: A dictionary mapping each node type to
|
84
|
+
a list of embeddings from different layers.
|
85
|
+
"""
|
86
|
+
from torch_geometric.nn import HANConv, HeteroConv, HGTConv
|
87
|
+
if not supported_models:
|
88
|
+
supported_models = [HGTConv, HANConv, HeteroConv]
|
89
|
+
|
90
|
+
# Dictionary to store node embeddings by type
|
91
|
+
node_embeddings_dict: Dict[NodeType, List[Tensor]] = {}
|
92
|
+
|
93
|
+
# Hook function to capture node embeddings
|
94
|
+
def hook(model: torch.nn.Module, inputs: Any, outputs: Any) -> None:
|
95
|
+
# Check if the outputs is a dictionary mapping node types to embeddings
|
96
|
+
if isinstance(outputs, dict) and outputs:
|
97
|
+
# Store embeddings for each node type
|
98
|
+
for node_type, embedding in outputs.items():
|
99
|
+
# Made sure that the outputs are a dictionary mapping node
|
100
|
+
# types to embeddings and remove the false positives.
|
101
|
+
if node_type not in node_embeddings_dict:
|
102
|
+
node_embeddings_dict[node_type] = []
|
103
|
+
node_embeddings_dict[node_type].append(embedding.clone())
|
104
|
+
|
105
|
+
# List to store hook handles
|
106
|
+
hook_handles = []
|
107
|
+
|
108
|
+
# Find ModuleDict objects in the model
|
109
|
+
for _, module in model.named_modules():
|
110
|
+
# Handle the native heterogenous models, e.g. HGTConv, HANConv
|
111
|
+
# and HeteroConv, etc.
|
112
|
+
if isinstance(module, tuple(supported_models)):
|
113
|
+
hook_handles.append(module.register_forward_hook(hook))
|
114
|
+
else:
|
115
|
+
# Handle the heterogenous models that are generated by calling
|
116
|
+
# to_hetero() on the homogeneous models.
|
117
|
+
submodules = list(module.children())
|
118
|
+
submodules_contains_module_dict = any([
|
119
|
+
isinstance(submodule, torch.nn.ModuleDict)
|
120
|
+
for submodule in submodules
|
121
|
+
])
|
122
|
+
if submodules_contains_module_dict:
|
123
|
+
hook_handles.append(module.register_forward_hook(hook))
|
124
|
+
|
125
|
+
if len(hook_handles) == 0:
|
126
|
+
warnings.warn("The 'model' does not have any heterogenous "
|
127
|
+
"'MessagePassing' layers")
|
128
|
+
|
129
|
+
# Run the model forward pass
|
130
|
+
training = model.training
|
131
|
+
model.eval()
|
132
|
+
|
133
|
+
with torch.no_grad():
|
134
|
+
model(*args, **kwargs)
|
135
|
+
model.train(training)
|
136
|
+
|
137
|
+
# Clean up hooks
|
138
|
+
for handle in hook_handles:
|
139
|
+
handle.remove()
|
140
|
+
|
141
|
+
return node_embeddings_dict
|
File without changes
|
{pyg_nightly-2.7.0.dev20250416.dist-info → pyg_nightly-2.7.0.dev20250417.dist-info}/licenses/LICENSE
RENAMED
File without changes
|