pyg-nightly 2.7.0.dev20241009__py3-none-any.whl → 2.8.0.dev20251228__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (229) hide show
  1. {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251228.dist-info}/METADATA +77 -53
  2. {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251228.dist-info}/RECORD +227 -190
  3. {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251228.dist-info}/WHEEL +1 -1
  4. pyg_nightly-2.8.0.dev20251228.dist-info/licenses/LICENSE +19 -0
  5. torch_geometric/__init__.py +14 -2
  6. torch_geometric/_compile.py +9 -3
  7. torch_geometric/_onnx.py +214 -0
  8. torch_geometric/config_mixin.py +5 -3
  9. torch_geometric/config_store.py +1 -1
  10. torch_geometric/contrib/__init__.py +1 -1
  11. torch_geometric/contrib/explain/pgm_explainer.py +1 -1
  12. torch_geometric/data/batch.py +2 -2
  13. torch_geometric/data/collate.py +1 -3
  14. torch_geometric/data/data.py +109 -5
  15. torch_geometric/data/database.py +4 -0
  16. torch_geometric/data/dataset.py +14 -11
  17. torch_geometric/data/extract.py +1 -1
  18. torch_geometric/data/feature_store.py +17 -22
  19. torch_geometric/data/graph_store.py +3 -2
  20. torch_geometric/data/hetero_data.py +139 -7
  21. torch_geometric/data/hypergraph_data.py +2 -2
  22. torch_geometric/data/in_memory_dataset.py +2 -2
  23. torch_geometric/data/lightning/datamodule.py +42 -28
  24. torch_geometric/data/storage.py +9 -1
  25. torch_geometric/datasets/__init__.py +18 -1
  26. torch_geometric/datasets/actor.py +7 -9
  27. torch_geometric/datasets/airfrans.py +15 -17
  28. torch_geometric/datasets/airports.py +8 -10
  29. torch_geometric/datasets/amazon.py +8 -11
  30. torch_geometric/datasets/amazon_book.py +8 -9
  31. torch_geometric/datasets/amazon_products.py +7 -9
  32. torch_geometric/datasets/aminer.py +8 -9
  33. torch_geometric/datasets/aqsol.py +10 -13
  34. torch_geometric/datasets/attributed_graph_dataset.py +8 -10
  35. torch_geometric/datasets/ba_multi_shapes.py +10 -12
  36. torch_geometric/datasets/ba_shapes.py +5 -6
  37. torch_geometric/datasets/city.py +157 -0
  38. torch_geometric/datasets/dbp15k.py +1 -1
  39. torch_geometric/datasets/git_mol_dataset.py +263 -0
  40. torch_geometric/datasets/hgb_dataset.py +2 -2
  41. torch_geometric/datasets/hm.py +1 -1
  42. torch_geometric/datasets/instruct_mol_dataset.py +134 -0
  43. torch_geometric/datasets/md17.py +3 -3
  44. torch_geometric/datasets/medshapenet.py +145 -0
  45. torch_geometric/datasets/modelnet.py +1 -1
  46. torch_geometric/datasets/molecule_gpt_dataset.py +492 -0
  47. torch_geometric/datasets/molecule_net.py +3 -2
  48. torch_geometric/datasets/ppi.py +2 -1
  49. torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
  50. torch_geometric/datasets/qm7.py +1 -1
  51. torch_geometric/datasets/qm9.py +1 -1
  52. torch_geometric/datasets/snap_dataset.py +8 -4
  53. torch_geometric/datasets/tag_dataset.py +462 -0
  54. torch_geometric/datasets/teeth3ds.py +269 -0
  55. torch_geometric/datasets/web_qsp_dataset.py +310 -209
  56. torch_geometric/datasets/wikics.py +2 -1
  57. torch_geometric/deprecation.py +1 -1
  58. torch_geometric/distributed/__init__.py +13 -0
  59. torch_geometric/distributed/dist_loader.py +2 -2
  60. torch_geometric/distributed/partition.py +2 -2
  61. torch_geometric/distributed/rpc.py +3 -3
  62. torch_geometric/edge_index.py +18 -14
  63. torch_geometric/explain/algorithm/attention_explainer.py +219 -29
  64. torch_geometric/explain/algorithm/base.py +2 -2
  65. torch_geometric/explain/algorithm/captum.py +1 -1
  66. torch_geometric/explain/algorithm/captum_explainer.py +2 -1
  67. torch_geometric/explain/algorithm/gnn_explainer.py +406 -69
  68. torch_geometric/explain/algorithm/graphmask_explainer.py +8 -8
  69. torch_geometric/explain/algorithm/pg_explainer.py +305 -47
  70. torch_geometric/explain/explainer.py +2 -2
  71. torch_geometric/explain/explanation.py +87 -3
  72. torch_geometric/explain/metric/faithfulness.py +1 -1
  73. torch_geometric/graphgym/config.py +3 -2
  74. torch_geometric/graphgym/imports.py +15 -4
  75. torch_geometric/graphgym/logger.py +1 -1
  76. torch_geometric/graphgym/loss.py +1 -1
  77. torch_geometric/graphgym/models/encoder.py +2 -2
  78. torch_geometric/graphgym/models/layer.py +1 -1
  79. torch_geometric/graphgym/utils/comp_budget.py +4 -3
  80. torch_geometric/hash_tensor.py +798 -0
  81. torch_geometric/index.py +14 -5
  82. torch_geometric/inspector.py +4 -0
  83. torch_geometric/io/fs.py +5 -4
  84. torch_geometric/llm/__init__.py +9 -0
  85. torch_geometric/llm/large_graph_indexer.py +741 -0
  86. torch_geometric/llm/models/__init__.py +23 -0
  87. torch_geometric/{nn → llm}/models/g_retriever.py +77 -45
  88. torch_geometric/llm/models/git_mol.py +336 -0
  89. torch_geometric/llm/models/glem.py +397 -0
  90. torch_geometric/{nn/nlp → llm/models}/llm.py +180 -32
  91. torch_geometric/llm/models/llm_judge.py +158 -0
  92. torch_geometric/llm/models/molecule_gpt.py +222 -0
  93. torch_geometric/llm/models/protein_mpnn.py +333 -0
  94. torch_geometric/llm/models/sentence_transformer.py +188 -0
  95. torch_geometric/llm/models/txt2kg.py +353 -0
  96. torch_geometric/llm/models/vision_transformer.py +38 -0
  97. torch_geometric/llm/rag_loader.py +154 -0
  98. torch_geometric/llm/utils/__init__.py +10 -0
  99. torch_geometric/llm/utils/backend_utils.py +443 -0
  100. torch_geometric/llm/utils/feature_store.py +169 -0
  101. torch_geometric/llm/utils/graph_store.py +199 -0
  102. torch_geometric/llm/utils/vectorrag.py +125 -0
  103. torch_geometric/loader/cluster.py +4 -4
  104. torch_geometric/loader/ibmb_loader.py +4 -4
  105. torch_geometric/loader/link_loader.py +1 -1
  106. torch_geometric/loader/link_neighbor_loader.py +2 -1
  107. torch_geometric/loader/mixin.py +6 -5
  108. torch_geometric/loader/neighbor_loader.py +1 -1
  109. torch_geometric/loader/neighbor_sampler.py +2 -2
  110. torch_geometric/loader/prefetch.py +3 -2
  111. torch_geometric/loader/temporal_dataloader.py +2 -2
  112. torch_geometric/loader/utils.py +10 -10
  113. torch_geometric/metrics/__init__.py +14 -0
  114. torch_geometric/metrics/link_pred.py +745 -92
  115. torch_geometric/nn/__init__.py +1 -0
  116. torch_geometric/nn/aggr/base.py +1 -1
  117. torch_geometric/nn/aggr/equilibrium.py +1 -1
  118. torch_geometric/nn/aggr/fused.py +1 -1
  119. torch_geometric/nn/aggr/patch_transformer.py +8 -2
  120. torch_geometric/nn/aggr/set_transformer.py +1 -1
  121. torch_geometric/nn/aggr/utils.py +9 -4
  122. torch_geometric/nn/attention/__init__.py +9 -1
  123. torch_geometric/nn/attention/polynormer.py +107 -0
  124. torch_geometric/nn/attention/qformer.py +71 -0
  125. torch_geometric/nn/attention/sgformer.py +99 -0
  126. torch_geometric/nn/conv/__init__.py +2 -0
  127. torch_geometric/nn/conv/appnp.py +1 -1
  128. torch_geometric/nn/conv/cugraph/gat_conv.py +8 -2
  129. torch_geometric/nn/conv/cugraph/rgcn_conv.py +3 -0
  130. torch_geometric/nn/conv/cugraph/sage_conv.py +3 -0
  131. torch_geometric/nn/conv/dna_conv.py +1 -1
  132. torch_geometric/nn/conv/eg_conv.py +7 -7
  133. torch_geometric/nn/conv/gen_conv.py +1 -1
  134. torch_geometric/nn/conv/gravnet_conv.py +2 -1
  135. torch_geometric/nn/conv/hetero_conv.py +2 -1
  136. torch_geometric/nn/conv/meshcnn_conv.py +487 -0
  137. torch_geometric/nn/conv/message_passing.py +5 -4
  138. torch_geometric/nn/conv/rgcn_conv.py +2 -1
  139. torch_geometric/nn/conv/sg_conv.py +1 -1
  140. torch_geometric/nn/conv/spline_conv.py +2 -1
  141. torch_geometric/nn/conv/ssg_conv.py +1 -1
  142. torch_geometric/nn/conv/transformer_conv.py +5 -3
  143. torch_geometric/nn/data_parallel.py +5 -4
  144. torch_geometric/nn/dense/linear.py +0 -20
  145. torch_geometric/nn/encoding.py +17 -3
  146. torch_geometric/nn/fx.py +14 -12
  147. torch_geometric/nn/model_hub.py +2 -15
  148. torch_geometric/nn/models/__init__.py +11 -2
  149. torch_geometric/nn/models/attentive_fp.py +1 -1
  150. torch_geometric/nn/models/attract_repel.py +148 -0
  151. torch_geometric/nn/models/basic_gnn.py +2 -1
  152. torch_geometric/nn/models/captum.py +1 -1
  153. torch_geometric/nn/models/deep_graph_infomax.py +1 -1
  154. torch_geometric/nn/models/dimenet.py +2 -2
  155. torch_geometric/nn/models/dimenet_utils.py +4 -2
  156. torch_geometric/nn/models/gpse.py +1083 -0
  157. torch_geometric/nn/models/graph_unet.py +13 -4
  158. torch_geometric/nn/models/lpformer.py +783 -0
  159. torch_geometric/nn/models/metapath2vec.py +1 -1
  160. torch_geometric/nn/models/mlp.py +4 -2
  161. torch_geometric/nn/models/node2vec.py +1 -1
  162. torch_geometric/nn/models/polynormer.py +206 -0
  163. torch_geometric/nn/models/rev_gnn.py +3 -3
  164. torch_geometric/nn/models/sgformer.py +219 -0
  165. torch_geometric/nn/models/signed_gcn.py +1 -1
  166. torch_geometric/nn/models/visnet.py +2 -2
  167. torch_geometric/nn/norm/batch_norm.py +17 -7
  168. torch_geometric/nn/norm/diff_group_norm.py +7 -2
  169. torch_geometric/nn/norm/graph_norm.py +9 -4
  170. torch_geometric/nn/norm/instance_norm.py +5 -1
  171. torch_geometric/nn/norm/layer_norm.py +15 -7
  172. torch_geometric/nn/norm/msg_norm.py +8 -2
  173. torch_geometric/nn/pool/__init__.py +8 -4
  174. torch_geometric/nn/pool/cluster_pool.py +3 -4
  175. torch_geometric/nn/pool/connect/base.py +1 -3
  176. torch_geometric/nn/pool/knn.py +13 -10
  177. torch_geometric/nn/pool/select/base.py +1 -4
  178. torch_geometric/nn/to_hetero_module.py +4 -3
  179. torch_geometric/nn/to_hetero_transformer.py +3 -3
  180. torch_geometric/nn/to_hetero_with_bases_transformer.py +4 -4
  181. torch_geometric/profile/__init__.py +2 -0
  182. torch_geometric/profile/nvtx.py +66 -0
  183. torch_geometric/profile/utils.py +20 -5
  184. torch_geometric/sampler/__init__.py +2 -1
  185. torch_geometric/sampler/base.py +336 -7
  186. torch_geometric/sampler/hgt_sampler.py +11 -1
  187. torch_geometric/sampler/neighbor_sampler.py +296 -23
  188. torch_geometric/sampler/utils.py +93 -5
  189. torch_geometric/testing/__init__.py +4 -0
  190. torch_geometric/testing/decorators.py +35 -5
  191. torch_geometric/testing/distributed.py +1 -1
  192. torch_geometric/transforms/__init__.py +2 -0
  193. torch_geometric/transforms/add_gpse.py +49 -0
  194. torch_geometric/transforms/add_metapaths.py +8 -6
  195. torch_geometric/transforms/add_positional_encoding.py +2 -2
  196. torch_geometric/transforms/base_transform.py +2 -1
  197. torch_geometric/transforms/delaunay.py +65 -15
  198. torch_geometric/transforms/face_to_edge.py +32 -3
  199. torch_geometric/transforms/gdc.py +7 -8
  200. torch_geometric/transforms/largest_connected_components.py +1 -1
  201. torch_geometric/transforms/mask.py +5 -1
  202. torch_geometric/transforms/normalize_features.py +3 -3
  203. torch_geometric/transforms/random_link_split.py +1 -1
  204. torch_geometric/transforms/remove_duplicated_edges.py +4 -2
  205. torch_geometric/transforms/rooted_subgraph.py +1 -1
  206. torch_geometric/typing.py +70 -17
  207. torch_geometric/utils/__init__.py +4 -1
  208. torch_geometric/utils/_lexsort.py +0 -9
  209. torch_geometric/utils/_negative_sampling.py +27 -12
  210. torch_geometric/utils/_scatter.py +132 -195
  211. torch_geometric/utils/_sort_edge_index.py +0 -2
  212. torch_geometric/utils/_spmm.py +16 -14
  213. torch_geometric/utils/_subgraph.py +4 -0
  214. torch_geometric/utils/_to_dense_batch.py +2 -2
  215. torch_geometric/utils/_trim_to_layer.py +2 -2
  216. torch_geometric/utils/convert.py +17 -10
  217. torch_geometric/utils/cross_entropy.py +34 -13
  218. torch_geometric/utils/embedding.py +91 -2
  219. torch_geometric/utils/geodesic.py +4 -3
  220. torch_geometric/utils/influence.py +279 -0
  221. torch_geometric/utils/map.py +13 -9
  222. torch_geometric/utils/nested.py +1 -1
  223. torch_geometric/utils/smiles.py +3 -3
  224. torch_geometric/utils/sparse.py +7 -14
  225. torch_geometric/visualization/__init__.py +2 -1
  226. torch_geometric/visualization/graph.py +250 -5
  227. torch_geometric/warnings.py +11 -2
  228. torch_geometric/nn/nlp/__init__.py +0 -7
  229. torch_geometric/nn/nlp/sentence_transformer.py +0 -101
