pyg-nightly 2.7.0.dev20250406__py3-none-any.whl → 2.7.0.dev20250407__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pyg-nightly
3
- Version: 2.7.0.dev20250406
3
+ Version: 2.7.0.dev20250407
4
4
  Summary: Graph Neural Network Library for PyTorch
5
5
  Keywords: deep-learning,pytorch,geometric-deep-learning,graph-neural-networks,graph-convolutional-networks
6
6
  Author-email: Matthias Fey <matthias@pyg.org>
@@ -1,4 +1,4 @@
1
- torch_geometric/__init__.py,sha256=EFUlgJy_cHoHOgqO8KCynWIfRJFW8DFqG7O5v9DFOzI,1978
1
+ torch_geometric/__init__.py,sha256=peFG3sVQB1R7kf4erqOCMV6_UO_k1PvDHdQg1HBaqog,1978
2
2
  torch_geometric/_compile.py,sha256=f-WQeH4VLi5Hn9lrgztFUCSrN_FImjhQa6BxFzcYC38,1338
3
3
  torch_geometric/_onnx.py,sha256=V9ffrIKSqhDw6xUZ12lkuSfNs48cQp2EeJ6Z19GfnVw,349
4
4
  torch_geometric/backend.py,sha256=lVaf7aLoVaB3M-UcByUJ1G4T4FOK6LXAg0CF4W3E8jo,1575
@@ -200,7 +200,7 @@ torch_geometric/explain/algorithm/base.py,sha256=wwJcREUFKDLFUDjRa9o4X3DWqQgMvhS
200
200
  torch_geometric/explain/algorithm/captum.py,sha256=k6hNgC5Kn9lVirOYVJzej8-hRuf5C2mPFUXFLd2wWsY,12857
201
201
  torch_geometric/explain/algorithm/captum_explainer.py,sha256=oz-c40hvdzii4_chEQPHzQo_dFjHr9HLuJhDLsqRIVU,7346
202
202
  torch_geometric/explain/algorithm/dummy_explainer.py,sha256=jvcVQmfngmUWgoKa5p7CXzju2HM5D5DfieJhZW3gbLc,2872
203
- torch_geometric/explain/algorithm/gnn_explainer.py,sha256=TRGwaKYn9nLn3fp0rSSzeGs9uHj2rZzfomMseDfq8O8,12454
203
+ torch_geometric/explain/algorithm/gnn_explainer.py,sha256=iu45fGWdd4c6wNczWEAT-29HCAz7ncuoaS6cpx-xDJM,24660
204
204
  torch_geometric/explain/algorithm/graphmask_explainer.py,sha256=T2B081dK-JSpaQmutnkQd5xF3JF49_dPZCOgwqIKJDk,21367
205
205
  torch_geometric/explain/algorithm/pg_explainer.py,sha256=zPsl0tT9ISSWK1xP1KKpe1ZjUarhSB736WTtqwcmDIo,10372
206
206
  torch_geometric/explain/algorithm/utils.py,sha256=eh0ARPG41V7piVw5jdMYpV0p7WjTlpehnY-bWqPV_zg,2564
@@ -636,7 +636,7 @@ torch_geometric/utils/undirected.py,sha256=H_nfpI0_WluOG6VfjPyldvcjL4w5USAKWu2x5
636
636
  torch_geometric/visualization/__init__.py,sha256=PyR_4K5SafsJrBr6qWrkjKr6GBL1b7FtZybyXCDEVwY,154
637
637
  torch_geometric/visualization/graph.py,sha256=ZuLPL92yGRi7lxlqsUPwL_EVVXF7P2kMcveTtW79vpA,4784
638
638
  torch_geometric/visualization/influence.py,sha256=CWMvuNA_Nf1sfbJmQgn58yS4OFpeKXeZPe7kEuvkUBw,477
639
- pyg_nightly-2.7.0.dev20250406.dist-info/licenses/LICENSE,sha256=ic-27cMJc1kWoMEYncz3Ya3Ur2Bi3bNLWib2DT763-o,1067
640
- pyg_nightly-2.7.0.dev20250406.dist-info/WHEEL,sha256=G2gURzTEtmeR8nrdXUJfNiB3VYVxigPQ-bEQujpNiNs,82
641
- pyg_nightly-2.7.0.dev20250406.dist-info/METADATA,sha256=csAfUo5zCWohsFtqQsynpzuSflZRg7_4f9DB3JWSVWE,63021
642
- pyg_nightly-2.7.0.dev20250406.dist-info/RECORD,,
639
+ pyg_nightly-2.7.0.dev20250407.dist-info/licenses/LICENSE,sha256=ic-27cMJc1kWoMEYncz3Ya3Ur2Bi3bNLWib2DT763-o,1067
640
+ pyg_nightly-2.7.0.dev20250407.dist-info/WHEEL,sha256=G2gURzTEtmeR8nrdXUJfNiB3VYVxigPQ-bEQujpNiNs,82
641
+ pyg_nightly-2.7.0.dev20250407.dist-info/METADATA,sha256=uPWrrVVadg-GQ1_t4bwTD4LTyNy8r0m9_g5BaKW1AVs,63021
642
+ pyg_nightly-2.7.0.dev20250407.dist-info/RECORD,,
@@ -31,7 +31,7 @@ from .lazy_loader import LazyLoader
31
31
  contrib = LazyLoader('contrib', globals(), 'torch_geometric.contrib')
32
32
  graphgym = LazyLoader('graphgym', globals(), 'torch_geometric.graphgym')
33
33
 
34
- __version__ = '2.7.0.dev20250406'
34
+ __version__ = '2.7.0.dev20250407'
35
35
 
36
36
  __all__ = [
37
37
  'Index',
@@ -1,14 +1,24 @@
1
1
  from math import sqrt
2
- from typing import Optional, Tuple, Union
2
+ from typing import Dict, Optional, Tuple, Union, overload
3
3
 
4
4
  import torch
5
5
  from torch import Tensor
6
6
  from torch.nn.parameter import Parameter
7
7
 
8
- from torch_geometric.explain import ExplainerConfig, Explanation, ModelConfig
8
+ from torch_geometric.explain import (
9
+ ExplainerConfig,
10
+ Explanation,
11
+ HeteroExplanation,
12
+ ModelConfig,
13
+ )
9
14
  from torch_geometric.explain.algorithm import ExplainerAlgorithm
10
- from torch_geometric.explain.algorithm.utils import clear_masks, set_masks
15
+ from torch_geometric.explain.algorithm.utils import (
16
+ clear_masks,
17
+ set_hetero_masks,
18
+ set_masks,
19
+ )
11
20
  from torch_geometric.explain.config import MaskType, ModelMode, ModelTaskLevel
21
+ from torch_geometric.typing import EdgeType, NodeType
12
22
 
13
23
 
14
24
  class GNNExplainer(ExplainerAlgorithm):
@@ -69,7 +79,9 @@ class GNNExplainer(ExplainerAlgorithm):
69
79
 
70
80
  self.node_mask = self.hard_node_mask = None
71
81
  self.edge_mask = self.hard_edge_mask = None
82
+ self.is_hetero = False
72
83
 
84
+ @overload
73
85
  def forward(
74
86
  self,
75
87
  model: torch.nn.Module,
@@ -80,30 +92,87 @@ class GNNExplainer(ExplainerAlgorithm):
80
92
  index: Optional[Union[int, Tensor]] = None,
81
93
  **kwargs,
82
94
  ) -> Explanation:
83
- if isinstance(x, dict) or isinstance(edge_index, dict):
84
- raise ValueError(f"Heterogeneous graphs not yet supported in "
85
- f"'{self.__class__.__name__}'")
95
+ ...
86
96
 
87
- self._train(model, x, edge_index, target=target, index=index, **kwargs)
88
-
89
- node_mask = self._post_process_mask(
90
- self.node_mask,
91
- self.hard_node_mask,
92
- apply_sigmoid=True,
93
- )
94
- edge_mask = self._post_process_mask(
95
- self.edge_mask,
96
- self.hard_edge_mask,
97
- apply_sigmoid=True,
98
- )
97
+ @overload
98
+ def forward(
99
+ self,
100
+ model: torch.nn.Module,
101
+ x: Dict[NodeType, Tensor],
102
+ edge_index: Dict[EdgeType, Tensor],
103
+ *,
104
+ target: Tensor,
105
+ index: Optional[Union[int, Tensor]] = None,
106
+ **kwargs,
107
+ ) -> HeteroExplanation:
108
+ ...
99
109
 
110
+ def forward(
111
+ self,
112
+ model: torch.nn.Module,
113
+ x: Union[Tensor, Dict[NodeType, Tensor]],
114
+ edge_index: Union[Tensor, Dict[EdgeType, Tensor]],
115
+ *,
116
+ target: Tensor,
117
+ index: Optional[Union[int, Tensor]] = None,
118
+ **kwargs,
119
+ ) -> Union[Explanation, HeteroExplanation]:
120
+ self.is_hetero = isinstance(x, dict)
121
+ self._train(model, x, edge_index, target=target, index=index, **kwargs)
122
+ explanation = self._create_explanation()
100
123
  self._clean_model(model)
124
+ return explanation
125
+
126
+ def _create_explanation(self) -> Union[Explanation, HeteroExplanation]:
127
+ """Create an explanation object from the current masks."""
128
+ if self.is_hetero:
129
+ # For heterogeneous graphs, process each type separately
130
+ node_mask_dict = {}
131
+ edge_mask_dict = {}
132
+
133
+ for node_type, mask in self.node_mask.items():
134
+ if mask is not None:
135
+ node_mask_dict[node_type] = self._post_process_mask(
136
+ mask,
137
+ self.hard_node_mask[node_type],
138
+ apply_sigmoid=True,
139
+ )
140
+
141
+ for edge_type, mask in self.edge_mask.items():
142
+ if mask is not None:
143
+ edge_mask_dict[edge_type] = self._post_process_mask(
144
+ mask,
145
+ self.hard_edge_mask[edge_type],
146
+ apply_sigmoid=True,
147
+ )
148
+
149
+ # Create heterogeneous explanation
150
+ explanation = HeteroExplanation()
151
+ explanation.set_value_dict('node_mask', node_mask_dict)
152
+ explanation.set_value_dict('edge_mask', edge_mask_dict)
101
153
 
102
- return Explanation(node_mask=node_mask, edge_mask=edge_mask)
154
+ else:
155
+ # For homogeneous graphs, process single masks
156
+ node_mask = self._post_process_mask(
157
+ self.node_mask,
158
+ self.hard_node_mask,
159
+ apply_sigmoid=True,
160
+ )
161
+ edge_mask = self._post_process_mask(
162
+ self.edge_mask,
163
+ self.hard_edge_mask,
164
+ apply_sigmoid=True,
165
+ )
166
+
167
+ # Create homogeneous explanation
168
+ explanation = Explanation(node_mask=node_mask, edge_mask=edge_mask)
169
+
170
+ return explanation
103
171
 
104
172
  def supports(self) -> bool:
105
173
  return True
106
174
 
175
+ @overload
107
176
  def _train(
108
177
  self,
109
178
  model: torch.nn.Module,
@@ -113,57 +182,222 @@ class GNNExplainer(ExplainerAlgorithm):
113
182
  target: Tensor,
114
183
  index: Optional[Union[int, Tensor]] = None,
115
184
  **kwargs,
116
- ):
185
+ ) -> None:
186
+ ...
187
+
188
+ @overload
189
+ def _train(
190
+ self,
191
+ model: torch.nn.Module,
192
+ x: Dict[NodeType, Tensor],
193
+ edge_index: Dict[EdgeType, Tensor],
194
+ *,
195
+ target: Tensor,
196
+ index: Optional[Union[int, Tensor]] = None,
197
+ **kwargs,
198
+ ) -> None:
199
+ ...
200
+
201
+ def _train(
202
+ self,
203
+ model: torch.nn.Module,
204
+ x: Union[Tensor, Dict[NodeType, Tensor]],
205
+ edge_index: Union[Tensor, Dict[EdgeType, Tensor]],
206
+ *,
207
+ target: Tensor,
208
+ index: Optional[Union[int, Tensor]] = None,
209
+ **kwargs,
210
+ ) -> None:
211
+ # Initialize masks based on input type
117
212
  self._initialize_masks(x, edge_index)
118
213
 
119
- parameters = []
120
- if self.node_mask is not None:
121
- parameters.append(self.node_mask)
122
- if self.edge_mask is not None:
123
- set_masks(model, self.edge_mask, edge_index, apply_sigmoid=True)
124
- parameters.append(self.edge_mask)
214
+ # Collect parameters for optimization
215
+ parameters = self._collect_parameters(model, edge_index)
125
216
 
