pyg-nightly 2.7.0.dev20241009__py3-none-any.whl → 2.8.0.dev20251207__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pyg-nightly might be problematic. Click here for more details.

Files changed (228) hide show
  1. {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/METADATA +77 -53
  2. {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/RECORD +226 -189
  3. {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/WHEEL +1 -1
  4. pyg_nightly-2.8.0.dev20251207.dist-info/licenses/LICENSE +19 -0
  5. torch_geometric/__init__.py +14 -2
  6. torch_geometric/_compile.py +9 -3
  7. torch_geometric/_onnx.py +214 -0
  8. torch_geometric/config_mixin.py +5 -3
  9. torch_geometric/config_store.py +1 -1
  10. torch_geometric/contrib/__init__.py +1 -1
  11. torch_geometric/contrib/explain/pgm_explainer.py +1 -1
  12. torch_geometric/data/batch.py +2 -2
  13. torch_geometric/data/collate.py +1 -3
  14. torch_geometric/data/data.py +109 -5
  15. torch_geometric/data/database.py +4 -0
  16. torch_geometric/data/dataset.py +14 -11
  17. torch_geometric/data/extract.py +1 -1
  18. torch_geometric/data/feature_store.py +17 -22
  19. torch_geometric/data/graph_store.py +3 -2
  20. torch_geometric/data/hetero_data.py +139 -7
  21. torch_geometric/data/hypergraph_data.py +2 -2
  22. torch_geometric/data/in_memory_dataset.py +2 -2
  23. torch_geometric/data/lightning/datamodule.py +42 -28
  24. torch_geometric/data/storage.py +9 -1
  25. torch_geometric/datasets/__init__.py +18 -1
  26. torch_geometric/datasets/actor.py +7 -9
  27. torch_geometric/datasets/airfrans.py +15 -17
  28. torch_geometric/datasets/airports.py +8 -10
  29. torch_geometric/datasets/amazon.py +8 -11
  30. torch_geometric/datasets/amazon_book.py +8 -9
  31. torch_geometric/datasets/amazon_products.py +7 -9
  32. torch_geometric/datasets/aminer.py +8 -9
  33. torch_geometric/datasets/aqsol.py +10 -13
  34. torch_geometric/datasets/attributed_graph_dataset.py +8 -10
  35. torch_geometric/datasets/ba_multi_shapes.py +10 -12
  36. torch_geometric/datasets/ba_shapes.py +5 -6
  37. torch_geometric/datasets/city.py +157 -0
  38. torch_geometric/datasets/dbp15k.py +1 -1
  39. torch_geometric/datasets/git_mol_dataset.py +263 -0
  40. torch_geometric/datasets/hgb_dataset.py +2 -2
  41. torch_geometric/datasets/hm.py +1 -1
  42. torch_geometric/datasets/instruct_mol_dataset.py +134 -0
  43. torch_geometric/datasets/md17.py +3 -3
  44. torch_geometric/datasets/medshapenet.py +145 -0
  45. torch_geometric/datasets/modelnet.py +1 -1
  46. torch_geometric/datasets/molecule_gpt_dataset.py +492 -0
  47. torch_geometric/datasets/molecule_net.py +3 -2
  48. torch_geometric/datasets/ppi.py +2 -1
  49. torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
  50. torch_geometric/datasets/qm7.py +1 -1
  51. torch_geometric/datasets/qm9.py +1 -1
  52. torch_geometric/datasets/snap_dataset.py +8 -4
  53. torch_geometric/datasets/tag_dataset.py +462 -0
  54. torch_geometric/datasets/teeth3ds.py +269 -0
  55. torch_geometric/datasets/web_qsp_dataset.py +310 -209
  56. torch_geometric/datasets/wikics.py +2 -1
  57. torch_geometric/deprecation.py +1 -1
  58. torch_geometric/distributed/__init__.py +13 -0
  59. torch_geometric/distributed/dist_loader.py +2 -2
  60. torch_geometric/distributed/partition.py +2 -2
  61. torch_geometric/distributed/rpc.py +3 -3
  62. torch_geometric/edge_index.py +18 -14
  63. torch_geometric/explain/algorithm/attention_explainer.py +219 -29
  64. torch_geometric/explain/algorithm/base.py +2 -2
  65. torch_geometric/explain/algorithm/captum.py +1 -1
  66. torch_geometric/explain/algorithm/captum_explainer.py +2 -1
  67. torch_geometric/explain/algorithm/gnn_explainer.py +406 -69
  68. torch_geometric/explain/algorithm/graphmask_explainer.py +8 -8
  69. torch_geometric/explain/algorithm/pg_explainer.py +305 -47
  70. torch_geometric/explain/explainer.py +2 -2
  71. torch_geometric/explain/explanation.py +87 -3
  72. torch_geometric/explain/metric/faithfulness.py +1 -1
  73. torch_geometric/graphgym/config.py +3 -2
  74. torch_geometric/graphgym/imports.py +15 -4
  75. torch_geometric/graphgym/logger.py +1 -1
  76. torch_geometric/graphgym/loss.py +1 -1
  77. torch_geometric/graphgym/models/encoder.py +2 -2
  78. torch_geometric/graphgym/models/layer.py +1 -1
  79. torch_geometric/graphgym/utils/comp_budget.py +4 -3
  80. torch_geometric/hash_tensor.py +798 -0
  81. torch_geometric/index.py +14 -5
  82. torch_geometric/inspector.py +4 -0
  83. torch_geometric/io/fs.py +5 -4
  84. torch_geometric/llm/__init__.py +9 -0
  85. torch_geometric/llm/large_graph_indexer.py +741 -0
  86. torch_geometric/llm/models/__init__.py +23 -0
  87. torch_geometric/{nn → llm}/models/g_retriever.py +77 -45
  88. torch_geometric/llm/models/git_mol.py +336 -0
  89. torch_geometric/llm/models/glem.py +397 -0
  90. torch_geometric/{nn/nlp → llm/models}/llm.py +179 -31
  91. torch_geometric/llm/models/llm_judge.py +158 -0
  92. torch_geometric/llm/models/molecule_gpt.py +222 -0
  93. torch_geometric/llm/models/protein_mpnn.py +333 -0
  94. torch_geometric/llm/models/sentence_transformer.py +188 -0
  95. torch_geometric/llm/models/txt2kg.py +353 -0
  96. torch_geometric/llm/models/vision_transformer.py +38 -0
  97. torch_geometric/llm/rag_loader.py +154 -0
  98. torch_geometric/llm/utils/__init__.py +10 -0
  99. torch_geometric/llm/utils/backend_utils.py +443 -0
  100. torch_geometric/llm/utils/feature_store.py +169 -0
  101. torch_geometric/llm/utils/graph_store.py +199 -0
  102. torch_geometric/llm/utils/vectorrag.py +125 -0
  103. torch_geometric/loader/cluster.py +4 -4
  104. torch_geometric/loader/ibmb_loader.py +4 -4
  105. torch_geometric/loader/link_loader.py +1 -1
  106. torch_geometric/loader/link_neighbor_loader.py +2 -1
  107. torch_geometric/loader/mixin.py +6 -5
  108. torch_geometric/loader/neighbor_loader.py +1 -1
  109. torch_geometric/loader/neighbor_sampler.py +2 -2
  110. torch_geometric/loader/prefetch.py +3 -2
  111. torch_geometric/loader/temporal_dataloader.py +2 -2
  112. torch_geometric/loader/utils.py +10 -10
  113. torch_geometric/metrics/__init__.py +14 -0
  114. torch_geometric/metrics/link_pred.py +745 -92
  115. torch_geometric/nn/__init__.py +1 -0
  116. torch_geometric/nn/aggr/base.py +1 -1
  117. torch_geometric/nn/aggr/equilibrium.py +1 -1
  118. torch_geometric/nn/aggr/fused.py +1 -1
  119. torch_geometric/nn/aggr/patch_transformer.py +8 -2
  120. torch_geometric/nn/aggr/set_transformer.py +1 -1
  121. torch_geometric/nn/aggr/utils.py +9 -4
  122. torch_geometric/nn/attention/__init__.py +9 -1
  123. torch_geometric/nn/attention/polynormer.py +107 -0
  124. torch_geometric/nn/attention/qformer.py +71 -0
  125. torch_geometric/nn/attention/sgformer.py +99 -0
  126. torch_geometric/nn/conv/__init__.py +2 -0
  127. torch_geometric/nn/conv/appnp.py +1 -1
  128. torch_geometric/nn/conv/cugraph/gat_conv.py +8 -2
  129. torch_geometric/nn/conv/cugraph/rgcn_conv.py +3 -0
  130. torch_geometric/nn/conv/cugraph/sage_conv.py +3 -0
  131. torch_geometric/nn/conv/dna_conv.py +1 -1
  132. torch_geometric/nn/conv/eg_conv.py +7 -7
  133. torch_geometric/nn/conv/gen_conv.py +1 -1
  134. torch_geometric/nn/conv/gravnet_conv.py +2 -1
  135. torch_geometric/nn/conv/hetero_conv.py +2 -1
  136. torch_geometric/nn/conv/meshcnn_conv.py +487 -0
  137. torch_geometric/nn/conv/message_passing.py +5 -4
  138. torch_geometric/nn/conv/rgcn_conv.py +2 -1
  139. torch_geometric/nn/conv/sg_conv.py +1 -1
  140. torch_geometric/nn/conv/spline_conv.py +2 -1
  141. torch_geometric/nn/conv/ssg_conv.py +1 -1
  142. torch_geometric/nn/conv/transformer_conv.py +5 -3
  143. torch_geometric/nn/data_parallel.py +5 -4
  144. torch_geometric/nn/dense/linear.py +0 -20
  145. torch_geometric/nn/encoding.py +17 -3
  146. torch_geometric/nn/fx.py +14 -12
  147. torch_geometric/nn/model_hub.py +2 -15
  148. torch_geometric/nn/models/__init__.py +11 -2
  149. torch_geometric/nn/models/attentive_fp.py +1 -1
  150. torch_geometric/nn/models/attract_repel.py +148 -0
  151. torch_geometric/nn/models/basic_gnn.py +2 -1
  152. torch_geometric/nn/models/captum.py +1 -1
  153. torch_geometric/nn/models/deep_graph_infomax.py +1 -1
  154. torch_geometric/nn/models/dimenet.py +2 -2
  155. torch_geometric/nn/models/dimenet_utils.py +4 -2
  156. torch_geometric/nn/models/gpse.py +1083 -0
  157. torch_geometric/nn/models/graph_unet.py +13 -4
  158. torch_geometric/nn/models/lpformer.py +783 -0
  159. torch_geometric/nn/models/metapath2vec.py +1 -1
  160. torch_geometric/nn/models/mlp.py +4 -2
  161. torch_geometric/nn/models/node2vec.py +1 -1
  162. torch_geometric/nn/models/polynormer.py +206 -0
  163. torch_geometric/nn/models/rev_gnn.py +3 -3
  164. torch_geometric/nn/models/sgformer.py +219 -0
  165. torch_geometric/nn/models/signed_gcn.py +1 -1
  166. torch_geometric/nn/models/visnet.py +2 -2
  167. torch_geometric/nn/norm/batch_norm.py +17 -7
  168. torch_geometric/nn/norm/diff_group_norm.py +7 -2
  169. torch_geometric/nn/norm/graph_norm.py +9 -4
  170. torch_geometric/nn/norm/instance_norm.py +5 -1
  171. torch_geometric/nn/norm/layer_norm.py +15 -7
  172. torch_geometric/nn/norm/msg_norm.py +8 -2
  173. torch_geometric/nn/pool/__init__.py +8 -4
  174. torch_geometric/nn/pool/cluster_pool.py +3 -4
  175. torch_geometric/nn/pool/connect/base.py +1 -3
  176. torch_geometric/nn/pool/knn.py +13 -10
  177. torch_geometric/nn/pool/select/base.py +1 -4
  178. torch_geometric/nn/to_hetero_module.py +4 -3
  179. torch_geometric/nn/to_hetero_transformer.py +3 -3
  180. torch_geometric/nn/to_hetero_with_bases_transformer.py +4 -4
  181. torch_geometric/profile/__init__.py +2 -0
  182. torch_geometric/profile/nvtx.py +66 -0
  183. torch_geometric/profile/utils.py +20 -5
  184. torch_geometric/sampler/__init__.py +2 -1
  185. torch_geometric/sampler/base.py +336 -7
  186. torch_geometric/sampler/hgt_sampler.py +11 -1
  187. torch_geometric/sampler/neighbor_sampler.py +296 -23
  188. torch_geometric/sampler/utils.py +93 -5
  189. torch_geometric/testing/__init__.py +4 -0
  190. torch_geometric/testing/decorators.py +35 -5
  191. torch_geometric/testing/distributed.py +1 -1
  192. torch_geometric/transforms/__init__.py +2 -0
  193. torch_geometric/transforms/add_gpse.py +49 -0
  194. torch_geometric/transforms/add_metapaths.py +8 -6
  195. torch_geometric/transforms/add_positional_encoding.py +2 -2
  196. torch_geometric/transforms/base_transform.py +2 -1
  197. torch_geometric/transforms/delaunay.py +65 -15
  198. torch_geometric/transforms/face_to_edge.py +32 -3
  199. torch_geometric/transforms/gdc.py +7 -8
  200. torch_geometric/transforms/largest_connected_components.py +1 -1
  201. torch_geometric/transforms/mask.py +5 -1
  202. torch_geometric/transforms/normalize_features.py +3 -3
  203. torch_geometric/transforms/random_link_split.py +1 -1
  204. torch_geometric/transforms/remove_duplicated_edges.py +4 -2
  205. torch_geometric/transforms/rooted_subgraph.py +1 -1
  206. torch_geometric/typing.py +70 -17
  207. torch_geometric/utils/__init__.py +4 -1
  208. torch_geometric/utils/_lexsort.py +0 -9
  209. torch_geometric/utils/_negative_sampling.py +27 -12
  210. torch_geometric/utils/_scatter.py +132 -195
  211. torch_geometric/utils/_sort_edge_index.py +0 -2
  212. torch_geometric/utils/_spmm.py +16 -14
  213. torch_geometric/utils/_subgraph.py +4 -0
  214. torch_geometric/utils/_trim_to_layer.py +2 -2
  215. torch_geometric/utils/convert.py +17 -10
  216. torch_geometric/utils/cross_entropy.py +34 -13
  217. torch_geometric/utils/embedding.py +91 -2
  218. torch_geometric/utils/geodesic.py +4 -3
  219. torch_geometric/utils/influence.py +279 -0
  220. torch_geometric/utils/map.py +13 -9
  221. torch_geometric/utils/nested.py +1 -1
  222. torch_geometric/utils/smiles.py +3 -3
  223. torch_geometric/utils/sparse.py +7 -14
  224. torch_geometric/visualization/__init__.py +2 -1
  225. torch_geometric/visualization/graph.py +250 -5
  226. torch_geometric/warnings.py +11 -2
  227. torch_geometric/nn/nlp/__init__.py +0 -7
  228. torch_geometric/nn/nlp/sentence_transformer.py +0 -101
@@ -1,14 +1,24 @@
1
1
  from math import sqrt
2
- from typing import Optional, Tuple, Union
2
+ from typing import Dict, Optional, Tuple, Union, overload
3
3
 
4
4
  import torch
5
5
  from torch import Tensor
6
6
  from torch.nn.parameter import Parameter
7
7
 
8
- from torch_geometric.explain import ExplainerConfig, Explanation, ModelConfig
8
+ from torch_geometric.explain import (
9
+ ExplainerConfig,
10
+ Explanation,
11
+ HeteroExplanation,
12
+ ModelConfig,
13
+ )
9
14
  from torch_geometric.explain.algorithm import ExplainerAlgorithm
10
- from torch_geometric.explain.algorithm.utils import clear_masks, set_masks
15
+ from torch_geometric.explain.algorithm.utils import (
16
+ clear_masks,
17
+ set_hetero_masks,
18
+ set_masks,
19
+ )
11
20
  from torch_geometric.explain.config import MaskType, ModelMode, ModelTaskLevel
21
+ from torch_geometric.typing import EdgeType, NodeType
12
22
 
13
23
 
14
24
  class GNNExplainer(ExplainerAlgorithm):
@@ -51,7 +61,7 @@ class GNNExplainer(ExplainerAlgorithm):
51
61
  :attr:`~torch_geometric.explain.algorithm.GNNExplainer.coeffs`.
52
62
  """
53
63
 
54
- coeffs = {
64
+ default_coeffs = {
55
65
  'edge_size': 0.005,
56
66
  'edge_reduction': 'sum',
57
67
  'node_feat_size': 1.0,
@@ -65,11 +75,14 @@ class GNNExplainer(ExplainerAlgorithm):
65
75
  super().__init__()
66
76
  self.epochs = epochs
67
77
  self.lr = lr
78
+ self.coeffs = dict(self.default_coeffs)
68
79
  self.coeffs.update(kwargs)
69
80
 
70
81
  self.node_mask = self.hard_node_mask = None
71
82
  self.edge_mask = self.hard_edge_mask = None
83
+ self.is_hetero = False
72
84
 
85
+ @overload
73
86
  def forward(
74
87
  self,
75
88
  model: torch.nn.Module,
@@ -80,30 +93,87 @@ class GNNExplainer(ExplainerAlgorithm):
80
93
  index: Optional[Union[int, Tensor]] = None,
81
94
  **kwargs,
82
95
  ) -> Explanation:
83
- if isinstance(x, dict) or isinstance(edge_index, dict):
84
- raise ValueError(f"Heterogeneous graphs not yet supported in "
85
- f"'{self.__class__.__name__}'")
96
+ ...
86
97
 
87
- self._train(model, x, edge_index, target=target, index=index, **kwargs)
88
-
89
- node_mask = self._post_process_mask(
90
- self.node_mask,
91
- self.hard_node_mask,
92
- apply_sigmoid=True,
93
- )
94
- edge_mask = self._post_process_mask(
95
- self.edge_mask,
96
- self.hard_edge_mask,
97
- apply_sigmoid=True,
98
- )
98
+ @overload
99
+ def forward(
100
+ self,
101
+ model: torch.nn.Module,
102
+ x: Dict[NodeType, Tensor],
103
+ edge_index: Dict[EdgeType, Tensor],
104
+ *,
105
+ target: Tensor,
106
+ index: Optional[Union[int, Tensor]] = None,
107
+ **kwargs,
108
+ ) -> HeteroExplanation:
109
+ ...
99
110
 
111
+ def forward(
112
+ self,
113
+ model: torch.nn.Module,
114
+ x: Union[Tensor, Dict[NodeType, Tensor]],
115
+ edge_index: Union[Tensor, Dict[EdgeType, Tensor]],
116
+ *,
117
+ target: Tensor,
118
+ index: Optional[Union[int, Tensor]] = None,
119
+ **kwargs,
120
+ ) -> Union[Explanation, HeteroExplanation]:
121
+ self.is_hetero = isinstance(x, dict)
122
+ self._train(model, x, edge_index, target=target, index=index, **kwargs)
123
+ explanation = self._create_explanation()
100
124
  self._clean_model(model)
125
+ return explanation
126
+
127
+ def _create_explanation(self) -> Union[Explanation, HeteroExplanation]:
128
+ """Create an explanation object from the current masks."""
129
+ if self.is_hetero:
130
+ # For heterogeneous graphs, process each type separately
131
+ node_mask_dict = {}
132
+ edge_mask_dict = {}
133
+
134
+ for node_type, mask in self.node_mask.items():
135
+ if mask is not None:
136
+ node_mask_dict[node_type] = self._post_process_mask(
137
+ mask,
138
+ self.hard_node_mask[node_type],
139
+ apply_sigmoid=True,
140
+ )
141
+
142
+ for edge_type, mask in self.edge_mask.items():
143
+ if mask is not None:
144
+ edge_mask_dict[edge_type] = self._post_process_mask(
145
+ mask,
146
+ self.hard_edge_mask[edge_type],
147
+ apply_sigmoid=True,
148
+ )
149
+
150
+ # Create heterogeneous explanation
151
+ explanation = HeteroExplanation()
152
+ explanation.set_value_dict('node_mask', node_mask_dict)
153
+ explanation.set_value_dict('edge_mask', edge_mask_dict)
101
154
 
102
- return Explanation(node_mask=node_mask, edge_mask=edge_mask)
155
+ else:
156
+ # For homogeneous graphs, process single masks
157
+ node_mask = self._post_process_mask(
158
+ self.node_mask,
159
+ self.hard_node_mask,
160
+ apply_sigmoid=True,
161
+ )
162
+ edge_mask = self._post_process_mask(
163
+ self.edge_mask,
164
+ self.hard_edge_mask,
165
+ apply_sigmoid=True,
166
+ )
167
+
168
+ # Create homogeneous explanation
169
+ explanation = Explanation(node_mask=node_mask, edge_mask=edge_mask)
170
+
171
+ return explanation
103
172
 
104
173
  def supports(self) -> bool:
105
174
  return True
106
175
 
176
+ @overload
107
177
  def _train(
108
178
  self,
109
179
  model: torch.nn.Module,
@@ -113,57 +183,222 @@ class GNNExplainer(ExplainerAlgorithm):
113
183
  target: Tensor,
114
184
  index: Optional[Union[int, Tensor]] = None,
115
185
  **kwargs,
116
- ):
186
+ ) -> None:
187
+ ...
188
+
189
+ @overload
190
+ def _train(
191
+ self,
192
+ model: torch.nn.Module,
193
+ x: Dict[NodeType, Tensor],
194
+ edge_index: Dict[EdgeType, Tensor],
195
+ *,
196
+ target: Tensor,
197
+ index: Optional[Union[int, Tensor]] = None,
198
+ **kwargs,
199
+ ) -> None:
200
+ ...
201
+
202
+ def _train(
203
+ self,
204
+ model: torch.nn.Module,
205
+ x: Union[Tensor, Dict[NodeType, Tensor]],
206
+ edge_index: Union[Tensor, Dict[EdgeType, Tensor]],
207
+ *,
208
+ target: Tensor,
209
+ index: Optional[Union[int, Tensor]] = None,
210
+ **kwargs,
211
+ ) -> None:
212
+ # Initialize masks based on input type
117
213
  self._initialize_masks(x, edge_index)
118
214
 
119
- parameters = []
120
- if self.node_mask is not None:
121
- parameters.append(self.node_mask)
122
- if self.edge_mask is not None:
123
- set_masks(model, self.edge_mask, edge_index, apply_sigmoid=True)
124
- parameters.append(self.edge_mask)
215
+ # Collect parameters for optimization
216
+ parameters = self._collect_parameters(model, edge_index)
125
217
 
218
+ # Create optimizer
126
219
  optimizer = torch.optim.Adam(parameters, lr=self.lr)
127
220
 
221
+ # Training loop
128
222
  for i in range(self.epochs):
129
223
  optimizer.zero_grad()
130
224
 
131
- h = x if self.node_mask is None else x * self.node_mask.sigmoid()
132
- y_hat, y = model(h, edge_index, **kwargs), target
225
+ # Forward pass with masked inputs
226
+ y_hat = self._forward_with_masks(model, x, edge_index, **kwargs)
227
+ y = target
133
228
 
229
+ # Handle index if provided
134
230
  if index is not None:
135
231
  y_hat, y = y_hat[index], y[index]
136
232
 
233
+ # Calculate loss
137
234
  loss = self._loss(y_hat, y)
138
235
 
236
+ # Backward pass
139
237
  loss.backward()
140
238
  optimizer.step()
141
239
 
142
- # In the first iteration, we collect the nodes and edges that are
143
- # involved into making the prediction. These are all the nodes and
144
- # edges with gradient != 0 (without regularization applied).
145
- if i == 0 and self.node_mask is not None:
146
- if self.node_mask.grad is None:
147
- raise ValueError("Could not compute gradients for node "
148
- "features. Please make sure that node "
149
- "features are used inside the model or "
150
- "disable it via `node_mask_type=None`.")
151
- self.hard_node_mask = self.node_mask.grad != 0.0
152
- if i == 0 and self.edge_mask is not None:
153
- if self.edge_mask.grad is None:
154
- raise ValueError("Could not compute gradients for edges. "
155
- "Please make sure that edges are used "
156
- "via message passing inside the model or "
157
- "disable it via `edge_mask_type=None`.")
158
- self.hard_edge_mask = self.edge_mask.grad != 0.0
159
-
160
- def _initialize_masks(self, x: Tensor, edge_index: Tensor):
240
+ # In the first iteration, collect gradients to identify important
241
+ # nodes/edges
242
+ if i == 0:
243
+ self._collect_gradients()
244
+
245
+ def _collect_parameters(self, model, edge_index):
246
+ """Collect parameters for optimization."""
247
+ parameters = []
248
+
249
+ if self.is_hetero:
250
+ # For heterogeneous graphs, collect parameters from all types
251
+ for mask in self.node_mask.values():
252
+ if mask is not None:
253
+ parameters.append(mask)
254
+ if any(v is not None for v in self.edge_mask.values()):
255
+ set_hetero_masks(model, self.edge_mask, edge_index)
256
+ for mask in self.edge_mask.values():
257
+ if mask is not None:
258
+ parameters.append(mask)
259
+ else:
260
+ # For homogeneous graphs, collect single parameters
261
+ if self.node_mask is not None:
262
+ parameters.append(self.node_mask)
263
+ if self.edge_mask is not None:
264
+ set_masks(model, self.edge_mask, edge_index,
265
+ apply_sigmoid=True)
266
+ parameters.append(self.edge_mask)
267
+
268
+ return parameters
269
+
270
+ @overload
271
+ def _forward_with_masks(
272
+ self,
273
+ model: torch.nn.Module,
274
+ x: Tensor,
275
+ edge_index: Tensor,
276
+ **kwargs,
277
+ ) -> Tensor:
278
+ ...
279
+
280
+ @overload
281
+ def _forward_with_masks(
282
+ self,
283
+ model: torch.nn.Module,
284
+ x: Dict[NodeType, Tensor],
285
+ edge_index: Dict[EdgeType, Tensor],
286
+ **kwargs,
287
+ ) -> Tensor:
288
+ ...
289
+
290
+ def _forward_with_masks(
291
+ self,
292
+ model: torch.nn.Module,
293
+ x: Union[Tensor, Dict[NodeType, Tensor]],
294
+ edge_index: Union[Tensor, Dict[EdgeType, Tensor]],
295
+ **kwargs,
296
+ ) -> Tensor:
297
+ """Forward pass with masked inputs."""
298
+ if self.is_hetero:
299
+ # Apply masks to heterogeneous inputs
300
+ h_dict = {}
301
+ for node_type, features in x.items():
302
+ if node_type in self.node_mask and self.node_mask[
303
+ node_type] is not None:
304
+ h_dict[node_type] = features * self.node_mask[
305
+ node_type].sigmoid()
306
+ else:
307
+ h_dict[node_type] = features
308
+
309
+ # Forward pass with masked features
310
+ return model(h_dict, edge_index, **kwargs)
311
+ else:
312
+ # Apply mask to homogeneous input
313
+ h = x if self.node_mask is None else x * self.node_mask.sigmoid()
314
+
315
+ # Forward pass with masked features
316
+ return model(h, edge_index, **kwargs)
317
+
318
+ def _initialize_masks(
319
+ self,
320
+ x: Union[Tensor, Dict[NodeType, Tensor]],
321
+ edge_index: Union[Tensor, Dict[EdgeType, Tensor]],
322
+ ) -> None:
161
323
  node_mask_type = self.explainer_config.node_mask_type
162
324
  edge_mask_type = self.explainer_config.edge_mask_type
163
325
 
164
- device = x.device
165
- (N, F), E = x.size(), edge_index.size(1)
326
+ if self.is_hetero:
327
+ # Initialize dictionaries for heterogeneous masks
328
+ self.node_mask = {}
329
+ self.hard_node_mask = {}
330
+ self.edge_mask = {}
331
+ self.hard_edge_mask = {}
332
+
333
+ # Initialize node masks for each node type
334
+ for node_type, features in x.items():
335
+ device = features.device
336
+ N, F = features.size()
337
+ self._initialize_node_mask(node_mask_type, node_type, N, F,
338
+ device)
339
+
340
+ # Initialize edge masks for each edge type
341
+ for edge_type, indices in edge_index.items():
342
+ device = indices.device
343
+ E = indices.size(1)
344
+ N = max(indices.max().item() + 1,
345
+ max(feat.size(0) for feat in x.values()))
346
+ self._initialize_edge_mask(edge_mask_type, edge_type, E, N,
347
+ device)
348
+ else:
349
+ # Initialize masks for homogeneous graph
350
+ device = x.device
351
+ (N, F), E = x.size(), edge_index.size(1)
352
+
353
+ # Initialize homogeneous node and edge masks
354
+ self._initialize_homogeneous_masks(node_mask_type, edge_mask_type,
355
+ N, F, E, device)
356
+
357
+ def _initialize_node_mask(
358
+ self,
359
+ node_mask_type,
360
+ node_type,
361
+ N,
362
+ F,
363
+ device,
364
+ ) -> None:
365
+ """Initialize node mask for a specific node type."""
366
+ std = 0.1
367
+ if node_mask_type is None:
368
+ self.node_mask[node_type] = None
369
+ self.hard_node_mask[node_type] = None
370
+ elif node_mask_type == MaskType.object:
371
+ self.node_mask[node_type] = Parameter(
372
+ torch.randn(N, 1, device=device) * std)
373
+ self.hard_node_mask[node_type] = None
374
+ elif node_mask_type == MaskType.attributes:
375
+ self.node_mask[node_type] = Parameter(
376
+ torch.randn(N, F, device=device) * std)
377
+ self.hard_node_mask[node_type] = None
378
+ elif node_mask_type == MaskType.common_attributes:
379
+ self.node_mask[node_type] = Parameter(
380
+ torch.randn(1, F, device=device) * std)
381
+ self.hard_node_mask[node_type] = None
382
+ else:
383
+ raise ValueError(f"Invalid node mask type: {node_mask_type}")
384
+
385
+ def _initialize_edge_mask(self, edge_mask_type, edge_type, E, N, device):
386
+ """Initialize edge mask for a specific edge type."""
387
+ if edge_mask_type is None:
388
+ self.edge_mask[edge_type] = None
389
+ self.hard_edge_mask[edge_type] = None
390
+ elif edge_mask_type == MaskType.object:
391
+ std = torch.nn.init.calculate_gain('relu') * sqrt(2.0 / (2 * N))
392
+ self.edge_mask[edge_type] = Parameter(
393
+ torch.randn(E, device=device) * std)
394
+ self.hard_edge_mask[edge_type] = None
395
+ else:
396
+ raise ValueError(f"Invalid edge mask type: {edge_mask_type}")
166
397
 
398
+ def _initialize_homogeneous_masks(self, node_mask_type, edge_mask_type, N,
399
+ F, E, device):
400
+ """Initialize masks for homogeneous graph."""
401
+ # Initialize node mask
167
402
  std = 0.1
168
403
  if node_mask_type is None:
169
404
  self.node_mask = None
@@ -174,43 +409,145 @@ class GNNExplainer(ExplainerAlgorithm):
174
409
  elif node_mask_type == MaskType.common_attributes:
175
410
  self.node_mask = Parameter(torch.randn(1, F, device=device) * std)
176
411
  else:
177
- assert False
412
+ raise ValueError(f"Invalid node mask type: {node_mask_type}")
178
413
 
414
+ # Initialize edge mask
179
415
  if edge_mask_type is None:
180
416
  self.edge_mask = None
181
417
  elif edge_mask_type == MaskType.object:
182
418
  std = torch.nn.init.calculate_gain('relu') * sqrt(2.0 / (2 * N))
183
419
  self.edge_mask = Parameter(torch.randn(E, device=device) * std)
184
420
  else:
185
- assert False
421
+ raise ValueError(f"Invalid edge mask type: {edge_mask_type}")
422
+
423
+ def _collect_gradients(self) -> None:
424
+ if self.is_hetero:
425
+ self._collect_hetero_gradients()
426
+ else:
427
+ self._collect_homo_gradients()
428
+
429
+ def _collect_hetero_gradients(self):
430
+ """Collect gradients for heterogeneous graph."""
431
+ for node_type, mask in self.node_mask.items():
432
+ if mask is not None:
433
+ if mask.grad is None:
434
+ raise ValueError(
435
+ f"Could not compute gradients for node masks of type "
436
+ f"'{node_type}'. Please make sure that node masks are "
437
+ f"used inside the model or disable it via "
438
+ f"`node_mask_type=None`.")
439
+
440
+ self.hard_node_mask[node_type] = mask.grad != 0.0
441
+
442
+ for edge_type, mask in self.edge_mask.items():
443
+ if mask is not None:
444
+ if mask.grad is None:
445
+ raise ValueError(
446
+ f"Could not compute gradients for edge masks of type "
447
+ f"'{edge_type}'. Please make sure that edge masks are "
448
+ f"used inside the model or disable it via "
449
+ f"`edge_mask_type=None`.")
450
+ self.hard_edge_mask[edge_type] = mask.grad != 0.0
451
+
452
+ def _collect_homo_gradients(self):
453
+ """Collect gradients for homogeneous graph."""
454
+ if self.node_mask is not None:
455
+ if self.node_mask.grad is None:
456
+ raise ValueError("Could not compute gradients for node "
457
+ "features. Please make sure that node "
458
+ "features are used inside the model or "
459
+ "disable it via `node_mask_type=None`.")
460
+ self.hard_node_mask = self.node_mask.grad != 0.0
461
+
462
+ if self.edge_mask is not None:
463
+ if self.edge_mask.grad is None:
464
+ raise ValueError("Could not compute gradients for edges. "
465
+ "Please make sure that edges are used "
466
+ "via message passing inside the model or "
467
+ "disable it via `edge_mask_type=None`.")
468
+ self.hard_edge_mask = self.edge_mask.grad != 0.0
186
469
 
187
470
  def _loss(self, y_hat: Tensor, y: Tensor) -> Tensor:
471
+ # Calculate base loss based on model configuration
472
+ loss = self._calculate_base_loss(y_hat, y)
473
+
474
+ # Apply regularization based on graph type
475
+ if self.is_hetero:
476
+ # Apply regularization for heterogeneous graph
477
+ loss = self._apply_hetero_regularization(loss)
478
+ else:
479
+ # Apply regularization for homogeneous graph
480
+ loss = self._apply_homo_regularization(loss)
481
+
482
+ return loss
483
+
484
+ def _calculate_base_loss(self, y_hat, y):
485
+ """Calculate base loss based on model configuration."""
188
486
  if self.model_config.mode == ModelMode.binary_classification:
189
- loss = self._loss_binary_classification(y_hat, y)
487
+ return self._loss_binary_classification(y_hat, y)
190
488
  elif self.model_config.mode == ModelMode.multiclass_classification:
191
- loss = self._loss_multiclass_classification(y_hat, y)
489
+ return self._loss_multiclass_classification(y_hat, y)
192
490
  elif self.model_config.mode == ModelMode.regression:
193
- loss = self._loss_regression(y_hat, y)
491
+ return self._loss_regression(y_hat, y)
194
492
  else:
195
- assert False
493
+ raise ValueError(f"Invalid model mode: {self.model_config.mode}")
494
+
495
+ def _apply_hetero_regularization(self, loss):
496
+ """Apply regularization for heterogeneous graph."""
497
+ # Apply regularization for each edge type
498
+ for edge_type, mask in self.edge_mask.items():
499
+ if (mask is not None
500
+ and self.hard_edge_mask[edge_type] is not None):
501
+ loss = self._add_mask_regularization(
502
+ loss, mask, self.hard_edge_mask[edge_type],
503
+ self.coeffs['edge_size'], self.coeffs['edge_reduction'],
504
+ self.coeffs['edge_ent'])
505
+
506
+ # Apply regularization for each node type
507
+ for node_type, mask in self.node_mask.items():
508
+ if (mask is not None
509
+ and self.hard_node_mask[node_type] is not None):
510
+ loss = self._add_mask_regularization(
511
+ loss, mask, self.hard_node_mask[node_type],
512
+ self.coeffs['node_feat_size'],
513
+ self.coeffs['node_feat_reduction'],
514
+ self.coeffs['node_feat_ent'])
196
515
 
516
+ return loss
517
+
518
+ def _apply_homo_regularization(self, loss):
519
+ """Apply regularization for homogeneous graph."""
520
+ # Apply regularization for edge mask
197
521
  if self.hard_edge_mask is not None:
198
522
  assert self.edge_mask is not None
199
- m = self.edge_mask[self.hard_edge_mask].sigmoid()
200
- edge_reduce = getattr(torch, self.coeffs['edge_reduction'])
201
- loss = loss + self.coeffs['edge_size'] * edge_reduce(m)
202
- ent = -m * torch.log(m + self.coeffs['EPS']) - (
203
- 1 - m) * torch.log(1 - m + self.coeffs['EPS'])
204
- loss = loss + self.coeffs['edge_ent'] * ent.mean()
523
+ loss = self._add_mask_regularization(loss, self.edge_mask,
524
+ self.hard_edge_mask,
525
+ self.coeffs['edge_size'],
526
+ self.coeffs['edge_reduction'],
527
+ self.coeffs['edge_ent'])
205
528
 
529
+ # Apply regularization for node mask
206
530
  if self.hard_node_mask is not None:
207
531
  assert self.node_mask is not None
208
- m = self.node_mask[self.hard_node_mask].sigmoid()
209
- node_reduce = getattr(torch, self.coeffs['node_feat_reduction'])
210
- loss = loss + self.coeffs['node_feat_size'] * node_reduce(m)
211
- ent = -m * torch.log(m + self.coeffs['EPS']) - (
212
- 1 - m) * torch.log(1 - m + self.coeffs['EPS'])
213
- loss = loss + self.coeffs['node_feat_ent'] * ent.mean()
532
+ loss = self._add_mask_regularization(
533
+ loss, self.node_mask, self.hard_node_mask,
534
+ self.coeffs['node_feat_size'],
535
+ self.coeffs['node_feat_reduction'],
536
+ self.coeffs['node_feat_ent'])
537
+
538
+ return loss
539
+
540
+ def _add_mask_regularization(self, loss, mask, hard_mask, size_coeff,
541
+ reduction_name, ent_coeff):
542
+ """Add size and entropy regularization for a mask."""
543
+ m = mask[hard_mask].sigmoid()
544
+ reduce_fn = getattr(torch, reduction_name)
545
+ # Add size regularization
546
+ loss = loss + size_coeff * reduce_fn(m)
547
+ # Add entropy regularization
548
+ ent = -m * torch.log(m + self.coeffs['EPS']) - (
549
+ 1 - m) * torch.log(1 - m + self.coeffs['EPS'])
550
+ loss = loss + ent_coeff * ent.mean()
214
551
 
215
552
  return loss
216
553
 
@@ -223,7 +560,7 @@ class GNNExplainer(ExplainerAlgorithm):
223
560
  class GNNExplainer_:
224
561
  r"""Deprecated version for :class:`GNNExplainer`."""
225
562
 
226
- coeffs = GNNExplainer.coeffs
563
+ coeffs = GNNExplainer.default_coeffs
227
564
 
228
565
  conversion_node_mask_type = {
229
566
  'feature': 'common_attributes',
@@ -202,25 +202,25 @@ class GraphMaskExplainer(ExplainerAlgorithm):
202
202
 
203
203
  baselines, self.gates, full_biases = [], torch.nn.ModuleList(), []
204
204
 
205
- for v_dim, m_dim, h_dim in zip(i_dim, j_dim, h_dim):
205
+ for v_dim, m_dim, o_dim in zip(i_dim, j_dim, h_dim):
206
206
  self.transform, self.layer_norm = [], []
207
207
  input_dims = [v_dim, m_dim, v_dim]
208
208
  for _, input_dim in enumerate(input_dims):
209
209
  self.transform.append(
210
- Linear(input_dim, h_dim, bias=False).to(device))
211
- self.layer_norm.append(LayerNorm(h_dim).to(device))
210
+ Linear(input_dim, o_dim, bias=False).to(device))
211
+ self.layer_norm.append(LayerNorm(o_dim).to(device))
212
212
 
213
213
  self.transforms = torch.nn.ModuleList(self.transform)
214
214
  self.layer_norms = torch.nn.ModuleList(self.layer_norm)
215
215
 
216
216
  self.full_bias = Parameter(
217
- torch.tensor(h_dim, dtype=torch.float, device=device))
217
+ torch.tensor(o_dim, dtype=torch.float, device=device))
218
218
  full_biases.append(self.full_bias)
219
219
 
220
- self.reset_parameters(input_dims, h_dim)
220
+ self.reset_parameters(input_dims, o_dim)
221
221
 
222
222
  self.non_linear = ReLU()
223
- self.output_layer = Linear(h_dim, 1).to(device)
223
+ self.output_layer = Linear(o_dim, 1).to(device)
224
224
 
225
225
  gate = [
226
226
  self.transforms, self.layer_norms, self.non_linear,
@@ -274,7 +274,7 @@ class GraphMaskExplainer(ExplainerAlgorithm):
274
274
  elif self.model_config.mode == ModelMode.regression:
275
275
  loss = self._loss_regression(y_hat, y)
276
276
  else:
277
- assert False
277
+ raise AssertionError()
278
278
 
279
279
  g = torch.relu(loss - self.allowance).mean()
280
280
  f = penalty * self.penalty_scaling
@@ -385,7 +385,7 @@ class GraphMaskExplainer(ExplainerAlgorithm):
385
385
  f'Train explainer for graph {index} with layer '
386
386
  f'{layer}')
387
387
  self._enable_layer(layer)
388
- for epoch in range(self.epochs):
388
+ for _ in range(self.epochs):
389
389
  with torch.no_grad():
390
390
  model(x, edge_index, **kwargs)
391
391
  gates, total_penalty = [], 0