@@ -1,21 +1,26 @@
1
1
  import logging
2
- from typing import Optional, Union
2
+ from typing import Dict, Optional, Tuple, Union, overload
3
3
 
4
4
  import torch
5
5
  from torch import Tensor
6
6
  from torch.nn import ReLU, Sequential
7
7
 
8
- from torch_geometric.explain import Explanation
8
+ from torch_geometric.explain import Explanation, HeteroExplanation
9
9
  from torch_geometric.explain.algorithm import ExplainerAlgorithm
10
- from torch_geometric.explain.algorithm.utils import clear_masks, set_masks
10
+ from torch_geometric.explain.algorithm.utils import (
11
+ clear_masks,
12
+ set_hetero_masks,
13
+ set_masks,
14
+ )
11
15
  from torch_geometric.explain.config import (
12
16
  ExplanationType,
13
17
  ModelMode,
14
18
  ModelTaskLevel,
15
19
  )
16
- from torch_geometric.nn import Linear
20
+ from torch_geometric.nn import HANConv, HeteroConv, HGTConv, Linear
17
21
  from torch_geometric.nn.inits import reset
18
- from torch_geometric.utils import get_embeddings
22
+ from torch_geometric.typing import EdgeType, NodeType
23
+ from torch_geometric.utils import get_embeddings, get_embeddings_hetero
19
24
 
