pyg-nightly 2.7.0.dev20250405__py3-none-any.whl → 2.7.0.dev20250407__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pyg_nightly-2.7.0.dev20250405.dist-info → pyg_nightly-2.7.0.dev20250407.dist-info}/METADATA +1 -1
- {pyg_nightly-2.7.0.dev20250405.dist-info → pyg_nightly-2.7.0.dev20250407.dist-info}/RECORD +6 -6
- torch_geometric/__init__.py +1 -1
- torch_geometric/explain/algorithm/gnn_explainer.py +403 -67
- {pyg_nightly-2.7.0.dev20250405.dist-info → pyg_nightly-2.7.0.dev20250407.dist-info}/WHEEL +0 -0
- {pyg_nightly-2.7.0.dev20250405.dist-info → pyg_nightly-2.7.0.dev20250407.dist-info}/licenses/LICENSE +0 -0
{pyg_nightly-2.7.0.dev20250405.dist-info → pyg_nightly-2.7.0.dev20250407.dist-info}/METADATA
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: pyg-nightly
|
3
|
-
Version: 2.7.0.
|
3
|
+
Version: 2.7.0.dev20250407
|
4
4
|
Summary: Graph Neural Network Library for PyTorch
|
5
5
|
Keywords: deep-learning,pytorch,geometric-deep-learning,graph-neural-networks,graph-convolutional-networks
|
6
6
|
Author-email: Matthias Fey <matthias@pyg.org>
|
@@ -1,4 +1,4 @@
|
|
1
|
-
torch_geometric/__init__.py,sha256=
|
1
|
+
torch_geometric/__init__.py,sha256=peFG3sVQB1R7kf4erqOCMV6_UO_k1PvDHdQg1HBaqog,1978
|
2
2
|
torch_geometric/_compile.py,sha256=f-WQeH4VLi5Hn9lrgztFUCSrN_FImjhQa6BxFzcYC38,1338
|
3
3
|
torch_geometric/_onnx.py,sha256=V9ffrIKSqhDw6xUZ12lkuSfNs48cQp2EeJ6Z19GfnVw,349
|
4
4
|
torch_geometric/backend.py,sha256=lVaf7aLoVaB3M-UcByUJ1G4T4FOK6LXAg0CF4W3E8jo,1575
|
@@ -200,7 +200,7 @@ torch_geometric/explain/algorithm/base.py,sha256=wwJcREUFKDLFUDjRa9o4X3DWqQgMvhS
|
|
200
200
|
torch_geometric/explain/algorithm/captum.py,sha256=k6hNgC5Kn9lVirOYVJzej8-hRuf5C2mPFUXFLd2wWsY,12857
|
201
201
|
torch_geometric/explain/algorithm/captum_explainer.py,sha256=oz-c40hvdzii4_chEQPHzQo_dFjHr9HLuJhDLsqRIVU,7346
|
202
202
|
torch_geometric/explain/algorithm/dummy_explainer.py,sha256=jvcVQmfngmUWgoKa5p7CXzju2HM5D5DfieJhZW3gbLc,2872
|
203
|
-
torch_geometric/explain/algorithm/gnn_explainer.py,sha256=
|
203
|
+
torch_geometric/explain/algorithm/gnn_explainer.py,sha256=iu45fGWdd4c6wNczWEAT-29HCAz7ncuoaS6cpx-xDJM,24660
|
204
204
|
torch_geometric/explain/algorithm/graphmask_explainer.py,sha256=T2B081dK-JSpaQmutnkQd5xF3JF49_dPZCOgwqIKJDk,21367
|
205
205
|
torch_geometric/explain/algorithm/pg_explainer.py,sha256=zPsl0tT9ISSWK1xP1KKpe1ZjUarhSB736WTtqwcmDIo,10372
|
206
206
|
torch_geometric/explain/algorithm/utils.py,sha256=eh0ARPG41V7piVw5jdMYpV0p7WjTlpehnY-bWqPV_zg,2564
|
@@ -636,7 +636,7 @@ torch_geometric/utils/undirected.py,sha256=H_nfpI0_WluOG6VfjPyldvcjL4w5USAKWu2x5
|
|
636
636
|
torch_geometric/visualization/__init__.py,sha256=PyR_4K5SafsJrBr6qWrkjKr6GBL1b7FtZybyXCDEVwY,154
|
637
637
|
torch_geometric/visualization/graph.py,sha256=ZuLPL92yGRi7lxlqsUPwL_EVVXF7P2kMcveTtW79vpA,4784
|
638
638
|
torch_geometric/visualization/influence.py,sha256=CWMvuNA_Nf1sfbJmQgn58yS4OFpeKXeZPe7kEuvkUBw,477
|
639
|
-
pyg_nightly-2.7.0.
|
640
|
-
pyg_nightly-2.7.0.
|
641
|
-
pyg_nightly-2.7.0.
|
642
|
-
pyg_nightly-2.7.0.
|
639
|
+
pyg_nightly-2.7.0.dev20250407.dist-info/licenses/LICENSE,sha256=ic-27cMJc1kWoMEYncz3Ya3Ur2Bi3bNLWib2DT763-o,1067
|
640
|
+
pyg_nightly-2.7.0.dev20250407.dist-info/WHEEL,sha256=G2gURzTEtmeR8nrdXUJfNiB3VYVxigPQ-bEQujpNiNs,82
|
641
|
+
pyg_nightly-2.7.0.dev20250407.dist-info/METADATA,sha256=uPWrrVVadg-GQ1_t4bwTD4LTyNy8r0m9_g5BaKW1AVs,63021
|
642
|
+
pyg_nightly-2.7.0.dev20250407.dist-info/RECORD,,
|
torch_geometric/__init__.py
CHANGED
@@ -31,7 +31,7 @@ from .lazy_loader import LazyLoader
|
|
31
31
|
contrib = LazyLoader('contrib', globals(), 'torch_geometric.contrib')
|
32
32
|
graphgym = LazyLoader('graphgym', globals(), 'torch_geometric.graphgym')
|
33
33
|
|
34
|
-
__version__ = '2.7.0.
|
34
|
+
__version__ = '2.7.0.dev20250407'
|
35
35
|
|
36
36
|
__all__ = [
|
37
37
|
'Index',
|
@@ -1,14 +1,24 @@
|
|
1
1
|
from math import sqrt
|
2
|
-
from typing import Optional, Tuple, Union
|
2
|
+
from typing import Dict, Optional, Tuple, Union, overload
|
3
3
|
|
4
4
|
import torch
|
5
5
|
from torch import Tensor
|
6
6
|
from torch.nn.parameter import Parameter
|
7
7
|
|
8
|
-
from torch_geometric.explain import
|
8
|
+
from torch_geometric.explain import (
|
9
|
+
ExplainerConfig,
|
10
|
+
Explanation,
|
11
|
+
HeteroExplanation,
|
12
|
+
ModelConfig,
|
13
|
+
)
|
9
14
|
from torch_geometric.explain.algorithm import ExplainerAlgorithm
|
10
|
-
from torch_geometric.explain.algorithm.utils import
|
15
|
+
from torch_geometric.explain.algorithm.utils import (
|
16
|
+
clear_masks,
|
17
|
+
set_hetero_masks,
|
18
|
+
set_masks,
|
19
|
+
)
|
11
20
|
from torch_geometric.explain.config import MaskType, ModelMode, ModelTaskLevel
|
21
|
+
from torch_geometric.typing import EdgeType, NodeType
|
12
22
|
|
13
23
|
|
14
24
|
class GNNExplainer(ExplainerAlgorithm):
|
@@ -69,7 +79,9 @@ class GNNExplainer(ExplainerAlgorithm):
|
|
69
79
|
|
70
80
|
self.node_mask = self.hard_node_mask = None
|
71
81
|
self.edge_mask = self.hard_edge_mask = None
|
82
|
+
self.is_hetero = False
|
72
83
|
|
84
|
+
@overload
|
73
85
|
def forward(
|
74
86
|
self,
|
75
87
|
model: torch.nn.Module,
|
@@ -80,30 +92,87 @@ class GNNExplainer(ExplainerAlgorithm):
|
|
80
92
|
index: Optional[Union[int, Tensor]] = None,
|
81
93
|
**kwargs,
|
82
94
|
) -> Explanation:
|
83
|
-
|
84
|
-
raise ValueError(f"Heterogeneous graphs not yet supported in "
|
85
|
-
f"'{self.__class__.__name__}'")
|
95
|
+
...
|
86
96
|
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
97
|
+
@overload
|
98
|
+
def forward(
|
99
|
+
self,
|
100
|
+
model: torch.nn.Module,
|
101
|
+
x: Dict[NodeType, Tensor],
|
102
|
+
edge_index: Dict[EdgeType, Tensor],
|
103
|
+
*,
|
104
|
+
target: Tensor,
|
105
|
+
index: Optional[Union[int, Tensor]] = None,
|
106
|
+
**kwargs,
|
107
|
+
) -> HeteroExplanation:
|
108
|
+
...
|
99
109
|
|
110
|
+
def forward(
|
111
|
+
self,
|
112
|
+
model: torch.nn.Module,
|
113
|
+
x: Union[Tensor, Dict[NodeType, Tensor]],
|
114
|
+
edge_index: Union[Tensor, Dict[EdgeType, Tensor]],
|
115
|
+
*,
|
116
|
+
target: Tensor,
|
117
|
+
index: Optional[Union[int, Tensor]] = None,
|
118
|
+
**kwargs,
|
119
|
+
) -> Union[Explanation, HeteroExplanation]:
|
120
|
+
self.is_hetero = isinstance(x, dict)
|
121
|
+
self._train(model, x, edge_index, target=target, index=index, **kwargs)
|
122
|
+
explanation = self._create_explanation()
|
100
123
|
self._clean_model(model)
|
124
|
+
return explanation
|
125
|
+
|
126
|
+
def _create_explanation(self) -> Union[Explanation, HeteroExplanation]:
|
127
|
+
"""Create an explanation object from the current masks."""
|
128
|
+
if self.is_hetero:
|
129
|
+
# For heterogeneous graphs, process each type separately
|
130
|
+
node_mask_dict = {}
|
131
|
+
edge_mask_dict = {}
|
132
|
+
|
133
|
+
for node_type, mask in self.node_mask.items():
|
134
|
+
if mask is not None:
|
135
|
+
node_mask_dict[node_type] = self._post_process_mask(
|
136
|
+
mask,
|
137
|
+
self.hard_node_mask[node_type],
|
138
|
+
apply_sigmoid=True,
|
139
|
+
)
|
140
|
+
|
141
|
+
for edge_type, mask in self.edge_mask.items():
|
142
|
+
if mask is not None:
|
143
|
+
edge_mask_dict[edge_type] = self._post_process_mask(
|
144
|
+
mask,
|
145
|
+
self.hard_edge_mask[edge_type],
|
146
|
+
apply_sigmoid=True,
|
147
|
+
)
|
148
|
+
|
149
|
+
# Create heterogeneous explanation
|
150
|
+
explanation = HeteroExplanation()
|
151
|
+
explanation.set_value_dict('node_mask', node_mask_dict)
|
152
|
+
explanation.set_value_dict('edge_mask', edge_mask_dict)
|
101
153
|
|
102
|
-
|
154
|
+
else:
|
155
|
+
# For homogeneous graphs, process single masks
|
156
|
+
node_mask = self._post_process_mask(
|
157
|
+
self.node_mask,
|
158
|
+
self.hard_node_mask,
|
159
|
+
apply_sigmoid=True,
|
160
|
+
)
|
161
|
+
edge_mask = self._post_process_mask(
|
162
|
+
self.edge_mask,
|
163
|
+
self.hard_edge_mask,
|
164
|
+
apply_sigmoid=True,
|
165
|
+
)
|
166
|
+
|
167
|
+
# Create homogeneous explanation
|
168
|
+
explanation = Explanation(node_mask=node_mask, edge_mask=edge_mask)
|
169
|
+
|
170
|
+
return explanation
|
103
171
|
|
104
172
|
def supports(self) -> bool:
|
105
173
|
return True
|
106
174
|
|
175
|
+
@overload
|
107
176
|
def _train(
|
108
177
|
self,
|
109
178
|
model: torch.nn.Module,
|
@@ -113,57 +182,222 @@ class GNNExplainer(ExplainerAlgorithm):
|
|
113
182
|
target: Tensor,
|
114
183
|
index: Optional[Union[int, Tensor]] = None,
|
115
184
|
**kwargs,
|
116
|
-
):
|
185
|
+
) -> None:
|
186
|
+
...
|
187
|
+
|
188
|
+
@overload
|
189
|
+
def _train(
|
190
|
+
self,
|
191
|
+
model: torch.nn.Module,
|
192
|
+
x: Dict[NodeType, Tensor],
|
193
|
+
edge_index: Dict[EdgeType, Tensor],
|
194
|
+
*,
|
195
|
+
target: Tensor,
|
196
|
+
index: Optional[Union[int, Tensor]] = None,
|
197
|
+
**kwargs,
|
198
|
+
) -> None:
|
199
|
+
...
|
200
|
+
|
201
|
+
def _train(
|
202
|
+
self,
|
203
|
+
model: torch.nn.Module,
|
204
|
+
x: Union[Tensor, Dict[NodeType, Tensor]],
|
205
|
+
edge_index: Union[Tensor, Dict[EdgeType, Tensor]],
|
206
|
+
*,
|
207
|
+
target: Tensor,
|
208
|
+
index: Optional[Union[int, Tensor]] = None,
|
209
|
+
**kwargs,
|
210
|
+
) -> None:
|
211
|
+
# Initialize masks based on input type
|
117
212
|
self._initialize_masks(x, edge_index)
|
118
213
|
|
119
|
-
parameters
|
120
|
-
|
121
|
-
parameters.append(self.node_mask)
|
122
|
-
if self.edge_mask is not None:
|
123
|
-
set_masks(model, self.edge_mask, edge_index, apply_sigmoid=True)
|
124
|
-
parameters.append(self.edge_mask)
|
214
|
+
# Collect parameters for optimization
|
215
|
+
parameters = self._collect_parameters(model, edge_index)
|
125
216
|
|
217
|
+
# Create optimizer
|
126
218
|
optimizer = torch.optim.Adam(parameters, lr=self.lr)
|
127
219
|
|
220
|
+
# Training loop
|
128
221
|
for i in range(self.epochs):
|
129
222
|
optimizer.zero_grad()
|
130
223
|
|
131
|
-
|
132
|
-
y_hat
|
224
|
+
# Forward pass with masked inputs
|
225
|
+
y_hat = self._forward_with_masks(model, x, edge_index, **kwargs)
|
226
|
+
y = target
|
133
227
|
|
228
|
+
# Handle index if provided
|
134
229
|
if index is not None:
|
135
230
|
y_hat, y = y_hat[index], y[index]
|
136
231
|
|
232
|
+
# Calculate loss
|
137
233
|
loss = self._loss(y_hat, y)
|
138
234
|
|
235
|
+
# Backward pass
|
139
236
|
loss.backward()
|
140
237
|
optimizer.step()
|
141
238
|
|
142
|
-
# In the first iteration,
|
143
|
-
#
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
239
|
+
# In the first iteration, collect gradients to identify important
|
240
|
+
# nodes/edges
|
241
|
+
if i == 0:
|
242
|
+
self._collect_gradients()
|
243
|
+
|
244
|
+
def _collect_parameters(self, model, edge_index):
|
245
|
+
"""Collect parameters for optimization."""
|
246
|
+
parameters = []
|
247
|
+
|
248
|
+
if self.is_hetero:
|
249
|
+
# For heterogeneous graphs, collect parameters from all types
|
250
|
+
for mask in self.node_mask.values():
|
251
|
+
if mask is not None:
|
252
|
+
parameters.append(mask)
|
253
|
+
if any(v is not None for v in self.edge_mask.values()):
|
254
|
+
set_hetero_masks(model, self.edge_mask, edge_index)
|
255
|
+
for mask in self.edge_mask.values():
|
256
|
+
if mask is not None:
|
257
|
+
parameters.append(mask)
|
258
|
+
else:
|
259
|
+
# For homogeneous graphs, collect single parameters
|
260
|
+
if self.node_mask is not None:
|
261
|
+
parameters.append(self.node_mask)
|
262
|
+
if self.edge_mask is not None:
|
263
|
+
set_masks(model, self.edge_mask, edge_index,
|
264
|
+
apply_sigmoid=True)
|
265
|
+
parameters.append(self.edge_mask)
|
266
|
+
|
267
|
+
return parameters
|
268
|
+
|
269
|
+
@overload
|
270
|
+
def _forward_with_masks(
|
271
|
+
self,
|
272
|
+
model: torch.nn.Module,
|
273
|
+
x: Tensor,
|
274
|
+
edge_index: Tensor,
|
275
|
+
**kwargs,
|
276
|
+
) -> Tensor:
|
277
|
+
...
|
278
|
+
|
279
|
+
@overload
|
280
|
+
def _forward_with_masks(
|
281
|
+
self,
|
282
|
+
model: torch.nn.Module,
|
283
|
+
x: Dict[NodeType, Tensor],
|
284
|
+
edge_index: Dict[EdgeType, Tensor],
|
285
|
+
**kwargs,
|
286
|
+
) -> Tensor:
|
287
|
+
...
|
288
|
+
|
289
|
+
def _forward_with_masks(
|
290
|
+
self,
|
291
|
+
model: torch.nn.Module,
|
292
|
+
x: Union[Tensor, Dict[NodeType, Tensor]],
|
293
|
+
edge_index: Union[Tensor, Dict[EdgeType, Tensor]],
|
294
|
+
**kwargs,
|
295
|
+
) -> Tensor:
|
296
|
+
"""Forward pass with masked inputs."""
|
297
|
+
if self.is_hetero:
|
298
|
+
# Apply masks to heterogeneous inputs
|
299
|
+
h_dict = {}
|
300
|
+
for node_type, features in x.items():
|
301
|
+
if node_type in self.node_mask and self.node_mask[
|
302
|
+
node_type] is not None:
|
303
|
+
h_dict[node_type] = features * self.node_mask[
|
304
|
+
node_type].sigmoid()
|
305
|
+
else:
|
306
|
+
h_dict[node_type] = features
|
307
|
+
|
308
|
+
# Forward pass with masked features
|
309
|
+
return model(h_dict, edge_index, **kwargs)
|
310
|
+
else:
|
311
|
+
# Apply mask to homogeneous input
|
312
|
+
h = x if self.node_mask is None else x * self.node_mask.sigmoid()
|
313
|
+
|
314
|
+
# Forward pass with masked features
|
315
|
+
return model(h, edge_index, **kwargs)
|
316
|
+
|
317
|
+
def _initialize_masks(
|
318
|
+
self,
|
319
|
+
x: Union[Tensor, Dict[NodeType, Tensor]],
|
320
|
+
edge_index: Union[Tensor, Dict[EdgeType, Tensor]],
|
321
|
+
) -> None:
|
161
322
|
node_mask_type = self.explainer_config.node_mask_type
|
162
323
|
edge_mask_type = self.explainer_config.edge_mask_type
|
163
324
|
|
164
|
-
|
165
|
-
|
325
|
+
if self.is_hetero:
|
326
|
+
# Initialize dictionaries for heterogeneous masks
|
327
|
+
self.node_mask = {}
|
328
|
+
self.hard_node_mask = {}
|
329
|
+
self.edge_mask = {}
|
330
|
+
self.hard_edge_mask = {}
|
331
|
+
|
332
|
+
# Initialize node masks for each node type
|
333
|
+
for node_type, features in x.items():
|
334
|
+
device = features.device
|
335
|
+
N, F = features.size()
|
336
|
+
self._initialize_node_mask(node_mask_type, node_type, N, F,
|
337
|
+
device)
|
338
|
+
|
339
|
+
# Initialize edge masks for each edge type
|
340
|
+
for edge_type, indices in edge_index.items():
|
341
|
+
device = indices.device
|
342
|
+
E = indices.size(1)
|
343
|
+
N = max(indices.max().item() + 1,
|
344
|
+
max(feat.size(0) for feat in x.values()))
|
345
|
+
self._initialize_edge_mask(edge_mask_type, edge_type, E, N,
|
346
|
+
device)
|
347
|
+
else:
|
348
|
+
# Initialize masks for homogeneous graph
|
349
|
+
device = x.device
|
350
|
+
(N, F), E = x.size(), edge_index.size(1)
|
351
|
+
|
352
|
+
# Initialize homogeneous node and edge masks
|
353
|
+
self._initialize_homogeneous_masks(node_mask_type, edge_mask_type,
|
354
|
+
N, F, E, device)
|
355
|
+
|
356
|
+
def _initialize_node_mask(
|
357
|
+
self,
|
358
|
+
node_mask_type,
|
359
|
+
node_type,
|
360
|
+
N,
|
361
|
+
F,
|
362
|
+
device,
|
363
|
+
) -> None:
|
364
|
+
"""Initialize node mask for a specific node type."""
|
365
|
+
std = 0.1
|
366
|
+
if node_mask_type is None:
|
367
|
+
self.node_mask[node_type] = None
|
368
|
+
self.hard_node_mask[node_type] = None
|
369
|
+
elif node_mask_type == MaskType.object:
|
370
|
+
self.node_mask[node_type] = Parameter(
|
371
|
+
torch.randn(N, 1, device=device) * std)
|
372
|
+
self.hard_node_mask[node_type] = None
|
373
|
+
elif node_mask_type == MaskType.attributes:
|
374
|
+
self.node_mask[node_type] = Parameter(
|
375
|
+
torch.randn(N, F, device=device) * std)
|
376
|
+
self.hard_node_mask[node_type] = None
|
377
|
+
elif node_mask_type == MaskType.common_attributes:
|
378
|
+
self.node_mask[node_type] = Parameter(
|
379
|
+
torch.randn(1, F, device=device) * std)
|
380
|
+
self.hard_node_mask[node_type] = None
|
381
|
+
else:
|
382
|
+
raise ValueError(f"Invalid node mask type: {node_mask_type}")
|
383
|
+
|
384
|
+
def _initialize_edge_mask(self, edge_mask_type, edge_type, E, N, device):
|
385
|
+
"""Initialize edge mask for a specific edge type."""
|
386
|
+
if edge_mask_type is None:
|
387
|
+
self.edge_mask[edge_type] = None
|
388
|
+
self.hard_edge_mask[edge_type] = None
|
389
|
+
elif edge_mask_type == MaskType.object:
|
390
|
+
std = torch.nn.init.calculate_gain('relu') * sqrt(2.0 / (2 * N))
|
391
|
+
self.edge_mask[edge_type] = Parameter(
|
392
|
+
torch.randn(E, device=device) * std)
|
393
|
+
self.hard_edge_mask[edge_type] = None
|
394
|
+
else:
|
395
|
+
raise ValueError(f"Invalid edge mask type: {edge_mask_type}")
|
166
396
|
|
397
|
+
def _initialize_homogeneous_masks(self, node_mask_type, edge_mask_type, N,
|
398
|
+
F, E, device):
|
399
|
+
"""Initialize masks for homogeneous graph."""
|
400
|
+
# Initialize node mask
|
167
401
|
std = 0.1
|
168
402
|
if node_mask_type is None:
|
169
403
|
self.node_mask = None
|
@@ -174,43 +408,145 @@ class GNNExplainer(ExplainerAlgorithm):
|
|
174
408
|
elif node_mask_type == MaskType.common_attributes:
|
175
409
|
self.node_mask = Parameter(torch.randn(1, F, device=device) * std)
|
176
410
|
else:
|
177
|
-
|
411
|
+
raise ValueError(f"Invalid node mask type: {node_mask_type}")
|
178
412
|
|
413
|
+
# Initialize edge mask
|
179
414
|
if edge_mask_type is None:
|
180
415
|
self.edge_mask = None
|
181
416
|
elif edge_mask_type == MaskType.object:
|
182
417
|
std = torch.nn.init.calculate_gain('relu') * sqrt(2.0 / (2 * N))
|
183
418
|
self.edge_mask = Parameter(torch.randn(E, device=device) * std)
|
184
419
|
else:
|
185
|
-
|
420
|
+
raise ValueError(f"Invalid edge mask type: {edge_mask_type}")
|
421
|
+
|
422
|
+
def _collect_gradients(self) -> None:
|
423
|
+
if self.is_hetero:
|
424
|
+
self._collect_hetero_gradients()
|
425
|
+
else:
|
426
|
+
self._collect_homo_gradients()
|
427
|
+
|
428
|
+
def _collect_hetero_gradients(self):
|
429
|
+
"""Collect gradients for heterogeneous graph."""
|
430
|
+
for node_type, mask in self.node_mask.items():
|
431
|
+
if mask is not None:
|
432
|
+
if mask.grad is None:
|
433
|
+
raise ValueError(
|
434
|
+
f"Could not compute gradients for node masks of type "
|
435
|
+
f"'{node_type}'. Please make sure that node masks are "
|
436
|
+
f"used inside the model or disable it via "
|
437
|
+
f"`node_mask_type=None`.")
|
438
|
+
|
439
|
+
self.hard_node_mask[node_type] = mask.grad != 0.0
|
440
|
+
|
441
|
+
for edge_type, mask in self.edge_mask.items():
|
442
|
+
if mask is not None:
|
443
|
+
if mask.grad is None:
|
444
|
+
raise ValueError(
|
445
|
+
f"Could not compute gradients for edge masks of type "
|
446
|
+
f"'{edge_type}'. Please make sure that edge masks are "
|
447
|
+
f"used inside the model or disable it via "
|
448
|
+
f"`edge_mask_type=None`.")
|
449
|
+
self.hard_edge_mask[edge_type] = mask.grad != 0.0
|
450
|
+
|
451
|
+
def _collect_homo_gradients(self):
|
452
|
+
"""Collect gradients for homogeneous graph."""
|
453
|
+
if self.node_mask is not None:
|
454
|
+
if self.node_mask.grad is None:
|
455
|
+
raise ValueError("Could not compute gradients for node "
|
456
|
+
"features. Please make sure that node "
|
457
|
+
"features are used inside the model or "
|
458
|
+
"disable it via `node_mask_type=None`.")
|
459
|
+
self.hard_node_mask = self.node_mask.grad != 0.0
|
460
|
+
|
461
|
+
if self.edge_mask is not None:
|
462
|
+
if self.edge_mask.grad is None:
|
463
|
+
raise ValueError("Could not compute gradients for edges. "
|
464
|
+
"Please make sure that edges are used "
|
465
|
+
"via message passing inside the model or "
|
466
|
+
"disable it via `edge_mask_type=None`.")
|
467
|
+
self.hard_edge_mask = self.edge_mask.grad != 0.0
|
186
468
|
|
187
469
|
def _loss(self, y_hat: Tensor, y: Tensor) -> Tensor:
|
470
|
+
# Calculate base loss based on model configuration
|
471
|
+
loss = self._calculate_base_loss(y_hat, y)
|
472
|
+
|
473
|
+
# Apply regularization based on graph type
|
474
|
+
if self.is_hetero:
|
475
|
+
# Apply regularization for heterogeneous graph
|
476
|
+
loss = self._apply_hetero_regularization(loss)
|
477
|
+
else:
|
478
|
+
# Apply regularization for homogeneous graph
|
479
|
+
loss = self._apply_homo_regularization(loss)
|
480
|
+
|
481
|
+
return loss
|
482
|
+
|
483
|
+
def _calculate_base_loss(self, y_hat, y):
|
484
|
+
"""Calculate base loss based on model configuration."""
|
188
485
|
if self.model_config.mode == ModelMode.binary_classification:
|
189
|
-
|
486
|
+
return self._loss_binary_classification(y_hat, y)
|
190
487
|
elif self.model_config.mode == ModelMode.multiclass_classification:
|
191
|
-
|
488
|
+
return self._loss_multiclass_classification(y_hat, y)
|
192
489
|
elif self.model_config.mode == ModelMode.regression:
|
193
|
-
|
490
|
+
return self._loss_regression(y_hat, y)
|
194
491
|
else:
|
195
|
-
|
492
|
+
raise ValueError(f"Invalid model mode: {self.model_config.mode}")
|
493
|
+
|
494
|
+
def _apply_hetero_regularization(self, loss):
|
495
|
+
"""Apply regularization for heterogeneous graph."""
|
496
|
+
# Apply regularization for each edge type
|
497
|
+
for edge_type, mask in self.edge_mask.items():
|
498
|
+
if (mask is not None
|
499
|
+
and self.hard_edge_mask[edge_type] is not None):
|
500
|
+
loss = self._add_mask_regularization(
|
501
|
+
loss, mask, self.hard_edge_mask[edge_type],
|
502
|
+
self.coeffs['edge_size'], self.coeffs['edge_reduction'],
|
503
|
+
self.coeffs['edge_ent'])
|
504
|
+
|
505
|
+
# Apply regularization for each node type
|
506
|
+
for node_type, mask in self.node_mask.items():
|
507
|
+
if (mask is not None
|
508
|
+
and self.hard_node_mask[node_type] is not None):
|
509
|
+
loss = self._add_mask_regularization(
|
510
|
+
loss, mask, self.hard_node_mask[node_type],
|
511
|
+
self.coeffs['node_feat_size'],
|
512
|
+
self.coeffs['node_feat_reduction'],
|
513
|
+
self.coeffs['node_feat_ent'])
|
196
514
|
|
515
|
+
return loss
|
516
|
+
|
517
|
+
def _apply_homo_regularization(self, loss):
|
518
|
+
"""Apply regularization for homogeneous graph."""
|
519
|
+
# Apply regularization for edge mask
|
197
520
|
if self.hard_edge_mask is not None:
|
198
521
|
assert self.edge_mask is not None
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
loss = loss + self.coeffs['edge_ent'] * ent.mean()
|
522
|
+
loss = self._add_mask_regularization(loss, self.edge_mask,
|
523
|
+
self.hard_edge_mask,
|
524
|
+
self.coeffs['edge_size'],
|
525
|
+
self.coeffs['edge_reduction'],
|
526
|
+
self.coeffs['edge_ent'])
|
205
527
|
|
528
|
+
# Apply regularization for node mask
|
206
529
|
if self.hard_node_mask is not None:
|
207
530
|
assert self.node_mask is not None
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
531
|
+
loss = self._add_mask_regularization(
|
532
|
+
loss, self.node_mask, self.hard_node_mask,
|
533
|
+
self.coeffs['node_feat_size'],
|
534
|
+
self.coeffs['node_feat_reduction'],
|
535
|
+
self.coeffs['node_feat_ent'])
|
536
|
+
|
537
|
+
return loss
|
538
|
+
|
539
|
+
def _add_mask_regularization(self, loss, mask, hard_mask, size_coeff,
|
540
|
+
reduction_name, ent_coeff):
|
541
|
+
"""Add size and entropy regularization for a mask."""
|
542
|
+
m = mask[hard_mask].sigmoid()
|
543
|
+
reduce_fn = getattr(torch, reduction_name)
|
544
|
+
# Add size regularization
|
545
|
+
loss = loss + size_coeff * reduce_fn(m)
|
546
|
+
# Add entropy regularization
|
547
|
+
ent = -m * torch.log(m + self.coeffs['EPS']) - (
|
548
|
+
1 - m) * torch.log(1 - m + self.coeffs['EPS'])
|
549
|
+
loss = loss + ent_coeff * ent.mean()
|
214
550
|
|
215
551
|
return loss
|
216
552
|
|
File without changes
|
{pyg_nightly-2.7.0.dev20250405.dist-info → pyg_nightly-2.7.0.dev20250407.dist-info}/licenses/LICENSE
RENAMED
File without changes
|