pyg-nightly 2.7.0.dev20250423__py3-none-any.whl → 2.7.0.dev20250425__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pyg_nightly-2.7.0.dev20250423.dist-info → pyg_nightly-2.7.0.dev20250425.dist-info}/METADATA +1 -1
- {pyg_nightly-2.7.0.dev20250423.dist-info → pyg_nightly-2.7.0.dev20250425.dist-info}/RECORD +6 -6
- torch_geometric/__init__.py +1 -1
- torch_geometric/explain/algorithm/attention_explainer.py +219 -29
- {pyg_nightly-2.7.0.dev20250423.dist-info → pyg_nightly-2.7.0.dev20250425.dist-info}/WHEEL +0 -0
- {pyg_nightly-2.7.0.dev20250423.dist-info → pyg_nightly-2.7.0.dev20250425.dist-info}/licenses/LICENSE +0 -0
{pyg_nightly-2.7.0.dev20250423.dist-info → pyg_nightly-2.7.0.dev20250425.dist-info}/METADATA
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: pyg-nightly
|
3
|
-
Version: 2.7.0.
|
3
|
+
Version: 2.7.0.dev20250425
|
4
4
|
Summary: Graph Neural Network Library for PyTorch
|
5
5
|
Keywords: deep-learning,pytorch,geometric-deep-learning,graph-neural-networks,graph-convolutional-networks
|
6
6
|
Author-email: Matthias Fey <matthias@pyg.org>
|
@@ -1,4 +1,4 @@
|
|
1
|
-
torch_geometric/__init__.py,sha256=
|
1
|
+
torch_geometric/__init__.py,sha256=C-jEor7eBeuebCtYnau6wgH8uwUZMc0YCHWJ8KjRZ2s,1978
|
2
2
|
torch_geometric/_compile.py,sha256=f-WQeH4VLi5Hn9lrgztFUCSrN_FImjhQa6BxFzcYC38,1338
|
3
3
|
torch_geometric/_onnx.py,sha256=V9ffrIKSqhDw6xUZ12lkuSfNs48cQp2EeJ6Z19GfnVw,349
|
4
4
|
torch_geometric/backend.py,sha256=lVaf7aLoVaB3M-UcByUJ1G4T4FOK6LXAg0CF4W3E8jo,1575
|
@@ -195,7 +195,7 @@ torch_geometric/explain/config.py,sha256=_0j67NAwPwjrWHPncNywCT-oKyMiryJNxufxVN1
|
|
195
195
|
torch_geometric/explain/explainer.py,sha256=8_NZTmlT4WO9RgKxpSUQRt3rbVwFURF5bSWOPlfOLjA,10667
|
196
196
|
torch_geometric/explain/explanation.py,sha256=Z2NlgavEnq0QadEr6p6pxAhV6lU7WrlcJLFWbTdsmvg,14903
|
197
197
|
torch_geometric/explain/algorithm/__init__.py,sha256=fE29xbd0bPxg-EfrB2BDmmY9QnyO-7TgvYduGHofm5o,496
|
198
|
-
torch_geometric/explain/algorithm/attention_explainer.py,sha256=
|
198
|
+
torch_geometric/explain/algorithm/attention_explainer.py,sha256=65iGLmOt00ERtBDVxAoydIchykdWZU24aXzSzUGzQEI,11304
|
199
199
|
torch_geometric/explain/algorithm/base.py,sha256=wwJcREUFKDLFUDjRa9o4X3DWqQgMvhS3Iciwb6Evtjc,6922
|
200
200
|
torch_geometric/explain/algorithm/captum.py,sha256=k6hNgC5Kn9lVirOYVJzej8-hRuf5C2mPFUXFLd2wWsY,12857
|
201
201
|
torch_geometric/explain/algorithm/captum_explainer.py,sha256=oz-c40hvdzii4_chEQPHzQo_dFjHr9HLuJhDLsqRIVU,7346
|
@@ -636,7 +636,7 @@ torch_geometric/utils/undirected.py,sha256=H_nfpI0_WluOG6VfjPyldvcjL4w5USAKWu2x5
|
|
636
636
|
torch_geometric/visualization/__init__.py,sha256=PyR_4K5SafsJrBr6qWrkjKr6GBL1b7FtZybyXCDEVwY,154
|
637
637
|
torch_geometric/visualization/graph.py,sha256=ZuLPL92yGRi7lxlqsUPwL_EVVXF7P2kMcveTtW79vpA,4784
|
638
638
|
torch_geometric/visualization/influence.py,sha256=CWMvuNA_Nf1sfbJmQgn58yS4OFpeKXeZPe7kEuvkUBw,477
|
639
|
-
pyg_nightly-2.7.0.
|
640
|
-
pyg_nightly-2.7.0.
|
641
|
-
pyg_nightly-2.7.0.
|
642
|
-
pyg_nightly-2.7.0.
|
639
|
+
pyg_nightly-2.7.0.dev20250425.dist-info/licenses/LICENSE,sha256=ic-27cMJc1kWoMEYncz3Ya3Ur2Bi3bNLWib2DT763-o,1067
|
640
|
+
pyg_nightly-2.7.0.dev20250425.dist-info/WHEEL,sha256=G2gURzTEtmeR8nrdXUJfNiB3VYVxigPQ-bEQujpNiNs,82
|
641
|
+
pyg_nightly-2.7.0.dev20250425.dist-info/METADATA,sha256=bSC09SKLOFJ61t6NRu4O8DvhYvowg2iDcYUoeDPD0U0,62979
|
642
|
+
pyg_nightly-2.7.0.dev20250425.dist-info/RECORD,,
|
torch_geometric/__init__.py
CHANGED
@@ -31,7 +31,7 @@ from .lazy_loader import LazyLoader
|
|
31
31
|
contrib = LazyLoader('contrib', globals(), 'torch_geometric.contrib')
|
32
32
|
graphgym = LazyLoader('graphgym', globals(), 'torch_geometric.graphgym')
|
33
33
|
|
34
|
-
__version__ = '2.7.0.
|
34
|
+
__version__ = '2.7.0.dev20250425'
|
35
35
|
|
36
36
|
__all__ = [
|
37
37
|
'Index',
|
@@ -1,13 +1,14 @@
|
|
1
1
|
import logging
|
2
|
-
from typing import List, Optional, Union
|
2
|
+
from typing import Dict, List, Optional, Union, overload
|
3
3
|
|
4
4
|
import torch
|
5
5
|
from torch import Tensor
|
6
6
|
|
7
|
-
from torch_geometric.explain import Explanation
|
7
|
+
from torch_geometric.explain import Explanation, HeteroExplanation
|
8
8
|
from torch_geometric.explain.algorithm import ExplainerAlgorithm
|
9
9
|
from torch_geometric.explain.config import ExplanationType, ModelTaskLevel
|
10
10
|
from torch_geometric.nn.conv.message_passing import MessagePassing
|
11
|
+
from torch_geometric.typing import EdgeType, NodeType
|
11
12
|
|
12
13
|
|
13
14
|
class AttentionExplainer(ExplainerAlgorithm):
|
@@ -26,7 +27,9 @@ class AttentionExplainer(ExplainerAlgorithm):
|
|
26
27
|
def __init__(self, reduce: str = 'max'):
|
27
28
|
super().__init__()
|
28
29
|
self.reduce = reduce
|
30
|
+
self.is_hetero = False
|
29
31
|
|
32
|
+
@overload
|
30
33
|
def forward(
|
31
34
|
self,
|
32
35
|
model: torch.nn.Module,
|
@@ -37,65 +40,252 @@ class AttentionExplainer(ExplainerAlgorithm):
|
|
37
40
|
index: Optional[Union[int, Tensor]] = None,
|
38
41
|
**kwargs,
|
39
42
|
) -> Explanation:
|
40
|
-
|
41
|
-
raise ValueError(f"Heterogeneous graphs not yet supported in "
|
42
|
-
f"'{self.__class__.__name__}'")
|
43
|
+
...
|
43
44
|
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
45
|
+
@overload
|
46
|
+
def forward(
|
47
|
+
self,
|
48
|
+
model: torch.nn.Module,
|
49
|
+
x: Dict[NodeType, Tensor],
|
50
|
+
edge_index: Dict[EdgeType, Tensor],
|
51
|
+
*,
|
52
|
+
target: Tensor,
|
53
|
+
index: Optional[Union[int, Tensor]] = None,
|
54
|
+
**kwargs,
|
55
|
+
) -> HeteroExplanation:
|
56
|
+
...
|
57
|
+
|
58
|
+
def forward(
|
59
|
+
self,
|
60
|
+
model: torch.nn.Module,
|
61
|
+
x: Union[Tensor, Dict[NodeType, Tensor]],
|
62
|
+
edge_index: Union[Tensor, Dict[EdgeType, Tensor]],
|
63
|
+
*,
|
64
|
+
target: Tensor,
|
65
|
+
index: Optional[Union[int, Tensor]] = None,
|
66
|
+
**kwargs,
|
67
|
+
) -> Union[Explanation, HeteroExplanation]:
|
68
|
+
"""Generate explanations based on attention coefficients."""
|
69
|
+
self.is_hetero = isinstance(x, dict)
|
70
|
+
|
71
|
+
# Collect attention coefficients
|
72
|
+
alphas_dict = self._collect_attention_coefficients(
|
73
|
+
model, x, edge_index, **kwargs)
|
74
|
+
|
75
|
+
# Process attention coefficients
|
76
|
+
if self.is_hetero:
|
77
|
+
return self._create_hetero_explanation(model, alphas_dict,
|
78
|
+
edge_index, index, x)
|
79
|
+
else:
|
80
|
+
return self._create_homo_explanation(model, alphas_dict,
|
81
|
+
edge_index, index, x)
|
82
|
+
|
83
|
+
@overload
|
84
|
+
def _collect_attention_coefficients(
|
85
|
+
self,
|
86
|
+
model: torch.nn.Module,
|
87
|
+
x: Tensor,
|
88
|
+
edge_index: Tensor,
|
89
|
+
**kwargs,
|
90
|
+
) -> List[Tensor]:
|
91
|
+
...
|
92
|
+
|
93
|
+
@overload
|
94
|
+
def _collect_attention_coefficients(
|
95
|
+
self,
|
96
|
+
model: torch.nn.Module,
|
97
|
+
x: Dict[NodeType, Tensor],
|
98
|
+
edge_index: Dict[EdgeType, Tensor],
|
99
|
+
**kwargs,
|
100
|
+
) -> Dict[EdgeType, List[Tensor]]:
|
101
|
+
...
|
102
|
+
|
103
|
+
def _collect_attention_coefficients(
|
104
|
+
self,
|
105
|
+
model: torch.nn.Module,
|
106
|
+
x: Union[Tensor, Dict[NodeType, Tensor]],
|
107
|
+
edge_index: Union[Tensor, Dict[EdgeType, Tensor]],
|
108
|
+
**kwargs,
|
109
|
+
) -> Union[List[Tensor], Dict[EdgeType, List[Tensor]]]:
|
110
|
+
"""Collect attention coefficients from model layers."""
|
111
|
+
if self.is_hetero:
|
112
|
+
# For heterogeneous graphs, store alphas by edge type
|
113
|
+
alphas_dict: Dict[EdgeType, List[Tensor]] = {}
|
114
|
+
|
115
|
+
# Get list of edge types
|
116
|
+
edge_types = list(edge_index.keys())
|
117
|
+
|
118
|
+
# Hook function to capture attention coefficients by edge type
|
119
|
+
def hook(module, msg_kwargs, out):
|
120
|
+
# Find edge type from the module's full name
|
121
|
+
module_name = getattr(module, '_name', None)
|
122
|
+
if module_name is None:
|
123
|
+
return
|
50
124
|
|
51
|
-
|
125
|
+
edge_type = None
|
126
|
+
for edge_tuple in edge_types:
|
127
|
+
src_type, edge_name, dst_type = edge_tuple
|
128
|
+
# Check if all components appear in the module name in
|
129
|
+
# order
|
130
|
+
try:
|
131
|
+
src_idx = module_name.index(src_type)
|
132
|
+
edge_idx = module_name.index(edge_name, src_idx)
|
133
|
+
dst_idx = module_name.index(dst_type, edge_idx)
|
134
|
+
if src_idx < edge_idx < dst_idx:
|
135
|
+
edge_type = edge_tuple
|
136
|
+
break
|
137
|
+
except ValueError: # Component not found
|
138
|
+
continue
|
139
|
+
|
140
|
+
if edge_type is None:
|
141
|
+
return
|
142
|
+
|
143
|
+
if edge_type not in alphas_dict:
|
144
|
+
alphas_dict[edge_type] = []
|
145
|
+
|
146
|
+
# Extract alpha from message kwargs or module
|
147
|
+
if 'alpha' in msg_kwargs[0]:
|
148
|
+
alphas_dict[edge_type].append(
|
149
|
+
msg_kwargs[0]['alpha'].detach())
|
150
|
+
elif getattr(module, '_alpha', None) is not None:
|
151
|
+
alphas_dict[edge_type].append(module._alpha.detach())
|
152
|
+
else:
|
153
|
+
# For homogeneous graphs, store all alphas in a list
|
154
|
+
alphas: List[Tensor] = []
|
52
155
|
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
156
|
+
def hook(module, msg_kwargs, out):
|
157
|
+
if 'alpha' in msg_kwargs[0]:
|
158
|
+
alphas.append(msg_kwargs[0]['alpha'].detach())
|
159
|
+
elif getattr(module, '_alpha', None) is not None:
|
160
|
+
alphas.append(module._alpha.detach())
|
58
161
|
|
162
|
+
# Register hooks for all message passing modules
|
59
163
|
hook_handles = []
|
60
|
-
for module in model.
|
61
|
-
if
|
62
|
-
|
164
|
+
for name, module in model.named_modules():
|
165
|
+
if isinstance(module,
|
166
|
+
MessagePassing) and module.explain is not False:
|
167
|
+
# Store name for hetero graph lookup in the hook
|
168
|
+
if self.is_hetero:
|
169
|
+
module._name = name
|
170
|
+
|
63
171
|
hook_handles.append(module.register_message_forward_hook(hook))
|
64
172
|
|
173
|
+
# Forward pass to collect attention coefficients.
|
65
174
|
model(x, edge_index, **kwargs)
|
66
175
|
|
67
|
-
|
176
|
+
# Remove hooks
|
177
|
+
for handle in hook_handles:
|
68
178
|
handle.remove()
|
69
179
|
|
70
|
-
if
|
71
|
-
|
72
|
-
|
73
|
-
|
180
|
+
# Check if we collected any attention coefficients.
|
181
|
+
if self.is_hetero:
|
182
|
+
if not alphas_dict:
|
183
|
+
raise ValueError(
|
184
|
+
"Could not collect any attention coefficients. "
|
185
|
+
"Please ensure that your model is using "
|
186
|
+
"attention-based GNN layers.")
|
187
|
+
return alphas_dict
|
188
|
+
else:
|
189
|
+
if not alphas:
|
190
|
+
raise ValueError(
|
191
|
+
"Could not collect any attention coefficients. "
|
192
|
+
"Please ensure that your model is using "
|
193
|
+
"attention-based GNN layers.")
|
194
|
+
return alphas
|
74
195
|
|
196
|
+
def _process_attention_coefficients(
|
197
|
+
self,
|
198
|
+
alphas: List[Tensor],
|
199
|
+
edge_index_size: int,
|
200
|
+
) -> Tensor:
|
201
|
+
"""Process collected attention coefficients into a single mask."""
|
75
202
|
for i, alpha in enumerate(alphas):
|
76
|
-
|
203
|
+
# Ensure alpha doesn't exceed edge_index size
|
204
|
+
alpha = alpha[:edge_index_size]
|
205
|
+
|
206
|
+
# Reduce multi-head attention
|
77
207
|
if alpha.dim() == 2:
|
78
208
|
alpha = getattr(torch, self.reduce)(alpha, dim=-1)
|
79
|
-
if isinstance(alpha, tuple): #
|
209
|
+
if isinstance(alpha, tuple): # Handle torch.max output
|
80
210
|
alpha = alpha[0]
|
81
211
|
elif alpha.dim() > 2:
|
82
|
-
raise ValueError(f"
|
212
|
+
raise ValueError(f"Cannot reduce attention coefficients of "
|
83
213
|
f"shape {list(alpha.size())}")
|
84
214
|
alphas[i] = alpha
|
85
215
|
|
216
|
+
# Combine attention coefficients across layers
|
86
217
|
if len(alphas) > 1:
|
87
218
|
alpha = torch.stack(alphas, dim=-1)
|
88
219
|
alpha = getattr(torch, self.reduce)(alpha, dim=-1)
|
89
|
-
if isinstance(alpha, tuple): #
|
220
|
+
if isinstance(alpha, tuple): # Handle torch.max output
|
90
221
|
alpha = alpha[0]
|
91
222
|
else:
|
92
223
|
alpha = alphas[0]
|
93
224
|
|
225
|
+
return alpha
|
226
|
+
|
227
|
+
def _create_homo_explanation(
|
228
|
+
self,
|
229
|
+
model: torch.nn.Module,
|
230
|
+
alphas: List[Tensor],
|
231
|
+
edge_index: Tensor,
|
232
|
+
index: Optional[Union[int, Tensor]],
|
233
|
+
x: Tensor,
|
234
|
+
) -> Explanation:
|
235
|
+
"""Create explanation for homogeneous graph."""
|
236
|
+
# Get hard edge mask for node-level tasks
|
237
|
+
hard_edge_mask = None
|
238
|
+
if self.model_config.task_level == ModelTaskLevel.node:
|
239
|
+
_, hard_edge_mask = self._get_hard_masks(model, index, edge_index,
|
240
|
+
num_nodes=x.size(0))
|
241
|
+
|
242
|
+
# Process attention coefficients
|
243
|
+
alpha = self._process_attention_coefficients(alphas,
|
244
|
+
edge_index.size(1))
|
245
|
+
|
246
|
+
# Post-process mask with hard edge mask if needed
|
94
247
|
alpha = self._post_process_mask(alpha, hard_edge_mask,
|
95
248
|
apply_sigmoid=False)
|
96
249
|
|
97
250
|
return Explanation(edge_mask=alpha)
|
98
251
|
|
252
|
+
def _create_hetero_explanation(
|
253
|
+
self,
|
254
|
+
model: torch.nn.Module,
|
255
|
+
alphas_dict: Dict[EdgeType, List[Tensor]],
|
256
|
+
edge_index: Dict[EdgeType, Tensor],
|
257
|
+
index: Optional[Union[int, Tensor]],
|
258
|
+
x: Dict[NodeType, Tensor],
|
259
|
+
) -> HeteroExplanation:
|
260
|
+
"""Create explanation for heterogeneous graph."""
|
261
|
+
edge_masks_dict = {}
|
262
|
+
|
263
|
+
# Process each edge type separately
|
264
|
+
for edge_type, alphas in alphas_dict.items():
|
265
|
+
if not alphas:
|
266
|
+
continue
|
267
|
+
|
268
|
+
# Get hard edge mask for node-level tasks
|
269
|
+
hard_edge_mask = None
|
270
|
+
if self.model_config.task_level == ModelTaskLevel.node:
|
271
|
+
src_type, _, dst_type = edge_type
|
272
|
+
_, hard_edge_mask = self._get_hard_masks(
|
273
|
+
model, index, edge_index[edge_type],
|
274
|
+
num_nodes=max(x[src_type].size(0), x[dst_type].size(0)))
|
275
|
+
|
276
|
+
# Process attention coefficients for this edge type
|
277
|
+
alpha = self._process_attention_coefficients(
|
278
|
+
alphas, edge_index[edge_type].size(1))
|
279
|
+
|
280
|
+
# Apply hard mask if available
|
281
|
+
edge_masks_dict[edge_type] = self._post_process_mask(
|
282
|
+
alpha, hard_edge_mask, apply_sigmoid=False)
|
283
|
+
|
284
|
+
# Create heterogeneous explanation
|
285
|
+
explanation = HeteroExplanation()
|
286
|
+
explanation.set_value_dict('edge_mask', edge_masks_dict)
|
287
|
+
return explanation
|
288
|
+
|
99
289
|
def supports(self) -> bool:
|
100
290
|
explanation_type = self.explainer_config.explanation_type
|
101
291
|
if explanation_type != ExplanationType.model:
|
File without changes
|
{pyg_nightly-2.7.0.dev20250423.dist-info → pyg_nightly-2.7.0.dev20250425.dist-info}/licenses/LICENSE
RENAMED
File without changes
|