pyg-nightly 2.7.0.dev20250423__py3-none-any.whl → 2.7.0.dev20250424__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pyg-nightly
3
- Version: 2.7.0.dev20250423
3
+ Version: 2.7.0.dev20250424
4
4
  Summary: Graph Neural Network Library for PyTorch
5
5
  Keywords: deep-learning,pytorch,geometric-deep-learning,graph-neural-networks,graph-convolutional-networks
6
6
  Author-email: Matthias Fey <matthias@pyg.org>
@@ -1,4 +1,4 @@
1
- torch_geometric/__init__.py,sha256=BRFGlcJ1nDb4azLG0EB7y7xon3yTqDiN8YOQ9ACnscI,1978
1
+ torch_geometric/__init__.py,sha256=ERKw5z1mre0dUq9Ds7it-e7OfZSTr-CjZED7ofIMMW4,1978
2
2
  torch_geometric/_compile.py,sha256=f-WQeH4VLi5Hn9lrgztFUCSrN_FImjhQa6BxFzcYC38,1338
3
3
  torch_geometric/_onnx.py,sha256=V9ffrIKSqhDw6xUZ12lkuSfNs48cQp2EeJ6Z19GfnVw,349
4
4
  torch_geometric/backend.py,sha256=lVaf7aLoVaB3M-UcByUJ1G4T4FOK6LXAg0CF4W3E8jo,1575
@@ -195,7 +195,7 @@ torch_geometric/explain/config.py,sha256=_0j67NAwPwjrWHPncNywCT-oKyMiryJNxufxVN1
195
195
  torch_geometric/explain/explainer.py,sha256=8_NZTmlT4WO9RgKxpSUQRt3rbVwFURF5bSWOPlfOLjA,10667
196
196
  torch_geometric/explain/explanation.py,sha256=Z2NlgavEnq0QadEr6p6pxAhV6lU7WrlcJLFWbTdsmvg,14903
197
197
  torch_geometric/explain/algorithm/__init__.py,sha256=fE29xbd0bPxg-EfrB2BDmmY9QnyO-7TgvYduGHofm5o,496
198
- torch_geometric/explain/algorithm/attention_explainer.py,sha256=iRWgrUVoAn42DpVPE0jZclLB6OtUOArKl5dn53WmCc4,4545
198
+ torch_geometric/explain/algorithm/attention_explainer.py,sha256=65iGLmOt00ERtBDVxAoydIchykdWZU24aXzSzUGzQEI,11304
199
199
  torch_geometric/explain/algorithm/base.py,sha256=wwJcREUFKDLFUDjRa9o4X3DWqQgMvhS3Iciwb6Evtjc,6922
200
200
  torch_geometric/explain/algorithm/captum.py,sha256=k6hNgC5Kn9lVirOYVJzej8-hRuf5C2mPFUXFLd2wWsY,12857
201
201
  torch_geometric/explain/algorithm/captum_explainer.py,sha256=oz-c40hvdzii4_chEQPHzQo_dFjHr9HLuJhDLsqRIVU,7346
@@ -636,7 +636,7 @@ torch_geometric/utils/undirected.py,sha256=H_nfpI0_WluOG6VfjPyldvcjL4w5USAKWu2x5
636
636
  torch_geometric/visualization/__init__.py,sha256=PyR_4K5SafsJrBr6qWrkjKr6GBL1b7FtZybyXCDEVwY,154
637
637
  torch_geometric/visualization/graph.py,sha256=ZuLPL92yGRi7lxlqsUPwL_EVVXF7P2kMcveTtW79vpA,4784
638
638
  torch_geometric/visualization/influence.py,sha256=CWMvuNA_Nf1sfbJmQgn58yS4OFpeKXeZPe7kEuvkUBw,477
639
- pyg_nightly-2.7.0.dev20250423.dist-info/licenses/LICENSE,sha256=ic-27cMJc1kWoMEYncz3Ya3Ur2Bi3bNLWib2DT763-o,1067
640
- pyg_nightly-2.7.0.dev20250423.dist-info/WHEEL,sha256=G2gURzTEtmeR8nrdXUJfNiB3VYVxigPQ-bEQujpNiNs,82
641
- pyg_nightly-2.7.0.dev20250423.dist-info/METADATA,sha256=OOH3EHfpXcmpGeyXvgYej5Brib1sYdXOJDrJkPwGaXs,62979
642
- pyg_nightly-2.7.0.dev20250423.dist-info/RECORD,,
639
+ pyg_nightly-2.7.0.dev20250424.dist-info/licenses/LICENSE,sha256=ic-27cMJc1kWoMEYncz3Ya3Ur2Bi3bNLWib2DT763-o,1067
640
+ pyg_nightly-2.7.0.dev20250424.dist-info/WHEEL,sha256=G2gURzTEtmeR8nrdXUJfNiB3VYVxigPQ-bEQujpNiNs,82
641
+ pyg_nightly-2.7.0.dev20250424.dist-info/METADATA,sha256=fQqhKZ4bkgt-_tyrSNhIMz0ZUixa4ctPYvIPGdWCQzo,62979
642
+ pyg_nightly-2.7.0.dev20250424.dist-info/RECORD,,
@@ -31,7 +31,7 @@ from .lazy_loader import LazyLoader
31
31
  contrib = LazyLoader('contrib', globals(), 'torch_geometric.contrib')