217
+ # Create optimizer
126
218
  optimizer = torch.optim.Adam(parameters, lr=self.lr)
127
219
 
220
+ # Training loop
128
221
  for i in range(self.epochs):
129
222
  optimizer.zero_grad()
130
223
 
131
- h = x if self.node_mask is None else x * self.node_mask.sigmoid()
132
- y_hat, y = model(h, edge_index, **kwargs), target
224
+ # Forward pass with masked inputs
225
+ y_hat = self._forward_with_masks(model, x, edge_index, **kwargs)
226
+ y = target
133
227
 
228
+ # Handle index if provided
134
229
  if index is not None:
135
230
  y_hat, y = y_hat[index], y[index]
136
231
 
232
+ # Calculate loss
137
233
  loss = self._loss(y_hat, y)
138
234
 
235
+ # Backward pass
139
236
  loss.backward()
140
237
  optimizer.step()
141
238
 
142
- # In the first iteration, we collect the nodes and edges that are
143
- # involved into making the prediction. These are all the nodes and
144
- # edges with gradient != 0 (without regularization applied).
145
- if i == 0 and self.node_mask is not None:
146
- if self.node_mask.grad is None:
147
- raise ValueError("Could not compute gradients for node "
148
- "features. Please make sure that node "
149
- "features are used inside the model or "
150
- "disable it via `node_mask_type=None`.")
151
- self.hard_node_mask = self.node_mask.grad != 0.0
152
- if i == 0 and self.edge_mask is not None:
153
- if self.edge_mask.grad is None:
154
- raise ValueError("Could not compute gradients for edges. "
155
- "Please make sure that edges are used "
156
- "via message passing inside the model or "
157
- "disable it via `edge_mask_type=None`.")
158
- self.hard_edge_mask = self.edge_mask.grad != 0.0
159
-
160
- def _initialize_masks(self, x: Tensor, edge_index: Tensor):
239
+ # In the first iteration, collect gradients to identify important
240
+ # nodes/edges
241
+ if i == 0:
242
+ self._collect_gradients()
243
+
244
+ def _collect_parameters(self, model, edge_index):
245
+ """Collect parameters for optimization."""
246
+ parameters = []
247
+
248
+ if self.is_hetero:
249
+ # For heterogeneous graphs, collect parameters from all types
250
+ for mask in self.node_mask.values():
251
+ if mask is not None:
252
+ parameters.append(mask)
253
+ if any(v is not None for v in self.edge_mask.values()):
254
+ set_hetero_masks(model, self.edge_mask, edge_index)
255
+ for mask in self.edge_mask.values():
256
+ if mask is not None:
257
+ parameters.append(mask)
258
+ else:
259
+ # For homogeneous graphs, collect single parameters
260
+ if self.node_mask is not None:
261
+ parameters.append(self.node_mask)
262
+ if self.edge_mask is not None:
263
+ set_masks(model, self.edge_mask, edge_index,
264
+ apply_sigmoid=True)
265
+ parameters.append(self.edge_mask)
266
+
267
+ return parameters
268
+
269
+ @overload
270
+ def _forward_with_masks(
271
+ self,
272
+ model: torch.nn.Module,
273
+ x: Tensor,
274
+ edge_index: Tensor,
275
+ **kwargs,
276
+ ) -> Tensor:
277
+ ...
278
+
279
+ @overload
280
+ def _forward_with_masks(
281
+ self,
282
+ model: torch.nn.Module,
283
+ x: Dict[NodeType, Tensor],
284
+ edge_index: Dict[EdgeType, Tensor],
285
+ **kwargs,
286
+ ) -> Tensor:
287
+ ...
288
+
289
+ def _forward_with_masks(
290
+ self,
291
+ model: torch.nn.Module,
292
+ x: Union[Tensor, Dict[NodeType, Tensor]],
293
+ edge_index: Union[Tensor, Dict[EdgeType, Tensor]],
294
+ **kwargs,
295
+ ) -> Tensor:
296
+ """Forward pass with masked inputs."""
297
+ if self.is_hetero:
298
+ # Apply masks to heterogeneous inputs
299
+ h_dict = {}
300
+ for node_type, features in x.items():
301
+ if node_type in self.node_mask and self.node_mask[
302
+ node_type] is not None:
303
+ h_dict[node_type] = features * self.node_mask[
304
+ node_type].sigmoid()
305
+ else:
306
+ h_dict[node_type] = features
307
+
308
+ # Forward pass with masked features
309
+ return model(h_dict, edge_index, **kwargs)
310
+ else:
311
+ # Apply mask to homogeneous input
312
+ h = x if self.node_mask is None else x * self.node_mask.sigmoid()
313
+
314
+ # Forward pass with masked features
315
+ return model(h, edge_index, **kwargs)
316
+
317
+ def _initialize_masks(
318
+ self,
319
+ x: Union[Tensor, Dict[NodeType, Tensor]],
320
+ edge_index: Union[Tensor, Dict[EdgeType, Tensor]],
321
+ ) -> None:
161
322
  node_mask_type = self.explainer_config.node_mask_type