20
25
 
21
26
  class PGExplainer(ExplainerAlgorithm):
@@ -62,6 +67,13 @@ class PGExplainer(ExplainerAlgorithm):
62
67
  'bias': 0.01,
63
68
  }
64
69
 
70
+ # NOTE: Add more in the future as needed.
71
+ SUPPORTED_HETERO_MODELS = [
72
+ HGTConv,
73
+ HANConv,
74
+ HeteroConv,
75
+ ]
76
+
65
77
  def __init__(self, epochs: int, lr: float = 0.003, **kwargs):
66
78
  super().__init__()
67
79
  self.epochs = epochs
@@ -75,11 +87,13 @@ class PGExplainer(ExplainerAlgorithm):
75
87
  )
76
88
  self.optimizer = torch.optim.Adam(self.mlp.parameters(), lr=lr)
77
89
  self._curr_epoch = -1
90
+ self.is_hetero = False
78
91
 
79
92
  def reset_parameters(self):
80
93
  r"""Resets all learnable parameters of the module."""
81
94
  reset(self.mlp)
82
95
 
96
+ @overload
83
97
  def train(
84
98
  self,
85
99
  epoch: int,
@@ -90,17 +104,44 @@ class PGExplainer(ExplainerAlgorithm):
90
104
  target: Tensor,
91
105
  index: Optional[Union[int, Tensor]] = None,
92
106
  **kwargs,
93
- ):
107
+ ) -> float:
108
+ ...
109
+
110
+ @overload
111
+ def train(
112
+ self,
113
+ epoch: int,
114
+ model: torch.nn.Module,
115
+ x: Dict[NodeType, Tensor],
116
+ edge_index: Dict[EdgeType, Tensor],
117
+ *,
118
+ target: Tensor,
119
+ index: Optional[Union[int, Tensor]] = None,
120
+ **kwargs,
121
+ ) -> float:
122
+ ...
123
+
124
+ def train(
125
+ self,
126
+ epoch: int,
127
+ model: torch.nn.Module,
128
+ x: Union[Tensor, Dict[NodeType, Tensor]],
129
+ edge_index: Union[Tensor, Dict[EdgeType, Tensor]],
130
+ *,
131
+ target: Tensor,
132
+ index: Optional[Union[int, Tensor]] = None,
133
+ **kwargs,
134
+ ) -> float:
94
135
  r"""Trains the underlying explainer model.
95
136
  Needs to be called before being able to make predictions.
96
137
 
97
138
  Args:
98
139
  epoch (int): The current epoch of the training phase.
99
140
  model (torch.nn.Module): The model to explain.
100
- x (torch.Tensor): The input node features of a
101
- homogeneous graph.
102
- edge_index (torch.Tensor): The input edge indices of a homogeneous
103
- graph.
141
+ x (torch.Tensor or Dict[str, torch.Tensor]): The input node
142
+ features. Can be either homogeneous or heterogeneous.
143
+ edge_index (torch.Tensor or Dict[Tuple[str, str, str]): The input
144
+ edge indices. Can be either homogeneous or heterogeneous.
104
145
  target (torch.Tensor): The target of the model.
105
146
  index (int or torch.Tensor, optional): The index of the model
106
147
  output to explain. Needs to be a single index.
@@ -108,9 +149,9 @@ class PGExplainer(ExplainerAlgorithm):
108
149
  **kwargs (optional): Additional keyword arguments passed to
109
150
  :obj:`model`.
110
151
  """
111
- if isinstance(x, dict) or isinstance(edge_index, dict):
112
- raise ValueError(f"Heterogeneous graphs not yet supported in "
113
- f"'{self.__class__.__name__}'")
152
+ self.is_hetero = isinstance(x, dict)
153
+ if self.is_hetero:
154
+ assert isinstance(edge_index, dict)
114
155
 
115
156
  if self.model_config.task_level == ModelTaskLevel.node:
116
157
  if index is None:
@@ -121,35 +162,68 @@ class PGExplainer(ExplainerAlgorithm):
121
162
  raise ValueError(f"Only scalars are supported for the 'index' "
122
163
  f"argument in '{self.__class__.__name__}'")
123
164
 
124
- z = get_embeddings(model, x, edge_index, **kwargs)[-1]
165
+ # Get embeddings based on whether the graph is homogeneous or
166
+ # heterogeneous
167
+ node_embeddings = self._get_embeddings(model, x, edge_index, **kwargs)
125
168
 
169
+ # Train the model
126
170
  self.optimizer.zero_grad()
127
171
  temperature = self._get_temperature(epoch)
128
172
 
129
- inputs = self._get_inputs(z, edge_index, index)
130
- logits = self.mlp(inputs).view(-1)
131
- edge_mask = self._concrete_sample(logits, temperature)
132
- set_masks(model, edge_mask, edge_index, apply_sigmoid=True)
133
-
134
- if self.model_config.task_level == ModelTaskLevel.node:
135
- _, hard_edge_mask = self._get_hard_masks(model, index, edge_index,
136
- num_nodes=x.size(0))
137
- edge_mask = edge_mask[hard_edge_mask]
138
-
173
+ # Process embeddings and generate edge masks
174
+ edge_mask = self._generate_edge_masks(node_embeddings, edge_index,
175
+ index, temperature)
176
+
177
+ # Apply masks to the model
178
+ if self.is_hetero:
179
+ set_hetero_masks(model, edge_mask, edge_index, apply_sigmoid=True)
180
+
181
+ # For node-level tasks, we can compute hard masks
182
+ if self.model_config.task_level == ModelTaskLevel.node:
183
+ # Process each edge type separately
184
+ for edge_type, mask in edge_mask.items():
185
+ # Get the edge indices for this edge type
186
+ edges = edge_index[edge_type]
187
+ src_type, _, dst_type = edge_type
188
+
189
+ # Get hard masks for this specific edge type
190
+ _, hard_mask = self._get_hard_masks(
191
+ model, index, edges,
192
+ num_nodes=max(x[src_type].size(0),
193
+ x[dst_type].size(0)))
194
+
195
+ edge_mask[edge_type] = mask[hard_mask]
196
+ else:
197
+ # Apply masks for homogeneous graphs
198
+ set_masks(model, edge_mask, edge_index, apply_sigmoid=True)
199
+
200
+ # For node-level tasks, we may need to apply hard masks
201
+ hard_edge_mask = None
202
+ if self.model_config.task_level == ModelTaskLevel.node:
203
+ _, hard_edge_mask = self._get_hard_masks(
204
+ model, index, edge_index, num_nodes=x.size(0))
205
+ edge_mask = edge_mask[hard_edge_mask]
206
+
207
+ # Forward pass with masks applied
139
208
  y_hat, y = model(x, edge_index, **kwargs), target
140
209
 
141
210
  if index is not None:
142
211
  y_hat, y = y_hat[index], y[index]
143
212
 
213
+ # Calculate loss
144
214
  loss = self._loss(y_hat, y, edge_mask)
215
+
216
+ # Backward pass and optimization
145
217
  loss.backward()
146
218
  self.optimizer.step()
147
219
 
220
+ # Clean up
148
221
  clear_masks(model)
149
222
  self._curr_epoch = epoch
150
223
 
151
224
  return float(loss)
152
225
 
226
+ @overload
153
227
  def forward(
154
228
  self,
155
229
  model: torch.nn.Module,
@@ -160,9 +234,32 @@ class PGExplainer(ExplainerAlgorithm):
160
234
  index: Optional[Union[int, Tensor]] = None,
161
235
  **kwargs,
162
236
  ) -> Explanation:
163
- if isinstance(x, dict) or isinstance(edge_index, dict):
164
- raise ValueError(f"Heterogeneous graphs not yet supported in "
165
- f"'{self.__class__.__name__}'")
237
+ ...
238
+
239
+ @overload
240
+ def forward(
241
+ self,
242
+ model: torch.nn.Module,
243
+ x: Dict[NodeType, Tensor],
244
+ edge_index: Dict[EdgeType, Tensor],
245
+ *,
246
+ target: Tensor,
247
+ index: Optional[Union[int, Tensor]] = None,
248
+ **kwargs,
249
+ ) -> HeteroExplanation:
250
+ ...
251
+
252
+ def forward(
253
+ self,
254
+ model: torch.nn.Module,
255
+ x: Union[Tensor, Dict[NodeType, Tensor]],
256
+ edge_index: Union[Tensor, Dict[EdgeType, Tensor]],
257
+ *,
258
+ target: Tensor,
259
+ index: Optional[Union[int, Tensor]] = None,
260
+ **kwargs,
261
+ ) -> Union[Explanation, HeteroExplanation]:
262
+ self.is_hetero = isinstance(x, dict)
166
263
 
167
264
  if self._curr_epoch < self.epochs - 1: # Safety check:
168
265
  raise ValueError(f"'{self.__class__.__name__}' is not yet fully "
@@ -171,7 +268,6 @@ class PGExplainer(ExplainerAlgorithm):
171
268
  f"the underlying explainer model by running "
172
269
  f"`explainer.algorithm.train(...)`.")
173
270
 
174
- hard_edge_mask = None
175
271
  if self.model_config.task_level == ModelTaskLevel.node:
176
272
  if index is None:
177
273
  raise ValueError(f"The 'index' argument needs to be provided "
@@ -181,20 +277,55 @@ class PGExplainer(ExplainerAlgorithm):
181
277
  raise ValueError(f"Only scalars are supported for the 'index' "
182
278
  f"argument in '{self.__class__.__name__}'")
183
279
 
184
- # We need to compute hard masks to properly clean up edges and
185
- # nodes attributions not involved during message passing:
186
- _, hard_edge_mask = self._get_hard_masks(model, index, edge_index,
187
- num_nodes=x.size(0))
280
+ # Get embeddings
281
+ node_embeddings = self._get_embeddings(model, x, edge_index, **kwargs)
282
+
283
+ # Generate explanations
284
+ if self.is_hetero:
285
+ # Generate edge masks for each edge type
286
+ edge_masks = {}
287
+
288
+ # Generate masks for each edge type
289
+ for edge_type, edge_idx in edge_index.items():
290
+ src_node_type, _, dst_node_type = edge_type
291
+
292
+ assert src_node_type in node_embeddings
293
+ assert dst_node_type in node_embeddings
294
+
295
+ inputs = self._get_inputs_hetero(node_embeddings, edge_type,
296
+ edge_idx, index)
297
+ logits = self.mlp(inputs).view(-1)
298
+
299
+ # For node-level explanations, get hard masks for this
300
+ # specific edge type
301
+ hard_edge_mask = None
302
+ if self.model_config.task_level == ModelTaskLevel.node:
303
+ _, hard_edge_mask = self._get_hard_masks(
304
+ model, index, edge_idx,
305
+ num_nodes=max(x[src_node_type].size(0),
306
+ x[dst_node_type].size(0)))
188
307
 
189
- z = get_embeddings(model, x, edge_index, **kwargs)[-1]
308
+ # Apply hard mask if available and it has any True values
309
+ edge_masks[edge_type] = self._post_process_mask(
310
+ logits, hard_edge_mask, apply_sigmoid=True)
190
311
 
191
- inputs = self._get_inputs(z, edge_index, index)
192
- logits = self.mlp(inputs).view(-1)
312
+ explanation = HeteroExplanation()
313
+ explanation.set_value_dict('edge_mask', edge_masks)
314
+ return explanation
315
+ else:
316
+ hard_edge_mask = None
317
+ if self.model_config.task_level == ModelTaskLevel.node:
318
+ # We need to compute hard masks to properly clean up edges
319
+ _, hard_edge_mask = self._get_hard_masks(
320
+ model, index, edge_index, num_nodes=x.size(0))
193
321
 
194
- edge_mask = self._post_process_mask(logits, hard_edge_mask,
195
- apply_sigmoid=True)
322
+ inputs = self._get_inputs(node_embeddings, edge_index, index)
323
+ logits = self.mlp(inputs).view(-1)
196
324
 
197
- return Explanation(edge_mask=edge_mask)
325
+ edge_mask = self._post_process_mask(logits, hard_edge_mask,
326
+ apply_sigmoid=True)
327
+
328
+ return Explanation(edge_mask=edge_mask)
198
329
 
199
330
  def supports(self) -> bool:
200
331
  explanation_type = self.explainer_config.explanation_type
@@ -222,6 +353,76 @@ class PGExplainer(ExplainerAlgorithm):
222
353
 
223
354
  ###########################################################################
224
355
 
356
+ def _get_embeddings(self, model: torch.nn.Module, x: Union[Tensor,
357
+ Dict[NodeType,
358
+ Tensor]],
359
+ edge_index: Union[Tensor, Dict[EdgeType, Tensor]],
360
+ **kwargs) -> Union[Tensor, Dict[NodeType, Tensor]]:
361
+ """Get embeddings from the model based on input type."""
362
+ if self.is_hetero:
363
+ # For heterogeneous graphs, get embeddings for each node type
364
+ embeddings_dict = get_embeddings_hetero(
365
+ model,
366
+ self.SUPPORTED_HETERO_MODELS,
367
+ x,
368
+ edge_index,
369
+ **kwargs,
370
+ )
371
+
372
+ # Use the last layer's embeddings for each node type
373
+ last_embedding_dict = {
374
+ node_type: embs[-1] if embs and len(embs) > 0 else None
375
+ for node_type, embs in embeddings_dict.items()
376
+ }
377
+
378
+ # Skip if no embeddings were captured
379
+ if not any(emb is not None
380
+ for emb in last_embedding_dict.values()):
381
+ raise ValueError(
382
+ "No embeddings were captured from the model. "
383
+ "Please check if the model architecture is supported.")
384
+
385
+ return last_embedding_dict
386
+ else:
387
+ # For homogeneous graphs, get embeddings directly
388
+ return get_embeddings(model, x, edge_index, **kwargs)[-1]
389
+
390
+ def _generate_edge_masks(
391
+ self, emb: Union[Tensor, Dict[NodeType, Tensor]],
392
+ edge_index: Union[Tensor,
393
+ Dict[EdgeType,
394
+ Tensor]], index: Optional[Union[int,
395
+ Tensor]],
396
+ temperature: float) -> Union[Tensor, Dict[EdgeType, Tensor]]:
397
+ """Generate edge masks based on embeddings."""
398
+ if self.is_hetero:
399
+ # For heterogeneous graphs, generate masks for each edge type
400
+ edge_masks = {}
401
+
402
+ for edge_type, edge_idx in edge_index.items():
403
+ src, _, dst = edge_type
404
+
405
+ assert src in emb and dst in emb
406
+ # Generate inputs for this edge type
407
+ inputs = self._get_inputs_hetero(emb, edge_type, edge_idx,
408
+ index)
409
+ logits = self.mlp(inputs).view(-1)
410
+ edge_masks[edge_type] = self._concrete_sample(
411
+ logits, temperature)
412
+
413
+ # Ensure we have at least one valid edge mask
414
+ if not edge_masks:
415
+ raise ValueError(
416
+ "Could not generate edge masks for any edge type. "
417
+ "Please ensure the model architecture is supported.")
418
+
419
+ return edge_masks
420
+ else:
421
+ # For homogeneous graphs, generate a single mask
422
+ inputs = self._get_inputs(emb, edge_index, index)
423
+ logits = self.mlp(inputs).view(-1)
424
+ return self._concrete_sample(logits, temperature)
425
+
225
426
  def _get_inputs(self, embedding: Tensor, edge_index: Tensor,
226
427
  index: Optional[int] = None) -> Tensor:
227
428
  zs = [embedding[edge_index[0]], embedding[edge_index[1]]]
@@ -230,6 +431,27 @@ class PGExplainer(ExplainerAlgorithm):
230
431
  zs.append(embedding[index].view(1, -1).repeat(zs[0].size(0), 1))
231
432
  return torch.cat(zs, dim=-1)
232
433
 
434
+ def _get_inputs_hetero(self, embedding_dict: Dict[NodeType, Tensor],
435
+ edge_type: Tuple[str, str, str], edge_index: Tensor,
436
+ index: Optional[int] = None) -> Tensor:
437
+ src, _, dst = edge_type
438
+
439
+ # Get embeddings for source and destination nodes
440
+ src_emb = embedding_dict[src]
441
+ dst_emb = embedding_dict[dst]
442
+
443
+ # Source and destination node embeddings
444
+ zs = [src_emb[edge_index[0]], dst_emb[edge_index[1]]]
445
+
446
+ # For node-level explanations, add the target node embedding
447
+ if self.model_config.task_level == ModelTaskLevel.node:
448
+ assert index is not None
449
+ # Assuming index refers to a node of type 'src'
450
+ target_emb = src_emb[index].view(1, -1).repeat(zs[0].size(0), 1)
451
+ zs.append(target_emb)
452
+
453
+ return torch.cat(zs, dim=-1)
454
+
233
455
  def _get_temperature(self, epoch: int) -> float:
234
456
  temp = self.coeffs['temp']
235
457
  return temp[0] * pow(temp[1] / temp[0], epoch / self.epochs)
@@ -240,19 +462,55 @@ class PGExplainer(ExplainerAlgorithm):
240
462
  eps = (1 - 2 * bias) * torch.rand_like(logits) + bias
241
463
  return (eps.log() - (1 - eps).log() + logits) / temperature
242
464
 
243
- def _loss(self, y_hat: Tensor, y: Tensor, edge_mask: Tensor) -> Tensor:
465
+ def _loss(self, y_hat: Tensor, y: Tensor,
466
+ edge_mask: Union[Tensor, Dict[EdgeType, Tensor]]) -> Tensor:
467
+ # Calculate base loss based on model configuration
468
+ loss = self._calculate_base_loss(y_hat, y)
469
+
470
+ # Apply regularization based on graph type
471
+ if self.is_hetero:
472
+ loss = self._apply_hetero_regularization(loss, edge_mask)
473
+ else:
474
+ loss = self._apply_homo_regularization(loss, edge_mask)
475
+
476
+ return loss
477
+
478
+ def _calculate_base_loss(self, y_hat: Tensor, y: Tensor) -> Tensor:
479
+ """Calculate base loss based on model configuration."""
244
480
  if self.model_config.mode == ModelMode.binary_classification:
245
- loss = self._loss_binary_classification(y_hat, y)
481
+ return self._loss_binary_classification(y_hat, y)
246
482
  elif self.model_config.mode == ModelMode.multiclass_classification:
247
- loss = self._loss_multiclass_classification(y_hat, y)
483
+ return self._loss_multiclass_classification(y_hat, y)
248
484
  elif self.model_config.mode == ModelMode.regression:
249
- loss = self._loss_regression(y_hat, y)
250
-
251
- # Regularization loss:
252
- mask = edge_mask.sigmoid()
485
+ return self._loss_regression(y_hat, y)
486
+ else:
487
+ raise ValueError(
488
+ f"Unsupported model mode: {self.model_config.mode}")
489
+
490
+ def _apply_hetero_regularization(
491
+ self, loss: Tensor, edge_mask: Dict[EdgeType, Tensor]) -> Tensor:
492
+ """Apply regularization for heterogeneous graph."""
493
+ for _, mask in edge_mask.items():
494
+ loss = self._add_mask_regularization(loss, mask)
495
+
496
+ return loss
497
+
498
+ def _apply_homo_regularization(self, loss: Tensor,
499
+ edge_mask: Tensor) -> Tensor:
500
+ """Apply regularization for homogeneous graph."""
501
+ return self._add_mask_regularization(loss, edge_mask)
502
+
503
+ def _add_mask_regularization(self, loss: Tensor, mask: Tensor) -> Tensor:
504
+ """Add size and entropy regularization for a mask."""
505
+ # Apply sigmoid for mask values
506
+ mask = mask.sigmoid()
507
+
508
+ # Size regularization
253
509
  size_loss = mask.sum() * self.coeffs['edge_size']
254
- mask = 0.99 * mask + 0.005
255
- mask_ent = -mask * mask.log() - (1 - mask) * (1 - mask).log()
510
+
511
+ # Entropy regularization
512
+ masked = 0.99 * mask + 0.005
513
+ mask_ent = -masked * masked.log() - (1 - masked) * (1 - masked).log()
256
514
  mask_ent_loss = mask_ent.mean() * self.coeffs['edge_ent']
257
515
 
258
516
  return loss + size_loss + mask_ent_loss
@@ -192,7 +192,7 @@ class Explainer:
192
192
  if target is not None:
193
193
  warnings.warn(
194
194
  f"The 'target' should not be provided for the explanation "
195
- f"type '{self.explanation_type.value}'")
195
+ f"type '{self.explanation_type.value}'", stacklevel=2)
196
196
  prediction = self.get_prediction(x, edge_index, **kwargs)
197
197
  target = self.get_target(prediction)
198
198
 
@@ -265,7 +265,7 @@ class Explainer:
265
265
  return (prediction > 0).long().view(-1)
266
266
  if self.model_config.return_type == ModelReturnType.probs:
267
267
  return (prediction > 0.5).long().view(-1)
268
- assert False
268
+ raise AssertionError()
269
269
 
270
270
  if self.model_config.mode == ModelMode.multiclass_classification:
271
271
  return prediction.argmax(dim=-1)
@@ -1,5 +1,5 @@
1
1
  import copy
2
- from typing import Dict, List, Optional, Union
2
+ from typing import Dict, List, Optional, Tuple, Union
3
3
 
4
4
  import torch
5
5
  from torch import Tensor
@@ -8,7 +8,10 @@ from torch_geometric.data.data import Data, warn_or_raise
8
8
  from torch_geometric.data.hetero_data import HeteroData
9
9
  from torch_geometric.explain.config import ThresholdConfig, ThresholdType
10
10
  from torch_geometric.typing import EdgeType, NodeType
11
- from torch_geometric.visualization import visualize_graph
11
+ from torch_geometric.visualization import (
12
+ visualize_graph,
13
+ visualize_hetero_graph,
14
+ )
12
15
 
13
16
 
14
17
  class ExplanationMixin:
@@ -100,7 +103,7 @@ class ExplanationMixin:
100
103
  out[index] = 1.0
101
104
  return out.view(mask.size())
102
105
 
103
- assert False
106
+ raise AssertionError()
104
107
 
