pyg-nightly 2.7.0.dev20250416__py3-none-any.whl → 2.7.0.dev20250418__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pyg-nightly
3
- Version: 2.7.0.dev20250416
3
+ Version: 2.7.0.dev20250418
4
4
  Summary: Graph Neural Network Library for PyTorch
5
5
  Keywords: deep-learning,pytorch,geometric-deep-learning,graph-neural-networks,graph-convolutional-networks
6
6
  Author-email: Matthias Fey <matthias@pyg.org>
@@ -1,4 +1,4 @@
1
- torch_geometric/__init__.py,sha256=en3YAn8JWtZUWaMs1gx9X-Z0xZM9IfGB3fvTW53cJno,1978
1
+ torch_geometric/__init__.py,sha256=UCc69_Z_CAoLuIzpuXALW5om7y4PXXzvqORBng1FG70,1978
2
2
  torch_geometric/_compile.py,sha256=f-WQeH4VLi5Hn9lrgztFUCSrN_FImjhQa6BxFzcYC38,1338
3
3
  torch_geometric/_onnx.py,sha256=V9ffrIKSqhDw6xUZ12lkuSfNs48cQp2EeJ6Z19GfnVw,349
4
4
  torch_geometric/backend.py,sha256=lVaf7aLoVaB3M-UcByUJ1G4T4FOK6LXAg0CF4W3E8jo,1575
@@ -202,7 +202,7 @@ torch_geometric/explain/algorithm/captum_explainer.py,sha256=oz-c40hvdzii4_chEQP
202
202
  torch_geometric/explain/algorithm/dummy_explainer.py,sha256=jvcVQmfngmUWgoKa5p7CXzju2HM5D5DfieJhZW3gbLc,2872
203
203
  torch_geometric/explain/algorithm/gnn_explainer.py,sha256=iu45fGWdd4c6wNczWEAT-29HCAz7ncuoaS6cpx-xDJM,24660
204
204
  torch_geometric/explain/algorithm/graphmask_explainer.py,sha256=T2B081dK-JSpaQmutnkQd5xF3JF49_dPZCOgwqIKJDk,21367
205
- torch_geometric/explain/algorithm/pg_explainer.py,sha256=zPsl0tT9ISSWK1xP1KKpe1ZjUarhSB736WTtqwcmDIo,10372
205
+ torch_geometric/explain/algorithm/pg_explainer.py,sha256=LMlNcqSqtEP-IzYA7Xix6FoAogcrLUaEUAxDVyz2eyc,20162
206
206
  torch_geometric/explain/algorithm/utils.py,sha256=eh0ARPG41V7piVw5jdMYpV0p7WjTlpehnY-bWqPV_zg,2564
207
207
  torch_geometric/explain/metric/__init__.py,sha256=swLeuWVaM3K7UvowsH7q3BzfTq_W1vhcFY8nEP7vFPQ,301
208
208
  torch_geometric/explain/metric/basic.py,sha256=qN-cho4lxwPlw_X26svJrW5QOnw5GB3lLKf0Js_6rBE,1888
@@ -584,7 +584,7 @@ torch_geometric/transforms/to_superpixels.py,sha256=g8ysBv-ezcHn2gHucKuBtnbe-kBD
584
584
  torch_geometric/transforms/to_undirected.py,sha256=oklgrNzev7HjvVaBHwPQFo0RxcQpmcIebNbcv6vNCtY,2972
585
585
  torch_geometric/transforms/two_hop.py,sha256=XxZl3eztTjE00ZlyAIqYu36rjaRddQT-1v4AFF9VUBc,1313
586
586
  torch_geometric/transforms/virtual_node.py,sha256=FMGT6LZBH-SU2zmp76GKNqJBZ8PyS1_6Em2BbVhv8Tw,2932
587
- torch_geometric/utils/__init__.py,sha256=zSiljeQIG8aVXDL9Jowv6WJynfiSLt2w29XzUSu59CI,4930
587
+ torch_geometric/utils/__init__.py,sha256=aVet2bjRvr3URikJ6LpjLATz447YuBS6FuSu5l3JwLY,4982
588
588
  torch_geometric/utils/_assortativity.py,sha256=pe2Hv5xLWhTW7dgqVWNiwDgDVMxMbliTdLeQf5Y65Ug,2347
589
589
  torch_geometric/utils/_coalesce.py,sha256=m4s_maBhib0jByQi6Cd8dazzhFVshZXLfB9aykCZT2g,6769
590
590
  torch_geometric/utils/_degree.py,sha256=FcsGx5cQdrBmoCQ4qQ2csjsTiDICP1as4x1HD9y5XVk,1017
@@ -613,7 +613,7 @@ torch_geometric/utils/augmentation.py,sha256=1F0YCuaklZ9ZbXxdFV0oOoemWvLd8p60WvF
613
613
  torch_geometric/utils/convert.py,sha256=0KEJoBOzU-w-mMQu9QYaMhUqcrGBxBmeRl0hv8NPvII,21697
614
614
  torch_geometric/utils/cross_entropy.py,sha256=ZFS5bivtzv3EV9zqgKsekmuQyoZZggPSclhl_tRNHxo,3047
615
615
  torch_geometric/utils/dropout.py,sha256=gg0rDnD4FLvBaKSoLAkZwViAQflhLefJm6_Mju5dmQs,11416
616
- torch_geometric/utils/embedding.py,sha256=gcWcUv46W0bZBm9puUr7GpPrdzb-PWlD9bpwqBHnA-w,1675
616
+ torch_geometric/utils/embedding.py,sha256=b-CQ-aapEgahxSS7fuL4aNQX6GJROboV0xclZ_MmwO0,5179
617
617
  torch_geometric/utils/functions.py,sha256=orQdS_6EpzWSmBHSok3WhxCzLy9neB-cin1aTnlXY-8,703
618
618
  torch_geometric/utils/geodesic.py,sha256=-xsqE3FZU7Y9gMbucIlGJ4FM-3nk8o0AQBxIdN-QfEw,4770
619
619
  torch_geometric/utils/hetero.py,sha256=ok4uAAOyMiaeEPmvyS4DNoDwdKnLS2gmgs5WVVklxOo,5539
@@ -636,7 +636,7 @@ torch_geometric/utils/undirected.py,sha256=H_nfpI0_WluOG6VfjPyldvcjL4w5USAKWu2x5
636
636
  torch_geometric/visualization/__init__.py,sha256=PyR_4K5SafsJrBr6qWrkjKr6GBL1b7FtZybyXCDEVwY,154
637
637
  torch_geometric/visualization/graph.py,sha256=ZuLPL92yGRi7lxlqsUPwL_EVVXF7P2kMcveTtW79vpA,4784
638
638
  torch_geometric/visualization/influence.py,sha256=CWMvuNA_Nf1sfbJmQgn58yS4OFpeKXeZPe7kEuvkUBw,477
639
- pyg_nightly-2.7.0.dev20250416.dist-info/licenses/LICENSE,sha256=ic-27cMJc1kWoMEYncz3Ya3Ur2Bi3bNLWib2DT763-o,1067
640
- pyg_nightly-2.7.0.dev20250416.dist-info/WHEEL,sha256=G2gURzTEtmeR8nrdXUJfNiB3VYVxigPQ-bEQujpNiNs,82
641
- pyg_nightly-2.7.0.dev20250416.dist-info/METADATA,sha256=yDodf56EgttruZas0nqEgbfnFaNHm03BHQsFi-IkPf0,62979
642
- pyg_nightly-2.7.0.dev20250416.dist-info/RECORD,,
639
+ pyg_nightly-2.7.0.dev20250418.dist-info/licenses/LICENSE,sha256=ic-27cMJc1kWoMEYncz3Ya3Ur2Bi3bNLWib2DT763-o,1067
640
+ pyg_nightly-2.7.0.dev20250418.dist-info/WHEEL,sha256=G2gURzTEtmeR8nrdXUJfNiB3VYVxigPQ-bEQujpNiNs,82
641
+ pyg_nightly-2.7.0.dev20250418.dist-info/METADATA,sha256=j8ciN8XlpllB1ulU5n8vM4eWsUnDsDiMaGZOKoyC0Rk,62979
642
+ pyg_nightly-2.7.0.dev20250418.dist-info/RECORD,,
@@ -31,7 +31,7 @@ from .lazy_loader import LazyLoader
31
31
  contrib = LazyLoader('contrib', globals(), 'torch_geometric.contrib')
32
32
  graphgym = LazyLoader('graphgym', globals(), 'torch_geometric.graphgym')
33
33
 
34
- __version__ = '2.7.0.dev20250416'
34
+ __version__ = '2.7.0.dev20250418'
35
35
 
36
36
  __all__ = [
37
37
  'Index',
@@ -1,21 +1,26 @@
1
1
  import logging
2
- from typing import Optional, Union
2
+ from typing import Dict, Optional, Tuple, Union, overload
3
3
 
4
4
  import torch
5
5
  from torch import Tensor
6
6
  from torch.nn import ReLU, Sequential
7
7
 
8
- from torch_geometric.explain import Explanation
8
+ from torch_geometric.explain import Explanation, HeteroExplanation
9
9
  from torch_geometric.explain.algorithm import ExplainerAlgorithm
10
- from torch_geometric.explain.algorithm.utils import clear_masks, set_masks
10
+ from torch_geometric.explain.algorithm.utils import (
11
+ clear_masks,
12
+ set_hetero_masks,
13
+ set_masks,
14
+ )
11
15
  from torch_geometric.explain.config import (
12
16
  ExplanationType,
13
17
  ModelMode,
14
18
  ModelTaskLevel,
15
19
  )
16
- from torch_geometric.nn import Linear
20
+ from torch_geometric.nn import HANConv, HeteroConv, HGTConv, Linear
17
21
  from torch_geometric.nn.inits import reset
18
- from torch_geometric.utils import get_embeddings
22
+ from torch_geometric.typing import EdgeType, NodeType
23
+ from torch_geometric.utils import get_embeddings, get_embeddings_hetero
19
24
 
20
25
 
21
26
  class PGExplainer(ExplainerAlgorithm):
@@ -62,6 +67,13 @@ class PGExplainer(ExplainerAlgorithm):
62
67
  'bias': 0.01,
63
68
  }
64
69
 
70
+ # NOTE: Add more in the future as needed.
71
+ SUPPORTED_HETERO_MODELS = [
72
+ HGTConv,
73
+ HANConv,
74
+ HeteroConv,
75
+ ]
76
+
65
77
  def __init__(self, epochs: int, lr: float = 0.003, **kwargs):
66
78
  super().__init__()
67
79
  self.epochs = epochs
@@ -75,11 +87,13 @@ class PGExplainer(ExplainerAlgorithm):
75
87
  )
76
88
  self.optimizer = torch.optim.Adam(self.mlp.parameters(), lr=lr)
77
89
  self._curr_epoch = -1
90
+ self.is_hetero = False
78
91
 
79
92
  def reset_parameters(self):
80
93
  r"""Resets all learnable parameters of the module."""
81
94
  reset(self.mlp)
82
95
 
96
+ @overload
83
97
  def train(
84
98
  self,
85
99
  epoch: int,
@@ -90,17 +104,44 @@ class PGExplainer(ExplainerAlgorithm):
90
104
  target: Tensor,
91
105
  index: Optional[Union[int, Tensor]] = None,
92
106
  **kwargs,
93
- ):
107
+ ) -> float:
108
+ ...
109
+
110
+ @overload
111
+ def train(
112
+ self,
113
+ epoch: int,
114
+ model: torch.nn.Module,
115
+ x: Dict[NodeType, Tensor],
116
+ edge_index: Dict[EdgeType, Tensor],
117
+ *,
118
+ target: Tensor,
119
+ index: Optional[Union[int, Tensor]] = None,
120
+ **kwargs,
121
+ ) -> float:
122
+ ...
123
+
124
+ def train(
125
+ self,
126
+ epoch: int,
127
+ model: torch.nn.Module,
128
+ x: Union[Tensor, Dict[NodeType, Tensor]],
129
+ edge_index: Union[Tensor, Dict[EdgeType, Tensor]],
130
+ *,
131
+ target: Tensor,
132
+ index: Optional[Union[int, Tensor]] = None,
133
+ **kwargs,
134
+ ) -> float:
94
135
  r"""Trains the underlying explainer model.
95
136
  Needs to be called before being able to make predictions.
96
137
 
97
138
  Args:
98
139
  epoch (int): The current epoch of the training phase.
99
140
  model (torch.nn.Module): The model to explain.
100
- x (torch.Tensor): The input node features of a
101
- homogeneous graph.
102
- edge_index (torch.Tensor): The input edge indices of a homogeneous
103
- graph.
141
+ x (torch.Tensor or Dict[str, torch.Tensor]): The input node
142
+ features. Can be either homogeneous or heterogeneous.
143
+ edge_index (torch.Tensor or Dict[Tuple[str, str, str]): The input
144
+ edge indices. Can be either homogeneous or heterogeneous.
104
145
  target (torch.Tensor): The target of the model.
105
146
  index (int or torch.Tensor, optional): The index of the model
106
147
  output to explain. Needs to be a single index.
@@ -108,9 +149,9 @@ class PGExplainer(ExplainerAlgorithm):
108
149
  **kwargs (optional): Additional keyword arguments passed to
109
150
  :obj:`model`.
110
151
  """
111
- if isinstance(x, dict) or isinstance(edge_index, dict):
112
- raise ValueError(f"Heterogeneous graphs not yet supported in "
113
- f"'{self.__class__.__name__}'")
152
+ self.is_hetero = isinstance(x, dict)
153
+ if self.is_hetero:
154
+ assert isinstance(edge_index, dict)
114
155
 
115
156
  if self.model_config.task_level == ModelTaskLevel.node:
116
157
  if index is None:
@@ -121,35 +162,68 @@ class PGExplainer(ExplainerAlgorithm):
121
162
  raise ValueError(f"Only scalars are supported for the 'index' "
122
163
  f"argument in '{self.__class__.__name__}'")
123
164
 
124
- z = get_embeddings(model, x, edge_index, **kwargs)[-1]
165
+ # Get embeddings based on whether the graph is homogeneous or
166
+ # heterogeneous
167
+ node_embeddings = self._get_embeddings(model, x, edge_index, **kwargs)
125
168
 
169
+ # Train the model
126
170
  self.optimizer.zero_grad()
127
171
  temperature = self._get_temperature(epoch)
128
172
 
129
- inputs = self._get_inputs(z, edge_index, index)
130
- logits = self.mlp(inputs).view(-1)
131
- edge_mask = self._concrete_sample(logits, temperature)
132
- set_masks(model, edge_mask, edge_index, apply_sigmoid=True)
133
-
134
- if self.model_config.task_level == ModelTaskLevel.node:
135
- _, hard_edge_mask = self._get_hard_masks(model, index, edge_index,
136
- num_nodes=x.size(0))
137
- edge_mask = edge_mask[hard_edge_mask]
138
-
173
+ # Process embeddings and generate edge masks
174
+ edge_mask = self._generate_edge_masks(node_embeddings, edge_index,
175
+ index, temperature)
176
+
177
+ # Apply masks to the model
178
+ if self.is_hetero:
179
+ set_hetero_masks(model, edge_mask, edge_index, apply_sigmoid=True)
180
+
181
+ # For node-level tasks, we can compute hard masks
182
+ if self.model_config.task_level == ModelTaskLevel.node:
183
+ # Process each edge type separately
184
+ for edge_type, mask in edge_mask.items():
185
+ # Get the edge indices for this edge type
186
+ edges = edge_index[edge_type]
187
+ src_type, _, dst_type = edge_type
188
+
189
+ # Get hard masks for this specific edge type
190
+ _, hard_mask = self._get_hard_masks(
191
+ model, index, edges,
192
+ num_nodes=max(x[src_type].size(0),
193
+ x[dst_type].size(0)))
194
+
195
+ edge_mask[edge_type] = mask[hard_mask]
196
+ else:
197
+ # Apply masks for homogeneous graphs
198
+ set_masks(model, edge_mask, edge_index, apply_sigmoid=True)
199
+
200
+ # For node-level tasks, we may need to apply hard masks
201
+ hard_edge_mask = None
202
+ if self.model_config.task_level == ModelTaskLevel.node:
203
+ _, hard_edge_mask = self._get_hard_masks(
204
+ model, index, edge_index, num_nodes=x.size(0))
205
+ edge_mask = edge_mask[hard_edge_mask]
206
+
207
+ # Forward pass with masks applied
139
208
  y_hat, y = model(x, edge_index, **kwargs), target
140
209
 
141
210
  if index is not None:
142
211
  y_hat, y = y_hat[index], y[index]
143
212
 
213
+ # Calculate loss
144
214
  loss = self._loss(y_hat, y, edge_mask)
215
+
216
+ # Backward pass and optimization
145
217
  loss.backward()
146
218
  self.optimizer.step()
147
219
 
220
+ # Clean up
148
221
  clear_masks(model)
149
222
  self._curr_epoch = epoch
150
223
 
151
224
  return float(loss)
152
225
 
226
+ @overload
153
227
  def forward(
154
228
  self,
155
229
  model: torch.nn.Module,
@@ -160,9 +234,32 @@ class PGExplainer(ExplainerAlgorithm):
160
234
  index: Optional[Union[int, Tensor]] = None,
161
235
  **kwargs,
162
236
  ) -> Explanation:
163
- if isinstance(x, dict) or isinstance(edge_index, dict):
164
- raise ValueError(f"Heterogeneous graphs not yet supported in "
165
- f"'{self.__class__.__name__}'")
237
+ ...
238
+
239
+ @overload
240
+ def forward(
241
+ self,
242
+ model: torch.nn.Module,
243
+ x: Dict[NodeType, Tensor],
244
+ edge_index: Dict[EdgeType, Tensor],
245
+ *,
246
+ target: Tensor,
247
+ index: Optional[Union[int, Tensor]] = None,
248
+ **kwargs,
249
+ ) -> HeteroExplanation:
250
+ ...
251
+
252
+ def forward(
253
+ self,
254
+ model: torch.nn.Module,
255
+ x: Union[Tensor, Dict[NodeType, Tensor]],
256
+ edge_index: Union[Tensor, Dict[EdgeType, Tensor]],
257
+ *,
258
+ target: Tensor,
259
+ index: Optional[Union[int, Tensor]] = None,
260
+ **kwargs,
261
+ ) -> Union[Explanation, HeteroExplanation]:
262
+ self.is_hetero = isinstance(x, dict)
166
263
 
167
264
  if self._curr_epoch < self.epochs - 1: # Safety check:
168
265
  raise ValueError(f"'{self.__class__.__name__}' is not yet fully "
@@ -171,7 +268,6 @@ class PGExplainer(ExplainerAlgorithm):
171
268
  f"the underlying explainer model by running "
172
269
  f"`explainer.algorithm.train(...)`.")
173
270
 
174
- hard_edge_mask = None
175
271
  if self.model_config.task_level == ModelTaskLevel.node:
176
272
  if index is None:
177
273
  raise ValueError(f"The 'index' argument needs to be provided "
@@ -181,20 +277,55 @@ class PGExplainer(ExplainerAlgorithm):
181
277
  raise ValueError(f"Only scalars are supported for the 'index' "
182
278
  f"argument in '{self.__class__.__name__}'")
183
279
 
184
- # We need to compute hard masks to properly clean up edges and
185
- # nodes attributions not involved during message passing:
186
- _, hard_edge_mask = self._get_hard_masks(model, index, edge_index,
187
- num_nodes=x.size(0))
280
+ # Get embeddings
281
+ node_embeddings = self._get_embeddings(model, x, edge_index, **kwargs)
282
+
283
+ # Generate explanations
284
+ if self.is_hetero:
285
+ # Generate edge masks for each edge type
286
+ edge_masks = {}
287
+
288
+ # Generate masks for each edge type
289
+ for edge_type, edge_idx in edge_index.items():
290
+ src_node_type, _, dst_node_type = edge_type
291
+
292
+ assert src_node_type in node_embeddings
293
+ assert dst_node_type in node_embeddings
294
+
295
+ inputs = self._get_inputs_hetero(node_embeddings, edge_type,
296
+ edge_idx, index)
297
+ logits = self.mlp(inputs).view(-1)
298
+
299
+ # For node-level explanations, get hard masks for this
300
+ # specific edge type
301
+ hard_edge_mask = None
302
+ if self.model_config.task_level == ModelTaskLevel.node:
303
+ _, hard_edge_mask = self._get_hard_masks(
304
+ model, index, edge_idx,
305
+ num_nodes=max(x[src_node_type].size(0),
306
+ x[dst_node_type].size(0)))
188
307
 
189
- z = get_embeddings(model, x, edge_index, **kwargs)[-1]
308
+ # Apply hard mask if available and it has any True values
309
+ edge_masks[edge_type] = self._post_process_mask(
310
+ logits, hard_edge_mask, apply_sigmoid=True)
190
311
 
191
- inputs = self._get_inputs(z, edge_index, index)
192
- logits = self.mlp(inputs).view(-1)
312
+ explanation = HeteroExplanation()
313
+ explanation.set_value_dict('edge_mask', edge_masks)
314
+ return explanation
315
+ else:
316
+ hard_edge_mask = None
317
+ if self.model_config.task_level == ModelTaskLevel.node:
318
+ # We need to compute hard masks to properly clean up edges
319
+ _, hard_edge_mask = self._get_hard_masks(
320
+ model, index, edge_index, num_nodes=x.size(0))
193
321
 
194
- edge_mask = self._post_process_mask(logits, hard_edge_mask,
195
- apply_sigmoid=True)
322
+ inputs = self._get_inputs(node_embeddings, edge_index, index)
323
+ logits = self.mlp(inputs).view(-1)
196
324
 
197
- return Explanation(edge_mask=edge_mask)
325
+ edge_mask = self._post_process_mask(logits, hard_edge_mask,
326
+ apply_sigmoid=True)
327
+
328
+ return Explanation(edge_mask=edge_mask)
198
329
 
199
330
  def supports(self) -> bool:
200
331
  explanation_type = self.explainer_config.explanation_type
@@ -222,6 +353,76 @@ class PGExplainer(ExplainerAlgorithm):
222
353
 
223
354
  ###########################################################################
224
355
 
356
+ def _get_embeddings(self, model: torch.nn.Module, x: Union[Tensor,
357
+ Dict[NodeType,
358
+ Tensor]],
359
+ edge_index: Union[Tensor, Dict[EdgeType, Tensor]],
360
+ **kwargs) -> Union[Tensor, Dict[NodeType, Tensor]]:
361
+ """Get embeddings from the model based on input type."""
362
+ if self.is_hetero:
363
+ # For heterogeneous graphs, get embeddings for each node type
364
+ embeddings_dict = get_embeddings_hetero(
365
+ model,
366
+ self.SUPPORTED_HETERO_MODELS,
367
+ x,
368
+ edge_index,
369
+ **kwargs,
370
+ )
371
+
372
+ # Use the last layer's embeddings for each node type
373
+ last_embedding_dict = {
374
+ node_type: embs[-1] if embs and len(embs) > 0 else None
375
+ for node_type, embs in embeddings_dict.items()
376
+ }
377
+
378
+ # Skip if no embeddings were captured
379
+ if not any(emb is not None
380
+ for emb in last_embedding_dict.values()):
381
+ raise ValueError(
382
+ "No embeddings were captured from the model. "
383
+ "Please check if the model architecture is supported.")
384
+
385
+ return last_embedding_dict
386
+ else:
387
+ # For homogeneous graphs, get embeddings directly
388
+ return get_embeddings(model, x, edge_index, **kwargs)[-1]
389
+
390
+ def _generate_edge_masks(
391
+ self, emb: Union[Tensor, Dict[NodeType, Tensor]],
392
+ edge_index: Union[Tensor,
393
+ Dict[EdgeType,
394
+ Tensor]], index: Optional[Union[int,
395
+ Tensor]],
396
+ temperature: float) -> Union[Tensor, Dict[EdgeType, Tensor]]:
397
+ """Generate edge masks based on embeddings."""
398
+ if self.is_hetero:
399
+ # For heterogeneous graphs, generate masks for each edge type
400
+ edge_masks = {}
401
+
402
+ for edge_type, edge_idx in edge_index.items():
403
+ src, _, dst = edge_type
404
+
405
+ assert src in emb and dst in emb
406
+ # Generate inputs for this edge type
407
+ inputs = self._get_inputs_hetero(emb, edge_type, edge_idx,
408
+ index)
409
+ logits = self.mlp(inputs).view(-1)
410
+ edge_masks[edge_type] = self._concrete_sample(
411
+ logits, temperature)
412
+
413
+ # Ensure we have at least one valid edge mask
414
+ if not edge_masks:
415
+ raise ValueError(
416
+ "Could not generate edge masks for any edge type. "
417
+ "Please ensure the model architecture is supported.")
418
+
419
+ return edge_masks
420
+ else:
421
+ # For homogeneous graphs, generate a single mask
422
+ inputs = self._get_inputs(emb, edge_index, index)
423
+ logits = self.mlp(inputs).view(-1)
424
+ return self._concrete_sample(logits, temperature)
425
+
225
426
  def _get_inputs(self, embedding: Tensor, edge_index: Tensor,
226
427
  index: Optional[int] = None) -> Tensor:
227
428
  zs = [embedding[edge_index[0]], embedding[edge_index[1]]]
@@ -230,6 +431,27 @@ class PGExplainer(ExplainerAlgorithm):
230
431
  zs.append(embedding[index].view(1, -1).repeat(zs[0].size(0), 1))
231
432
  return torch.cat(zs, dim=-1)
232
433
 
434
+ def _get_inputs_hetero(self, embedding_dict: Dict[NodeType, Tensor],
435
+ edge_type: Tuple[str, str, str], edge_index: Tensor,
436
+ index: Optional[int] = None) -> Tensor:
437
+ src, _, dst = edge_type
438
+
439
+ # Get embeddings for source and destination nodes
440
+ src_emb = embedding_dict[src]
441
+ dst_emb = embedding_dict[dst]
442
+
443
+ # Source and destination node embeddings
444
+ zs = [src_emb[edge_index[0]], dst_emb[edge_index[1]]]
445
+
446
+ # For node-level explanations, add the target node embedding
447
+ if self.model_config.task_level == ModelTaskLevel.node:
448
+ assert index is not None
449
+ # Assuming index refers to a node of type 'src'
450
+ target_emb = src_emb[index].view(1, -1).repeat(zs[0].size(0), 1)
451
+ zs.append(target_emb)
452
+
453
+ return torch.cat(zs, dim=-1)
454
+
233
455
  def _get_temperature(self, epoch: int) -> float:
234
456
  temp = self.coeffs['temp']
235
457
  return temp[0] * pow(temp[1] / temp[0], epoch / self.epochs)
@@ -240,19 +462,55 @@ class PGExplainer(ExplainerAlgorithm):
240
462
  eps = (1 - 2 * bias) * torch.rand_like(logits) + bias
241
463
  return (eps.log() - (1 - eps).log() + logits) / temperature
242
464
 
243
- def _loss(self, y_hat: Tensor, y: Tensor, edge_mask: Tensor) -> Tensor:
465
+ def _loss(self, y_hat: Tensor, y: Tensor,
466
+ edge_mask: Union[Tensor, Dict[EdgeType, Tensor]]) -> Tensor:
467
+ # Calculate base loss based on model configuration
468
+ loss = self._calculate_base_loss(y_hat, y)
469
+
470
+ # Apply regularization based on graph type
471
+ if self.is_hetero:
472
+ loss = self._apply_hetero_regularization(loss, edge_mask)
473
+ else:
474
+ loss = self._apply_homo_regularization(loss, edge_mask)
475
+
476
+ return loss
477
+
478
+ def _calculate_base_loss(self, y_hat: Tensor, y: Tensor) -> Tensor:
479
+ """Calculate base loss based on model configuration."""
244
480
  if self.model_config.mode == ModelMode.binary_classification:
245
- loss = self._loss_binary_classification(y_hat, y)
481
+ return self._loss_binary_classification(y_hat, y)
246
482
  elif self.model_config.mode == ModelMode.multiclass_classification:
247
- loss = self._loss_multiclass_classification(y_hat, y)
483
+ return self._loss_multiclass_classification(y_hat, y)
248
484
  elif self.model_config.mode == ModelMode.regression:
249
- loss = self._loss_regression(y_hat, y)
250
-
251
- # Regularization loss:
252
- mask = edge_mask.sigmoid()
485
+ return self._loss_regression(y_hat, y)
486
+ else:
487
+ raise ValueError(
488
+ f"Unsupported model mode: {self.model_config.mode}")
489
+
490
+ def _apply_hetero_regularization(
491
+ self, loss: Tensor, edge_mask: Dict[EdgeType, Tensor]) -> Tensor:
492
+ """Apply regularization for heterogeneous graph."""
493
+ for _, mask in edge_mask.items():
494
+ loss = self._add_mask_regularization(loss, mask)
495
+
496
+ return loss
497
+
498
+ def _apply_homo_regularization(self, loss: Tensor,
499
+ edge_mask: Tensor) -> Tensor:
500
+ """Apply regularization for homogeneous graph."""
501
+ return self._add_mask_regularization(loss, edge_mask)
502
+
503
+ def _add_mask_regularization(self, loss: Tensor, mask: Tensor) -> Tensor:
504
+ """Add size and entropy regularization for a mask."""
505
+ # Apply sigmoid for mask values
506
+ mask = mask.sigmoid()
507
+
508
+ # Size regularization
253
509
  size_loss = mask.sum() * self.coeffs['edge_size']
254
- mask = 0.99 * mask + 0.005
255
- mask_ent = -mask * mask.log() - (1 - mask) * (1 - mask).log()
510
+
511
+ # Entropy regularization
512
+ masked = 0.99 * mask + 0.005
513
+ mask_ent = -masked * masked.log() - (1 - masked) * (1 - masked).log()
256
514
  mask_ent_loss = mask_ent.mean() * self.coeffs['edge_ent']
257
515
 
258
516
  return loss + size_loss + mask_ent_loss
@@ -53,7 +53,7 @@ from ._negative_sampling import (negative_sampling, batched_negative_sampling,
53
53
  structured_negative_sampling_feasible)
54
54
  from .augmentation import shuffle_node, mask_feature, add_random_edge
55
55
  from ._tree_decomposition import tree_decomposition
56
- from .embedding import get_embeddings
56
+ from .embedding import get_embeddings, get_embeddings_hetero
57
57
  from ._trim_to_layer import trim_to_layer
58
58
  from .ppr import get_ppr
59
59
  from ._train_test_split_edges import train_test_split_edges
@@ -145,6 +145,7 @@ __all__ = [
145
145
  'add_random_edge',
146
146
  'tree_decomposition',
147
147
  'get_embeddings',
148
+ 'get_embeddings_hetero',
148
149
  'trim_to_layer',
149
150
  'get_ppr',
150
151
  'train_test_split_edges',
@@ -1,9 +1,11 @@
1
1
  import warnings
2
- from typing import Any, List
2
+ from typing import Any, Dict, List, Optional, Type
3
3
 
4
4
  import torch
5
5
  from torch import Tensor
6
6
 
7
+ from torch_geometric.typing import NodeType
8
+
7
9
 
8
10
  def get_embeddings(
9
11
  model: torch.nn.Module,
@@ -52,3 +54,88 @@ def get_embeddings(
52
54
  handle.remove()
53
55
 
54
56
  return embeddings
57
+
58
+
59
+ def get_embeddings_hetero(
60
+ model: torch.nn.Module,
61
+ supported_models: Optional[List[Type[torch.nn.Module]]] = None,
62
+ *args: Any,
63
+ **kwargs: Any,
64
+ ) -> Dict[NodeType, List[Tensor]]:
65
+ """Returns the output embeddings of all
66
+ :class:`~torch_geometric.nn.conv.MessagePassing` layers in a heterogeneous
67
+ :obj:`model`, organized by edge type.
68
+
69
+ Internally, this method registers forward hooks on all modules that process
70
+ heterogeneous graphs in the model and runs the forward pass of the model.
71
+ For heterogeneous models, the output is a dictionary where each key is a
72
+ node type and each value is a list of embeddings from different layers.
73
+
74
+ Args:
75
+ model (torch.nn.Module): The heterogeneous GNN model.
76
+ supported_models (List[Type[torch.nn.Module]], optional): A list of
77
+ supported model classes. If not provided, defaults to
78
+ [HGTConv, HANConv, HeteroConv].
79
+ *args: Arguments passed to the model.
80
+ **kwargs (optional): Additional keyword arguments passed to the model.
81
+
82
+ Returns:
83
+ Dict[NodeType, List[Tensor]]: A dictionary mapping each node type to
84
+ a list of embeddings from different layers.
85
+ """
86
+ from torch_geometric.nn import HANConv, HeteroConv, HGTConv
87
+ if not supported_models:
88
+ supported_models = [HGTConv, HANConv, HeteroConv]
89
+
90
+ # Dictionary to store node embeddings by type
91
+ node_embeddings_dict: Dict[NodeType, List[Tensor]] = {}
92
+
93
+ # Hook function to capture node embeddings
94
+ def hook(model: torch.nn.Module, inputs: Any, outputs: Any) -> None:
95
+ # Check if the outputs is a dictionary mapping node types to embeddings
96
+ if isinstance(outputs, dict) and outputs:
97
+ # Store embeddings for each node type
98
+ for node_type, embedding in outputs.items():
99
+ # Made sure that the outputs are a dictionary mapping node
100
+ # types to embeddings and remove the false positives.
101
+ if node_type not in node_embeddings_dict:
102
+ node_embeddings_dict[node_type] = []
103
+ node_embeddings_dict[node_type].append(embedding.clone())
104
+
105
+ # List to store hook handles
106
+ hook_handles = []
107
+
108
+ # Find ModuleDict objects in the model
109
+ for _, module in model.named_modules():
110
+ # Handle the native heterogenous models, e.g. HGTConv, HANConv
111
+ # and HeteroConv, etc.
112
+ if isinstance(module, tuple(supported_models)):
113
+ hook_handles.append(module.register_forward_hook(hook))
114
+ else:
115
+ # Handle the heterogenous models that are generated by calling
116
+ # to_hetero() on the homogeneous models.
117
+ submodules = list(module.children())
118
+ submodules_contains_module_dict = any([
119
+ isinstance(submodule, torch.nn.ModuleDict)
120
+ for submodule in submodules
121
+ ])
122
+ if submodules_contains_module_dict:
123
+ hook_handles.append(module.register_forward_hook(hook))
124
+
125
+ if len(hook_handles) == 0:
126
+ warnings.warn("The 'model' does not have any heterogenous "
127
+ "'MessagePassing' layers")
128
+
129
+ # Run the model forward pass
130
+ training = model.training
131
+ model.eval()
132
+
133
+ with torch.no_grad():
134
+ model(*args, **kwargs)
135
+ model.train(training)
136
+
137
+ # Clean up hooks
138
+ for handle in hook_handles:
139
+ handle.remove()
140
+
141
+ return node_embeddings_dict