162
323
  edge_mask_type = self.explainer_config.edge_mask_type
163
324
 
164
- device = x.device
165
- (N, F), E = x.size(), edge_index.size(1)
325
+ if self.is_hetero:
326
+ # Initialize dictionaries for heterogeneous masks
327
+ self.node_mask = {}
328
+ self.hard_node_mask = {}
329
+ self.edge_mask = {}
330
+ self.hard_edge_mask = {}
331
+
332
+ # Initialize node masks for each node type
333
+ for node_type, features in x.items():
334
+ device = features.device
335
+ N, F = features.size()
336
+ self._initialize_node_mask(node_mask_type, node_type, N, F,
337
+ device)
338
+
339
+ # Initialize edge masks for each edge type
340
+ for edge_type, indices in edge_index.items():
341
+ device = indices.device
342
+ E = indices.size(1)
343
+ N = max(indices.max().item() + 1,
344
+ max(feat.size(0) for feat in x.values()))
345
+ self._initialize_edge_mask(edge_mask_type, edge_type, E, N,
346
+ device)
347
+ else:
348
+ # Initialize masks for homogeneous graph
349
+ device = x.device
350
+ (N, F), E = x.size(), edge_index.size(1)
351
+
352
+ # Initialize homogeneous node and edge masks
353
+ self._initialize_homogeneous_masks(node_mask_type, edge_mask_type,
354
+ N, F, E, device)
355
+
356
+ def _initialize_node_mask(
357
+ self,
358
+ node_mask_type,
359
+ node_type,
360
+ N,
361
+ F,
362
+ device,
363
+ ) -> None:
364
+ """Initialize node mask for a specific node type."""
365
+ std = 0.1
366
+ if node_mask_type is None:
367
+ self.node_mask[node_type] = None
368
+ self.hard_node_mask[node_type] = None
369
+ elif node_mask_type == MaskType.object:
370
+ self.node_mask[node_type] = Parameter(
371
+ torch.randn(N, 1, device=device) * std)
372
+ self.hard_node_mask[node_type] = None
373
+ elif node_mask_type == MaskType.attributes:
374
+ self.node_mask[node_type] = Parameter(
375
+ torch.randn(N, F, device=device) * std)
376
+ self.hard_node_mask[node_type] = None
377
+ elif node_mask_type == MaskType.common_attributes:
378
+ self.node_mask[node_type] = Parameter(
379
+ torch.randn(1, F, device=device) * std)
380
+ self.hard_node_mask[node_type] = None
381
+ else:
382
+ raise ValueError(f"Invalid node mask type: {node_mask_type}")
383
+
384
+ def _initialize_edge_mask(self, edge_mask_type, edge_type, E, N, device):
385
+ """Initialize edge mask for a specific edge type."""
386
+ if edge_mask_type is None:
387
+ self.edge_mask[edge_type] = None
388
+ self.hard_edge_mask[edge_type] = None
389
+ elif edge_mask_type == MaskType.object:
390
+ std = torch.nn.init.calculate_gain('relu') * sqrt(2.0 / (2 * N))
391
+ self.edge_mask[edge_type] = Parameter(
392
+ torch.randn(E, device=device) * std)
393
+ self.hard_edge_mask[edge_type] = None
394
+ else:
395
+ raise ValueError(f"Invalid edge mask type: {edge_mask_type}")
166
396
 
397
+ def _initialize_homogeneous_masks(self, node_mask_type, edge_mask_type, N,
398
+ F, E, device):
399
+ """Initialize masks for homogeneous graph."""
400
+ # Initialize node mask
167
401
  std = 0.1
168
402
  if node_mask_type is None:
169
403
  self.node_mask = None
@@ -174,43 +408,145 @@ class GNNExplainer(ExplainerAlgorithm):
174
408
  elif node_mask_type == MaskType.common_attributes:
175
409
  self.node_mask = Parameter(torch.randn(1, F, device=device) * std)
176
410
  else:
177
- assert False
411
+ raise ValueError(f"Invalid node mask type: {node_mask_type}")
178
412
 
413
+ # Initialize edge mask
179
414
  if edge_mask_type is None:
180
415
  self.edge_mask = None
181
416
  elif edge_mask_type == MaskType.object:
182
417
  std = torch.nn.init.calculate_gain('relu') * sqrt(2.0 / (2 * N))
183
418
  self.edge_mask = Parameter(torch.randn(E, device=device) * std)
184
419
  else:
185
- assert False
420
+ raise ValueError(f"Invalid edge mask type: {edge_mask_type}")
421
+
422
+ def _collect_gradients(self) -> None:
423
+ if self.is_hetero:
424
+ self._collect_hetero_gradients()
425
+ else:
426
+ self._collect_homo_gradients()
427
+
428
+ def _collect_hetero_gradients(self):
429
+ """Collect gradients for heterogeneous graph."""
430
+ for node_type, mask in self.node_mask.items():
431
+ if mask is not None:
432
+ if mask.grad is None:
433
+ raise ValueError(
434
+ f"Could not compute gradients for node masks of type "
435
+ f"'{node_type}'. Please make sure that node masks are "
436
+ f"used inside the model or disable it via "
437
+ f"`node_mask_type=None`.")
438
+
439
+ self.hard_node_mask[node_type] = mask.grad != 0.0
440
+
441
+ for edge_type, mask in self.edge_mask.items():
442
+ if mask is not None:
443
+ if mask.grad is None:
444
+ raise ValueError(
445
+ f"Could not compute gradients for edge masks of type "
446
+ f"'{edge_type}'. Please make sure that edge masks are "
447
+ f"used inside the model or disable it via "
448
+ f"`edge_mask_type=None`.")
449
+ self.hard_edge_mask[edge_type] = mask.grad != 0.0
450
+
451
+ def _collect_homo_gradients(self):
452
+ """Collect gradients for homogeneous graph."""
453
+ if self.node_mask is not None:
454
+ if self.node_mask.grad is None:
455
+ raise ValueError("Could not compute gradients for node "
456
+ "features. Please make sure that node "
457
+ "features are used inside the model or "
458
+ "disable it via `node_mask_type=None`.")
459
+ self.hard_node_mask = self.node_mask.grad != 0.0
460
+
461
+ if self.edge_mask is not None:
462
+ if self.edge_mask.grad is None:
463
+ raise ValueError("Could not compute gradients for edges. "
464
+ "Please make sure that edges are used "
465
+ "via message passing inside the model or "
466
+ "disable it via `edge_mask_type=None`.")
467
+ self.hard_edge_mask = self.edge_mask.grad != 0.0
186
468
 
187
469
  def _loss(self, y_hat: Tensor, y: Tensor) -> Tensor:
470
+ # Calculate base loss based on model configuration
471
+ loss = self._calculate_base_loss(y_hat, y)
472
+
473
+ # Apply regularization based on graph type
474
+ if self.is_hetero:
475
+ # Apply regularization for heterogeneous graph
476
+ loss = self._apply_hetero_regularization(loss)
477
+ else:
478
+ # Apply regularization for homogeneous graph
479
+ loss = self._apply_homo_regularization(loss)
480
+
481
+ return loss
482
+
483
+ def _calculate_base_loss(self, y_hat, y):
484
+ """Calculate base loss based on model configuration."""
188
485
  if self.model_config.mode == ModelMode.binary_classification:
189
- loss = self._loss_binary_classification(y_hat, y)
486
+ return self._loss_binary_classification(y_hat, y)
190
487
  elif self.model_config.mode == ModelMode.multiclass_classification:
191
- loss = self._loss_multiclass_classification(y_hat, y)
488
+ return self._loss_multiclass_classification(y_hat, y)
192
489
  elif self.model_config.mode == ModelMode.regression:
193
- loss = self._loss_regression(y_hat, y)
490
+ return self._loss_regression(y_hat, y)
194
491
  else:
195
- assert False
492
+ raise ValueError(f"Invalid model mode: {self.model_config.mode}")
493
+
494
+ def _apply_hetero_regularization(self, loss):
495
+ """Apply regularization for heterogeneous graph."""
496
+ # Apply regularization for each edge type
497
+ for edge_type, mask in self.edge_mask.items():
498
+ if (mask is not None
499
+ and self.hard_edge_mask[edge_type] is not None):
500
+ loss = self._add_mask_regularization(
501
+ loss, mask, self.hard_edge_mask[edge_type],
502
+ self.coeffs['edge_size'], self.coeffs['edge_reduction'],
503
+ self.coeffs['edge_ent'])
504
+
505
+ # Apply regularization for each node type
506
+ for node_type, mask in self.node_mask.items():
507
+ if (mask is not None
508
+ and self.hard_node_mask[node_type] is not None):
509
+ loss = self._add_mask_regularization(
510
+ loss, mask, self.hard_node_mask[node_type],
511
+ self.coeffs['node_feat_size'],
512
+ self.coeffs['node_feat_reduction'],
513
+ self.coeffs['node_feat_ent'])
196
514
 
515
+ return loss
516
+
517
+ def _apply_homo_regularization(self, loss):
518
+ """Apply regularization for homogeneous graph."""
519
+ # Apply regularization for edge mask
197
520
  if self.hard_edge_mask is not None:
198
521
  assert self.edge_mask is not None
199
- m = self.edge_mask[self.hard_edge_mask].sigmoid()
200
- edge_reduce = getattr(torch, self.coeffs['edge_reduction'])
201
- loss = loss + self.coeffs['edge_size'] * edge_reduce(m)
202
- ent = -m * torch.log(m + self.coeffs['EPS']) - (
203
- 1 - m) * torch.log(1 - m + self.coeffs['EPS'])
204
- loss = loss + self.coeffs['edge_ent'] * ent.mean()
522
+ loss = self._add_mask_regularization(loss, self.edge_mask,
523
+ self.hard_edge_mask,
524
+ self.coeffs['edge_size'],
525
+ self.coeffs['edge_reduction'],
526
+ self.coeffs['edge_ent'])
205
527
 
528
+ # Apply regularization for node mask
206
529
  if self.hard_node_mask is not None:
207
530
  assert self.node_mask is not None
208
- m = self.node_mask[self.hard_node_mask].sigmoid()
209
- node_reduce = getattr(torch, self.coeffs['node_feat_reduction'])
210
- loss = loss + self.coeffs['node_feat_size'] * node_reduce(m)
211
- ent = -m * torch.log(m + self.coeffs['EPS']) - (
212
- 1 - m) * torch.log(1 - m + self.coeffs['EPS'])
213
- loss = loss + self.coeffs['node_feat_ent'] * ent.mean()
531
+ loss = self._add_mask_regularization(
532
+ loss, self.node_mask, self.hard_node_mask,
533
+ self.coeffs['node_feat_size'],
534
+ self.coeffs['node_feat_reduction'],
535
+ self.coeffs['node_feat_ent'])
536
+
537
+ return loss
538
+
539
+ def _add_mask_regularization(self, loss, mask, hard_mask, size_coeff,
540
+ reduction_name, ent_coeff):
541
+ """Add size and entropy regularization for a mask."""
542
+ m = mask[hard_mask].sigmoid()
543
+ reduce_fn = getattr(torch, reduction_name)
544
+ # Add size regularization
545
+ loss = loss + size_coeff * reduce_fn(m)
546
+ # Add entropy regularization
547
+ ent = -m * torch.log(m + self.coeffs['EPS']) - (
548
+ 1 - m) * torch.log(1 - m + self.coeffs['EPS'])
549
+ loss = loss + ent_coeff * ent.mean()
214
550
 
215
551
  return loss
216
552