105
108
  def threshold(
106
109
  self,
@@ -362,6 +365,87 @@ class HeteroExplanation(HeteroData, ExplanationMixin):
362
365
 
363
366
  return _visualize_score(score, all_feat_labels, path, top_k)
364
367
 
368
+ def visualize_graph(
369
+ self,
370
+ path: Optional[str] = None,
371
+ node_labels: Optional[Dict[NodeType, List[str]]] = None,
372
+ node_size_range: Tuple[float, float] = (50, 500),
373
+ node_opacity_range: Tuple[float, float] = (0.2, 1.0),
374
+ edge_width_range: Tuple[float, float] = (0.1, 2.0),
375
+ edge_opacity_range: Tuple[float, float] = (0.2, 1.0),
376
+ ) -> None:
377
+ r"""Visualizes the explanation subgraph using networkx, with edge
378
+ opacity corresponding to edge importance and node colors
379
+ corresponding to node types.
380
+
381
+ Args:
382
+ path (str, optional): The path to where the plot is saved.
383
+ If set to :obj:`None`, will visualize the plot on-the-fly.
384
+ (default: :obj:`None`)
385
+ node_labels (Dict[NodeType, List[str]], optional): The display
386
+ names of nodes for each node type that will be shown in the
387
+ visualization. (default: :obj:`None`)
388
+ node_size_range (Tuple[float, float], optional): The minimum and
389
+ maximum node size in the visualization.
390
+ (default: :obj:`(50, 500)`)
391
+ node_opacity_range (Tuple[float, float], optional): The minimum and
392
+ maximum node opacity in the visualization.
393
+ (default: :obj:`(0.2, 1.0)`)
394
+ edge_width_range (Tuple[float, float], optional): The minimum and
395
+ maximum edge width in the visualization.
396
+ (default: :obj:`(0.1, 2.0)`)
397
+ edge_opacity_range (Tuple[float, float], optional): The minimum and
398
+ maximum edge opacity in the visualization.
399
+ (default: :obj:`(0.2, 1.0)`)
400
+ """
401
+ # Validate node labels if provided
402
+ if node_labels is not None:
403
+ for node_type, labels in node_labels.items():
404
+ if node_type not in self.node_types:
405
+ raise ValueError(
406
+ f"Node type '{node_type}' in node_labels "
407
+ f"does not exist in the explanation graph")
408
+ if len(labels) != self[node_type].num_nodes:
409
+ raise ValueError(f"Number of labels for node type "
410
+ f"'{node_type}' (got {len(labels)}) does "
411
+ f"not match the number of nodes "
412
+ f"(got {self[node_type].num_nodes})")
413
+ # Get the explanation subgraph
414
+ subgraph = self.get_explanation_subgraph()
415
+
416
+ # Prepare edge indices and weights for each edge type
417
+ edge_index_dict = {}
418
+ edge_weight_dict = {}
419
+ for edge_type in subgraph.edge_types:
420
+ if edge_type[0] == 'x' or edge_type[-1] == 'x': # Skip edges
421
+ continue
422
+ edge_index_dict[edge_type] = subgraph[edge_type].edge_index
423
+ edge_weight_dict[edge_type] = subgraph[edge_type].get(
424
+ 'edge_mask',
425
+ torch.ones(subgraph[edge_type].edge_index.size(1)))
426
+
427
+ # Prepare node weights for each node type
428
+ node_weight_dict = {}
429
+ for node_type in subgraph.node_types:
430
+ if node_type == 'x': # Skip the global store
431
+ continue
432
+ node_weight_dict[node_type] = subgraph[node_type] \
433
+ .get('node_mask',
434
+ torch.ones(subgraph[node_type].num_nodes)).squeeze(-1)
435
+
436
+ # Call the visualization function
437
+ visualize_hetero_graph(
438
+ edge_index_dict=edge_index_dict,
439
+ edge_weight_dict=edge_weight_dict,
440
+ path=path,
441
+ node_labels_dict=node_labels,
442
+ node_weight_dict=node_weight_dict,
443
+ node_size_range=node_size_range,
444
+ node_opacity_range=node_opacity_range,
445
+ edge_width_range=edge_width_range,
446
+ edge_opacity_range=edge_opacity_range,
447
+ )
448
+
365
449
 
366
450
  def _visualize_score(
367
451
  score: torch.Tensor,
@@ -13,7 +13,7 @@ def unfaithfulness(
13
13
  top_k: Optional[int] = None,
14
14
  ) -> float:
15
15
  r"""Evaluates how faithful an :class:`~torch_geometric.explain.Explanation`
16
- is to an underyling GNN predictor, as described in the
16
+ is to an underlying GNN predictor, as described in the
17
17
  `"Evaluating Explainability for Graph Neural Networks"
18
18
  <https://arxiv.org/abs/2208.09339>`_ paper.
19
19
 
@@ -16,8 +16,9 @@ try: # Define global config object
16
16
  cfg = CN()
17
17
  except ImportError:
18
18
  cfg = None
19
- warnings.warn("Could not define global config object. Please install "
20
- "'yacs' via 'pip install yacs' in order to use GraphGym")
19
+ warnings.warn(
20
+ "Could not define global config object. Please install "
21
+ "'yacs' via 'pip install yacs' in order to use GraphGym", stacklevel=2)
21
22
 
22
23
 
23
24
  def set_cfg(cfg):
@@ -3,13 +3,24 @@ import warnings
3
3
  import torch
4
4
 
5
5
  try:
6
- import pytorch_lightning as pl
6
+ import lightning.pytorch as pl
7
+ _pl_is_available = True
8
+ except ImportError:
9
+ try:
10
+ import pytorch_lightning as pl
11
+ _pl_is_available = True
12
+ except ImportError:
13
+ _pl_is_available = False
14
+
15
+ if _pl_is_available:
7
16
  LightningModule = pl.LightningModule
8
17
  Callback = pl.Callback
9
- except ImportError:
18
+ else:
10
19
  pl = object
11
20
  LightningModule = torch.nn.Module
12
21
  Callback = object
13
22
 
14
- warnings.warn("Please install 'pytorch_lightning' via "
15
- "'pip install pytorch_lightning' in order to use GraphGym")
23
+ warnings.warn(
24
+ "To use GraphGym, install 'pytorch_lightning' or 'lightning' via "
25
+ "'pip install pytorch_lightning' or 'pip install lightning'",
26
+ stacklevel=2)
@@ -239,7 +239,7 @@ def create_logger():
239
239
  r"""Create logger for the experiment."""
240
240
  loggers = []
241
241
  names = ['train', 'val', 'test']
242
- for i, dataset in enumerate(range(cfg.share.num_splits)):
242
+ for i, _ in enumerate(range(cfg.share.num_splits)):
243
243
  loggers.append(Logger(name=names[i], task_type=infer_task()))
244
244
  return loggers
245
245