32
32
  graphgym = LazyLoader('graphgym', globals(), 'torch_geometric.graphgym')
33
33
 
34
- __version__ = '2.7.0.dev20250423'
34
+ __version__ = '2.7.0.dev20250424'
35
35
 
36
36
  __all__ = [
37
37
  'Index',
@@ -1,13 +1,14 @@
1
1
  import logging
2
- from typing import List, Optional, Union
2
+ from typing import Dict, List, Optional, Union, overload
3
3
 
4
4
  import torch
5
5
  from torch import Tensor
6
6
 
7
- from torch_geometric.explain import Explanation
7
+ from torch_geometric.explain import Explanation, HeteroExplanation
8
8
  from torch_geometric.explain.algorithm import ExplainerAlgorithm
9
9
  from torch_geometric.explain.config import ExplanationType, ModelTaskLevel
10
10
  from torch_geometric.nn.conv.message_passing import MessagePassing
11
+ from torch_geometric.typing import EdgeType, NodeType
11
12
 
12
13
 
13
14
  class AttentionExplainer(ExplainerAlgorithm):
@@ -26,7 +27,9 @@ class AttentionExplainer(ExplainerAlgorithm):
26
27
  def __init__(self, reduce: str = 'max'):
27
28
  super().__init__()
28
29
  self.reduce = reduce
30
+ self.is_hetero = False
29
31
 
32
+ @overload
30
33
  def forward(
31
34
  self,
32
35
  model: torch.nn.Module,
@@ -37,65 +40,252 @@ class AttentionExplainer(ExplainerAlgorithm):
37
40
  index: Optional[Union[int, Tensor]] = None,
38
41
  **kwargs,
39
42
  ) -> Explanation:
40
- if isinstance(x, dict) or isinstance(edge_index, dict):
41
- raise ValueError(f"Heterogeneous graphs not yet supported in "
42
- f"'{self.__class__.__name__}'")
43
+ ...
43
44
 
44
- hard_edge_mask = None
45
- if self.model_config.task_level == ModelTaskLevel.node:
46
- # We need to compute the hard edge mask to properly clean up edge
47
- # attributions not involved during message passing:
48
- _, hard_edge_mask = self._get_hard_masks(model, index, edge_index,
49
- num_nodes=x.size(0))
45
+ @overload
46
+ def forward(
47
+ self,
48
+ model: torch.nn.Module,
49
+ x: Dict[NodeType, Tensor],
50
+ edge_index: Dict[EdgeType, Tensor],
51
+ *,
52
+ target: Tensor,
53
+ index: Optional[Union[int, Tensor]] = None,
54
+ **kwargs,
55
+ ) -> HeteroExplanation:
56
+ ...
57
+
58
+ def forward(
59
+ self,
60
+ model: torch.nn.Module,
61
+ x: Union[Tensor, Dict[NodeType, Tensor]],
62
+ edge_index: Union[Tensor, Dict[EdgeType, Tensor]],
63
+ *,
64
+ target: Tensor,
65
+ index: Optional[Union[int, Tensor]] = None,
66
+ **kwargs,
67
+ ) -> Union[Explanation, HeteroExplanation]:
68
+ """Generate explanations based on attention coefficients."""
69
+ self.is_hetero = isinstance(x, dict)
70
+
71
+ # Collect attention coefficients
72
+ alphas_dict = self._collect_attention_coefficients(
73
+ model, x, edge_index, **kwargs)
74
+
75
+ # Process attention coefficients
76
+ if self.is_hetero:
77
+ return self._create_hetero_explanation(model, alphas_dict,
78
+ edge_index, index, x)
79
+ else:
80
+ return self._create_homo_explanation(model, alphas_dict,
81
+ edge_index, index, x)
82
+
83
+ @overload
84
+ def _collect_attention_coefficients(
85
+ self,
86
+ model: torch.nn.Module,
87
+ x: Tensor,
88
+ edge_index: Tensor,
89
+ **kwargs,
90
+ ) -> List[Tensor]:
91
+ ...
92
+
93
+ @overload
94
+ def _collect_attention_coefficients(
95
+ self,
96
+ model: torch.nn.Module,
97
+ x: Dict[NodeType, Tensor],
98
+ edge_index: Dict[EdgeType, Tensor],
99
+ **kwargs,
100
+ ) -> Dict[EdgeType, List[Tensor]]:
101
+ ...
102
+
103
+ def _collect_attention_coefficients(
104
+ self,
105
+ model: torch.nn.Module,
106
+ x: Union[Tensor, Dict[NodeType, Tensor]],
107
+ edge_index: Union[Tensor, Dict[EdgeType, Tensor]],
108
+ **kwargs,
109
+ ) -> Union[List[Tensor], Dict[EdgeType, List[Tensor]]]:
110
+ """Collect attention coefficients from model layers."""
111
+ if self.is_hetero:
112
+ # For heterogeneous graphs, store alphas by edge type
113
+ alphas_dict: Dict[EdgeType, List[Tensor]] = {}
114
+
115
+ # Get list of edge types
116
+ edge_types = list(edge_index.keys())
117
+
118
+ # Hook function to capture attention coefficients by edge type
119
+ def hook(module, msg_kwargs, out):
120
+ # Find edge type from the module's full name
121
+ module_name = getattr(module, '_name', None)
122
+ if module_name is None:
123
+ return
50
124
 
51
- alphas: List[Tensor] = []
125
+ edge_type = None
126
+ for edge_tuple in edge_types:
127
+ src_type, edge_name, dst_type = edge_tuple
128
+ # Check if all components appear in the module name in
129
+ # order
130
+ try:
131
+ src_idx = module_name.index(src_type)
132
+ edge_idx = module_name.index(edge_name, src_idx)
133
+ dst_idx = module_name.index(dst_type, edge_idx)
134
+ if src_idx < edge_idx < dst_idx:
135
+ edge_type = edge_tuple
136
+ break
137
+ except ValueError: # Component not found
138
+ continue
139
+
140
+ if edge_type is None:
141
+ return
142
+
143
+ if edge_type not in alphas_dict:
144
+ alphas_dict[edge_type] = []
145
+
146
+ # Extract alpha from message kwargs or module
147
+ if 'alpha' in msg_kwargs[0]:
148
+ alphas_dict[edge_type].append(
149
+ msg_kwargs[0]['alpha'].detach())
150
+ elif getattr(module, '_alpha', None) is not None:
151
+ alphas_dict[edge_type].append(module._alpha.detach())
152
+ else:
153
+ # For homogeneous graphs, store all alphas in a list
154
+ alphas: List[Tensor] = []
52
155
 
53
- def hook(module, msg_kwargs, out):
54
- if 'alpha' in msg_kwargs[0]:
55
- alphas.append(msg_kwargs[0]['alpha'].detach())
56
- elif getattr(module, '_alpha', None) is not None:
57
- alphas.append(module._alpha.detach())
156
+ def hook(module, msg_kwargs, out):
157
+ if 'alpha' in msg_kwargs[0]:
158
+ alphas.append(msg_kwargs[0]['alpha'].detach())
159
+ elif getattr(module, '_alpha', None) is not None:
160
+ alphas.append(module._alpha.detach())
58
161
 
162
+ # Register hooks for all message passing modules
59
163
  hook_handles = []
60
- for module in model.modules(): # Register message forward hooks:
61
- if (isinstance(module, MessagePassing)
62
- and module.explain is not False):
164
+ for name, module in model.named_modules():
165
+ if isinstance(module,
166
+ MessagePassing) and module.explain is not False:
167
+ # Store name for hetero graph lookup in the hook
168
+ if self.is_hetero:
169
+ module._name = name
170
+
63
171
  hook_handles.append(module.register_message_forward_hook(hook))
64
172
 
173
+ # Forward pass to collect attention coefficients.
65
174
  model(x, edge_index, **kwargs)
66
175
 
67
- for handle in hook_handles: # Remove hooks:
176
+ # Remove hooks
177
+ for handle in hook_handles:
68
178
  handle.remove()
69
179
 
70
- if len(alphas) == 0:
71
- raise ValueError("Could not collect any attention coefficients. "
72
- "Please ensure that your model is using "
73
- "attention-based GNN layers.")
180
+ # Check if we collected any attention coefficients.
181
+ if self.is_hetero:
182
+ if not alphas_dict:
183
+ raise ValueError(
184
+ "Could not collect any attention coefficients. "
185
+ "Please ensure that your model is using "
186
+ "attention-based GNN layers.")
187
+ return alphas_dict
188
+ else:
189
+ if not alphas:
190
+ raise ValueError(
191
+ "Could not collect any attention coefficients. "
192
+ "Please ensure that your model is using "
193
+ "attention-based GNN layers.")
194
+ return alphas
74
195
 
196
+ def _process_attention_coefficients(
197
+ self,
198
+ alphas: List[Tensor],
199
+ edge_index_size: int,
200
+ ) -> Tensor:
201
+ """Process collected attention coefficients into a single mask."""
75
202
  for i, alpha in enumerate(alphas):
76
- alpha = alpha[:edge_index.size(1)] # Respect potential self-loops.
203
+ # Ensure alpha doesn't exceed edge_index size
204
+ alpha = alpha[:edge_index_size]
205
+
206
+ # Reduce multi-head attention
77
207
  if alpha.dim() == 2:
78
208
  alpha = getattr(torch, self.reduce)(alpha, dim=-1)
79
- if isinstance(alpha, tuple): # Respect `torch.max`:
209
+ if isinstance(alpha, tuple): # Handle torch.max output
80
210
  alpha = alpha[0]
81
211
  elif alpha.dim() > 2:
82
- raise ValueError(f"Can not reduce attention coefficients of "
212
+ raise ValueError(f"Cannot reduce attention coefficients of "
83
213
  f"shape {list(alpha.size())}")
84
214
  alphas[i] = alpha
85
215
 
216
+ # Combine attention coefficients across layers
86
217
  if len(alphas) > 1:
87
218
  alpha = torch.stack(alphas, dim=-1)
88
219
  alpha = getattr(torch, self.reduce)(alpha, dim=-1)
89
- if isinstance(alpha, tuple): # Respect `torch.max`:
220
+ if isinstance(alpha, tuple): # Handle torch.max output
90
221
  alpha = alpha[0]
91
222
  else:
92
223
  alpha = alphas[0]
93
224
 
225
+ return alpha
226
+
227
+ def _create_homo_explanation(
228
+ self,
229
+ model: torch.nn.Module,
230
+ alphas: List[Tensor],
231
+ edge_index: Tensor,
232
+ index: Optional[Union[int, Tensor]],
233
+ x: Tensor,
234
+ ) -> Explanation:
235
+ """Create explanation for homogeneous graph."""
236
+ # Get hard edge mask for node-level tasks
237
+ hard_edge_mask = None
238
+ if self.model_config.task_level == ModelTaskLevel.node:
239
+ _, hard_edge_mask = self._get_hard_masks(model, index, edge_index,
240
+ num_nodes=x.size(0))
241
+
242
+ # Process attention coefficients
243
+ alpha = self._process_attention_coefficients(alphas,
244
+ edge_index.size(1))
245
+
246
+ # Post-process mask with hard edge mask if needed
94
247
  alpha = self._post_process_mask(alpha, hard_edge_mask,
95
248
  apply_sigmoid=False)
96
249
 
97
250
  return Explanation(edge_mask=alpha)
98
251
 
252
+ def _create_hetero_explanation(
253
+ self,
254
+ model: torch.nn.Module,
255
+ alphas_dict: Dict[EdgeType, List[Tensor]],
256
+ edge_index: Dict[EdgeType, Tensor],
257
+ index: Optional[Union[int, Tensor]],
258
+ x: Dict[NodeType, Tensor],
259
+ ) -> HeteroExplanation:
260
+ """Create explanation for heterogeneous graph."""
261
+ edge_masks_dict = {}
262
+
263
+ # Process each edge type separately
264
+ for edge_type, alphas in alphas_dict.items():
265
+ if not alphas:
266
+ continue
267
+
268
+ # Get hard edge mask for node-level tasks
269
+ hard_edge_mask = None
270
+ if self.model_config.task_level == ModelTaskLevel.node:
271
+ src_type, _, dst_type = edge_type
272
+ _, hard_edge_mask = self._get_hard_masks(
273
+ model, index, edge_index[edge_type],
274
+ num_nodes=max(x[src_type].size(0), x[dst_type].size(0)))
275
+
276
+ # Process attention coefficients for this edge type
277
+ alpha = self._process_attention_coefficients(
278
+ alphas, edge_index[edge_type].size(1))
279
+
280
+ # Apply hard mask if available
281
+ edge_masks_dict[edge_type] = self._post_process_mask(
282
+ alpha, hard_edge_mask, apply_sigmoid=False)
283
+
284
+ # Create heterogeneous explanation
285
+ explanation = HeteroExplanation()
286
+ explanation.set_value_dict('edge_mask', edge_masks_dict)
287
+ return explanation
288
+
99
289
  def supports(self) -> bool:
100
290
  explanation_type = self.explainer_config.explanation_type
101
291
  if explanation_type != ExplanationType.model: