pyg-nightly 2.7.0.dev20250825__py3-none-any.whl → 2.7.0.dev20250827__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pyg_nightly-2.7.0.dev20250825.dist-info → pyg_nightly-2.7.0.dev20250827.dist-info}/METADATA +1 -1
- {pyg_nightly-2.7.0.dev20250825.dist-info → pyg_nightly-2.7.0.dev20250827.dist-info}/RECORD +13 -12
- torch_geometric/__init__.py +3 -2
- torch_geometric/_onnx.py +214 -0
- torch_geometric/loader/link_neighbor_loader.py +1 -0
- torch_geometric/nn/models/__init__.py +3 -0
- torch_geometric/nn/models/lpformer.py +783 -0
- torch_geometric/sampler/__init__.py +2 -1
- torch_geometric/sampler/base.py +258 -3
- torch_geometric/sampler/neighbor_sampler.py +283 -13
- torch_geometric/sampler/utils.py +48 -5
- {pyg_nightly-2.7.0.dev20250825.dist-info → pyg_nightly-2.7.0.dev20250827.dist-info}/WHEEL +0 -0
- {pyg_nightly-2.7.0.dev20250825.dist-info → pyg_nightly-2.7.0.dev20250827.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,783 @@
|
|
1
|
+
import math
|
2
|
+
from typing import Optional
|
3
|
+
|
4
|
+
import torch
|
5
|
+
import torch.nn as nn
|
6
|
+
import torch.nn.functional as F
|
7
|
+
from torch import Tensor
|
8
|
+
from torch.nn import Parameter
|
9
|
+
|
10
|
+
from ...nn.conv import MessagePassing
|
11
|
+
from ...nn.dense.linear import Linear
|
12
|
+
from ...nn.inits import glorot, zeros
|
13
|
+
from ...typing import Adj, OptTensor, Tuple
|
14
|
+
from ...utils import get_ppr, is_sparse, scatter, softmax
|
15
|
+
from .basic_gnn import GCN
|
16
|
+
|
17
|
+
|
18
|
+
class LPFormer(nn.Module):
|
19
|
+
r"""The LPFormer model from the
|
20
|
+
`"LPFormer: An Adaptive Graph Transformer for Link Prediction"
|
21
|
+
<https://arxiv.org/abs/2310.11009>`_ paper.
|
22
|
+
|
23
|
+
.. note::
|
24
|
+
|
25
|
+
For an example of using LPFormer, see
|
26
|
+
`examples/lpformer.py
|
27
|
+
<https://github.com/pyg-team/pytorch_geometric/blob/master/examples/
|
28
|
+
lpformer.py>`_.
|
29
|
+
|
30
|
+
Args:
|
31
|
+
in_channels (int): Size of input dimension
|
32
|
+
hidden_channels (int): Size of hidden dimension
|
33
|
+
num_gnn_layers (int, optional): Number of GNN layers
|
34
|
+
(default: :obj:`2`)
|
35
|
+
gnn_dropout(float, optional): Dropout used for GNN
|
36
|
+
(default: :obj:`0.1`)
|
37
|
+
num_transformer_layers (int, optional): Number of Transformer layers
|
38
|
+
(default: :obj:`1`)
|
39
|
+
num_heads (int, optional): Number of heads to use in MHA
|
40
|
+
(default: :obj:`1`)
|
41
|
+
transformer_dropout (float, optional): Dropout used for Transformer
|
42
|
+
(default: :obj:`0.1`)
|
43
|
+
ppr_thresholds (list): PPR thresholds for different types of nodes.
|
44
|
+
Types include (in order) common neighbors, 1-Hop nodes
|
45
|
+
(that aren't CNs), and all other nodes.
|
46
|
+
(default: :obj:`[0, 1e-4, 1e-2]`)
|
47
|
+
gcn_cache (bool, optional): Whether to cache edge indices
|
48
|
+
during message passing. (default: :obj:`False`)
|
49
|
+
"""
|
50
|
+
def __init__(
|
51
|
+
self,
|
52
|
+
in_channels: int,
|
53
|
+
hidden_channels: int,
|
54
|
+
num_gnn_layers: int = 2,
|
55
|
+
gnn_dropout: float = 0.1,
|
56
|
+
num_transformer_layers: int = 1,
|
57
|
+
num_heads: int = 1,
|
58
|
+
transformer_dropout: float = 0.1,
|
59
|
+
ppr_thresholds: list = None,
|
60
|
+
gcn_cache=False,
|
61
|
+
):
|
62
|
+
super().__init__()
|
63
|
+
|
64
|
+
# Default thresholds
|
65
|
+
if ppr_thresholds is None:
|
66
|
+
ppr_thresholds = [0, 1e-4, 1e-2]
|
67
|
+
|
68
|
+
if len(ppr_thresholds) == 3:
|
69
|
+
self.thresh_cn = ppr_thresholds[0]
|
70
|
+
self.thresh_1hop = ppr_thresholds[1]
|
71
|
+
self.thresh_non1hop = ppr_thresholds[2]
|
72
|
+
else:
|
73
|
+
raise ValueError(
|
74
|
+
"Argument 'ppr_thresholds' must only be length 3!")
|
75
|
+
|
76
|
+
self.in_dim = in_channels
|
77
|
+
self.hid_dim = hidden_channels
|
78
|
+
self.gnn_drop = gnn_dropout
|
79
|
+
self.trans_drop = transformer_dropout
|
80
|
+
|
81
|
+
self.gnn = GCN(in_channels, hidden_channels, num_gnn_layers,
|
82
|
+
dropout=gnn_dropout, norm="layer_norm",
|
83
|
+
cached=gcn_cache)
|
84
|
+
self.gnn_norm = nn.LayerNorm(hidden_channels)
|
85
|
+
|
86
|
+
# Create Transformer Layers
|
87
|
+
self.att_layers = nn.ModuleList()
|
88
|
+
for il in range(num_transformer_layers):
|
89
|
+
if il == 0:
|
90
|
+
node_dim = None
|
91
|
+
self.out_dim = self.hid_dim * 2 if num_transformer_layers > 1 \
|
92
|
+
else self.hid_dim
|
93
|
+
elif il == self.num_layers - 1:
|
94
|
+
node_dim = self.hid_dim
|
95
|
+
else:
|
96
|
+
self.out_dim = node_dim = self.hid_dim
|
97
|
+
|
98
|
+
self.att_layers.append(
|
99
|
+
LPAttLayer(self.hid_dim, self.out_dim, node_dim, num_heads,
|
100
|
+
self.trans_drop))
|
101
|
+
|
102
|
+
self.elementwise_lin = MLP(self.hid_dim, self.hid_dim, self.hid_dim)
|
103
|
+
|
104
|
+
# Relative Positional Encodings
|
105
|
+
self.ppr_encoder_cn = MLP(2, self.hid_dim, self.hid_dim)
|
106
|
+
self.ppr_encoder_onehop = MLP(2, self.hid_dim, self.hid_dim)
|
107
|
+
self.ppr_encoder_non1hop = MLP(2, self.hid_dim, self.hid_dim)
|
108
|
+
|
109
|
+
# thresh=1 implies ignoring some set of nodes
|
110
|
+
# Also allows us to be more efficient later
|
111
|
+
if self.thresh_non1hop == 1 and self.thresh_1hop == 1:
|
112
|
+
self.mask = "cn"
|
113
|
+
elif self.thresh_non1hop == 1 and self.thresh_1hop < 1:
|
114
|
+
self.mask = "1-hop"
|
115
|
+
else:
|
116
|
+
self.mask = "all"
|
117
|
+
|
118
|
+
# 4 is for counts of diff nodes
|
119
|
+
pairwise_dim = self.hid_dim * num_heads + 4
|
120
|
+
self.pairwise_lin = MLP(pairwise_dim, pairwise_dim, self.hid_dim)
|
121
|
+
|
122
|
+
self.score_func = MLP(self.hid_dim * 2, self.hid_dim * 2, 1, norm=None)
|
123
|
+
|
124
|
+
def __repr__(self) -> str:
|
125
|
+
return (f'{self.__class__.__name__}({self.in_dim}, '
|
126
|
+
f'{self.hid_dim}, num_gnn_layers={self.gnn.num_layers}, '
|
127
|
+
f'num_transformer_layers={len(self.att_layers)})')
|
128
|
+
|
129
|
+
def reset_parameters(self):
|
130
|
+
r"""Resets all learnable parameters of the module."""
|
131
|
+
self.gnn.reset_parameters()
|
132
|
+
self.gnn_norm.reset_parameters()
|
133
|
+
self.elementwise_lin.reset_parameters()
|
134
|
+
self.pairwise_lin.reset_parameters()
|
135
|
+
self.ppr_encoder_cn.reset_parameters()
|
136
|
+
self.ppr_encoder_onehop.reset_parameters()
|
137
|
+
self.ppr_encoder_non1hop.reset_parameters()
|
138
|
+
self.score_func.reset_parameters()
|
139
|
+
for i in range(len(self.att_layers)):
|
140
|
+
self.att_layers[i].reset_parameters()
|
141
|
+
|
142
|
+
def forward(
|
143
|
+
self,
|
144
|
+
batch: Tensor,
|
145
|
+
x: Tensor,
|
146
|
+
edge_index: Adj,
|
147
|
+
ppr_matrix: Tensor,
|
148
|
+
) -> Tensor:
|
149
|
+
r"""Forward Pass of LPFormer.
|
150
|
+
|
151
|
+
Returns raw logits for each link
|
152
|
+
|
153
|
+
Args:
|
154
|
+
batch (Tensor): The batch vector.
|
155
|
+
Denotes which node pairs to predict.
|
156
|
+
x (Tensor): Input node features
|
157
|
+
edge_index (torch.Tensor, SparseTensor): The edge indices.
|
158
|
+
Either in COO or SparseTensor format
|
159
|
+
ppr_matrix (Tensor): PPR matrix
|
160
|
+
"""
|
161
|
+
batch = batch.to(x.device)
|
162
|
+
|
163
|
+
X_node = self.propagate(x, edge_index)
|
164
|
+
x_i, x_j = X_node[batch[0]], X_node[batch[1]]
|
165
|
+
elementwise_edge_feats = self.elementwise_lin(x_i * x_j)
|
166
|
+
|
167
|
+
# Ensure in sparse format
|
168
|
+
# Need as native torch.sparse for later computations
|
169
|
+
# (necessary operations are not supported by PyG SparseTensor)
|
170
|
+
if not edge_index.is_sparse:
|
171
|
+
num_nodes = ppr_matrix.size(1)
|
172
|
+
vals = torch.ones(len(edge_index[0]), device=edge_index.device)
|
173
|
+
edge_index = torch.sparse_coo_tensor(edge_index, vals,
|
174
|
+
[num_nodes, num_nodes])
|
175
|
+
# Checks if SparseTensor, if so the convert
|
176
|
+
if is_sparse(edge_index) and not edge_index.is_sparse:
|
177
|
+
edge_index = edge_index.to_torch_sparse_coo_tensor()
|
178
|
+
|
179
|
+
# Ensure {0, 1}
|
180
|
+
edge_index = edge_index.coalesce().bool().int()
|
181
|
+
|
182
|
+
pairwise_feats = self.calc_pairwise(batch, X_node, edge_index,
|
183
|
+
ppr_matrix)
|
184
|
+
combined_feats = torch.cat((elementwise_edge_feats, pairwise_feats),
|
185
|
+
dim=-1)
|
186
|
+
|
187
|
+
logits = self.score_func(combined_feats)
|
188
|
+
return logits
|
189
|
+
|
190
|
+
def propagate(self, x: Tensor, adj: Adj) -> Tensor:
|
191
|
+
"""Propagate via GNN.
|
192
|
+
|
193
|
+
Args:
|
194
|
+
x (Tensor): Node features
|
195
|
+
adj (torch.Tensor, SparseTensor): Adjacency matrix
|
196
|
+
"""
|
197
|
+
x = F.dropout(x, p=self.gnn_drop, training=self.training)
|
198
|
+
X_node = self.gnn(x, adj)
|
199
|
+
X_node = self.gnn_norm(X_node)
|
200
|
+
|
201
|
+
return X_node
|
202
|
+
|
203
|
+
def calc_pairwise(self, batch: Tensor, X_node: Tensor, adj_mask: Tensor,
|
204
|
+
ppr_matrix: Tensor) -> Tensor:
|
205
|
+
r"""Calculate the pairwise features for the node pairs.
|
206
|
+
|
207
|
+
Args:
|
208
|
+
batch (Tensor): The batch vector.
|
209
|
+
Denotes which node pairs to predict.
|
210
|
+
X_node (Tensor): Node representations
|
211
|
+
adj_mask (Tensor): Mask of adjacency matrix used for computing the
|
212
|
+
different node types.
|
213
|
+
ppr_matrix (Tensor): PPR matrix
|
214
|
+
"""
|
215
|
+
k_i, k_j = X_node[batch[0]], X_node[batch[1]]
|
216
|
+
pairwise_feats = torch.cat((k_i, k_j), dim=-1)
|
217
|
+
|
218
|
+
cn_info, onehop_info, non1hop_info = self.compute_node_mask(
|
219
|
+
batch, adj_mask, ppr_matrix)
|
220
|
+
|
221
|
+
all_mask = cn_info[0]
|
222
|
+
if onehop_info is not None:
|
223
|
+
all_mask = torch.cat((all_mask, onehop_info[0]), dim=-1)
|
224
|
+
if non1hop_info is not None:
|
225
|
+
all_mask = torch.cat((all_mask, non1hop_info[0]), dim=-1)
|
226
|
+
|
227
|
+
pes = self.get_pos_encodings(cn_info[1:], onehop_info[1:],
|
228
|
+
non1hop_info[1:])
|
229
|
+
|
230
|
+
for lay in range(len(self.att_layers)):
|
231
|
+
pairwise_feats = self.att_layers[lay](all_mask, pairwise_feats,
|
232
|
+
X_node, pes)
|
233
|
+
|
234
|
+
num_cns, num_1hop, num_non1hop, num_neigh = self.get_structure_cnts(
|
235
|
+
batch, cn_info, onehop_info, non1hop_info)
|
236
|
+
|
237
|
+
pairwise_feats = torch.cat(
|
238
|
+
(pairwise_feats, num_cns, num_1hop, num_non1hop, num_neigh),
|
239
|
+
dim=-1)
|
240
|
+
|
241
|
+
pairwise_feats = self.pairwise_lin(pairwise_feats)
|
242
|
+
return pairwise_feats
|
243
|
+
|
244
|
+
def get_pos_encodings(
|
245
|
+
self, cn_ppr: Tuple[Tensor, Tensor],
|
246
|
+
onehop_ppr: Optional[Tuple[Tensor, Tensor]] = None,
|
247
|
+
non1hop_ppr: Optional[Tuple[Tensor, Tensor]] = None) -> Tensor:
|
248
|
+
r"""Calculate the PPR-based relative positional encodings.
|
249
|
+
|
250
|
+
Due to thresholds, sometimes we don't have 1-hop or >1-hop nodes.
|
251
|
+
In those cases, the value of onehop_ppr and/or non1hop_ppr should
|
252
|
+
be `None`.
|
253
|
+
|
254
|
+
Args:
|
255
|
+
cn_ppr (tuple, optional): PPR scores of CNs.
|
256
|
+
onehop_ppr (tuple, optional): PPR scores of 1-Hop.
|
257
|
+
(default: :obj:`None`)
|
258
|
+
non1hop_ppr (tuple, optional): PPR scores of >1-Hop.
|
259
|
+
(default: :obj:`None`)
|
260
|
+
"""
|
261
|
+
cn_a = self.ppr_encoder_cn(torch.stack((cn_ppr[0], cn_ppr[1])).t())
|
262
|
+
cn_b = self.ppr_encoder_cn(torch.stack((cn_ppr[1], cn_ppr[0])).t())
|
263
|
+
cn_pe = cn_a + cn_b
|
264
|
+
|
265
|
+
if onehop_ppr is None:
|
266
|
+
return cn_pe
|
267
|
+
|
268
|
+
onehop_a = self.ppr_encoder_onehop(
|
269
|
+
torch.stack((onehop_ppr[0], onehop_ppr[1])).t())
|
270
|
+
onehop_b = self.ppr_encoder_onehop(
|
271
|
+
torch.stack((onehop_ppr[1], onehop_ppr[0])).t())
|
272
|
+
onehop_pe = onehop_a + onehop_b
|
273
|
+
|
274
|
+
if non1hop_ppr is None:
|
275
|
+
return torch.cat((cn_pe, onehop_pe), dim=0)
|
276
|
+
|
277
|
+
non1hop_a = self.ppr_encoder_non1hop(
|
278
|
+
torch.stack((non1hop_ppr[0], non1hop_ppr[1])).t())
|
279
|
+
non1hop_b = self.ppr_encoder_non1hop(
|
280
|
+
torch.stack((non1hop_ppr[1], non1hop_ppr[0])).t())
|
281
|
+
non1hop_pe = non1hop_a + non1hop_b
|
282
|
+
|
283
|
+
return torch.cat((cn_pe, onehop_pe, non1hop_pe), dim=0)
|
284
|
+
|
285
|
+
def compute_node_mask(
|
286
|
+
self, batch: Tensor, adj: Tensor, ppr_matrix: Tensor
|
287
|
+
) -> Tuple[Tuple, Optional[Tuple], Optional[Tuple]]:
|
288
|
+
r"""Get mask based on type of node.
|
289
|
+
|
290
|
+
When mask_type is not "cn", also return the ppr vals for both
|
291
|
+
the source and target.
|
292
|
+
|
293
|
+
Args:
|
294
|
+
batch (Tensor): The batch vector.
|
295
|
+
Denotes which node pairs to predict.
|
296
|
+
adj (SparseTensor): Adjacency matrix
|
297
|
+
ppr_matrix (Tensor): PPR matrix
|
298
|
+
"""
|
299
|
+
src_adj = torch.index_select(adj, 0, batch[0])
|
300
|
+
tgt_adj = torch.index_select(adj, 0, batch[1])
|
301
|
+
|
302
|
+
if self.mask == "cn":
|
303
|
+
# 1 when CN, 0 otherwise
|
304
|
+
pair_adj = src_adj * tgt_adj
|
305
|
+
else:
|
306
|
+
# Equals: {0: ">1-Hop", 1: "1-Hop (Non-CN)", 2: "CN"}
|
307
|
+
pair_adj = src_adj + tgt_adj
|
308
|
+
|
309
|
+
pair_ix, node_type, src_ppr, tgt_ppr = self.get_ppr_vals(
|
310
|
+
batch, pair_adj, ppr_matrix)
|
311
|
+
|
312
|
+
cn_filt_cond = (src_ppr >= self.thresh_cn) & (tgt_ppr
|
313
|
+
>= self.thresh_cn)
|
314
|
+
onehop_filt_cond = (src_ppr >= self.thresh_1hop) & (
|
315
|
+
tgt_ppr >= self.thresh_1hop)
|
316
|
+
|
317
|
+
if self.mask != "cn":
|
318
|
+
filt_cond = torch.where(node_type == 1, onehop_filt_cond,
|
319
|
+
cn_filt_cond)
|
320
|
+
else:
|
321
|
+
filt_cond = torch.where(node_type == 0, onehop_filt_cond,
|
322
|
+
cn_filt_cond)
|
323
|
+
|
324
|
+
pair_ix, node_type = pair_ix[:, filt_cond], node_type[filt_cond]
|
325
|
+
src_ppr, tgt_ppr = src_ppr[filt_cond], tgt_ppr[filt_cond]
|
326
|
+
|
327
|
+
# >1-Hop mask is gotten separately
|
328
|
+
if self.mask == "all":
|
329
|
+
non1hop_ix, non1hop_sppr, non1hop_tppr = self.get_non_1hop_ppr(
|
330
|
+
batch, adj, ppr_matrix)
|
331
|
+
|
332
|
+
# Dropout
|
333
|
+
if self.training and self.trans_drop > 0:
|
334
|
+
pair_ix, src_ppr, tgt_ppr, node_type = self.drop_pairwise(
|
335
|
+
pair_ix, src_ppr, tgt_ppr, node_type)
|
336
|
+
if self.mask == "all":
|
337
|
+
non1hop_ix, non1hop_sppr, non1hop_tppr, _ = self.drop_pairwise(
|
338
|
+
non1hop_ix, non1hop_sppr, non1hop_tppr)
|
339
|
+
|
340
|
+
# Separate out CN and 1-Hop
|
341
|
+
if self.mask != "cn":
|
342
|
+
cn_ind = node_type == 2
|
343
|
+
cn_ix = pair_ix[:, cn_ind]
|
344
|
+
cn_src_ppr = src_ppr[cn_ind]
|
345
|
+
cn_tgt_ppr = tgt_ppr[cn_ind]
|
346
|
+
|
347
|
+
one_hop_ind = node_type == 1
|
348
|
+
onehop_ix = pair_ix[:, one_hop_ind]
|
349
|
+
onehop_src_ppr = src_ppr[one_hop_ind]
|
350
|
+
onehop_tgt_ppr = tgt_ppr[one_hop_ind]
|
351
|
+
|
352
|
+
if self.mask == "cn":
|
353
|
+
return (pair_ix, src_ppr, tgt_ppr), None, None
|
354
|
+
elif self.mask == "1-hop":
|
355
|
+
return (cn_ix, cn_src_ppr, cn_tgt_ppr), (onehop_ix, onehop_src_ppr,
|
356
|
+
onehop_tgt_ppr), None
|
357
|
+
else:
|
358
|
+
return (cn_ix, cn_src_ppr,
|
359
|
+
cn_tgt_ppr), (onehop_ix, onehop_src_ppr,
|
360
|
+
onehop_tgt_ppr), (non1hop_ix, non1hop_sppr,
|
361
|
+
non1hop_tppr)
|
362
|
+
|
363
|
+
def get_ppr_vals(
|
364
|
+
self, batch: Tensor, pair_diff_adj: Tensor,
|
365
|
+
ppr_matrix: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
|
366
|
+
r"""Get the src and tgt ppr vals.
|
367
|
+
|
368
|
+
Returns the: link the node belongs to, type of node
|
369
|
+
(e.g., CN), PPR relative to src, PPR relative to tgt.
|
370
|
+
|
371
|
+
Args:
|
372
|
+
batch (Tensor): The batch vector.
|
373
|
+
Denotes which node pairs to predict.
|
374
|
+
pair_diff_adj (SparseTensor): Combination of rows in
|
375
|
+
adjacency for src and tgt nodes (e.g., X1 + X2)
|
376
|
+
ppr_matrix (Tensor): PPR matrix
|
377
|
+
"""
|
378
|
+
# Additional terms for also choosing scores when ppr=0
|
379
|
+
# Multiplication removes any values for nodes not in batch
|
380
|
+
# Addition then adds offset to ensure we select when ppr=0
|
381
|
+
# All selected scores are +1 higher than their true val
|
382
|
+
src_ppr_adj = torch.index_select(
|
383
|
+
ppr_matrix, 0, batch[0]) * pair_diff_adj + pair_diff_adj
|
384
|
+
tgt_ppr_adj = torch.index_select(
|
385
|
+
ppr_matrix, 0, batch[1]) * pair_diff_adj + pair_diff_adj
|
386
|
+
|
387
|
+
# Can now convert ppr scores to dense
|
388
|
+
ppr_ix = src_ppr_adj.coalesce().indices()
|
389
|
+
src_ppr = src_ppr_adj.coalesce().values()
|
390
|
+
tgt_ppr = tgt_ppr_adj.coalesce().values()
|
391
|
+
|
392
|
+
# TODO: Needed due to a bug in recent torch versions
|
393
|
+
# see here for more - https://github.com/pytorch/pytorch/issues/114529
|
394
|
+
# note that if one is 0 so is the other
|
395
|
+
zero_vals = (src_ppr != 0)
|
396
|
+
src_ppr = src_ppr[zero_vals]
|
397
|
+
tgt_ppr = tgt_ppr[tgt_ppr != 0]
|
398
|
+
ppr_ix = ppr_ix[:, zero_vals]
|
399
|
+
|
400
|
+
pair_diff_adj = pair_diff_adj.coalesce().values()
|
401
|
+
node_type = pair_diff_adj[src_ppr != 0]
|
402
|
+
|
403
|
+
# Remove additional +1 from each ppr val
|
404
|
+
src_ppr = (src_ppr - node_type) / node_type
|
405
|
+
tgt_ppr = (tgt_ppr - node_type) / node_type
|
406
|
+
|
407
|
+
return ppr_ix, node_type, src_ppr, tgt_ppr
|
408
|
+
|
409
|
+
def drop_pairwise(
|
410
|
+
self,
|
411
|
+
pair_ix: Tensor,
|
412
|
+
src_ppr: Optional[Tensor] = None,
|
413
|
+
tgt_ppr: Optional[Tensor] = None,
|
414
|
+
node_indicator: Optional[Tensor] = None,
|
415
|
+
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
|
416
|
+
r"""Perform dropout on pairwise information
|
417
|
+
by randomly dropping a percentage of nodes.
|
418
|
+
|
419
|
+
Done before performing attention for efficiency
|
420
|
+
|
421
|
+
Args:
|
422
|
+
pair_ix (Tensor): Link node belongs to
|
423
|
+
src_ppr (Tensor, optional): PPR relative to src
|
424
|
+
(default: :obj:`None`)
|
425
|
+
tgt_ppr (Tensor, optional): PPR relative to tgt
|
426
|
+
(default: :obj:`None`)
|
427
|
+
node_indicator (Tensor, optional): Type of node (e.g., CN)
|
428
|
+
(default: :obj:`None`)
|
429
|
+
"""
|
430
|
+
num_indices = math.ceil(pair_ix.size(1) * (1 - self.trans_drop))
|
431
|
+
indices = torch.randperm(pair_ix.size(1))[:num_indices]
|
432
|
+
pair_ix = pair_ix[:, indices]
|
433
|
+
|
434
|
+
if src_ppr is not None:
|
435
|
+
src_ppr = src_ppr[indices]
|
436
|
+
if tgt_ppr is not None:
|
437
|
+
tgt_ppr = tgt_ppr[indices]
|
438
|
+
if node_indicator is not None:
|
439
|
+
node_indicator = node_indicator[indices]
|
440
|
+
|
441
|
+
return pair_ix, src_ppr, tgt_ppr, node_indicator
|
442
|
+
|
443
|
+
def get_structure_cnts(
|
444
|
+
self,
|
445
|
+
batch: Tensor,
|
446
|
+
cn_info: Tuple[Tensor, Tensor],
|
447
|
+
onehop_info: Tuple[Tensor, Tensor],
|
448
|
+
non1hop_info: Optional[Tuple[Tensor, Tensor]],
|
449
|
+
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
|
450
|
+
"""Counts for CNs, 1-Hop, and >1-Hop that satisfy PPR threshold.
|
451
|
+
|
452
|
+
Also include total # of neighbors
|
453
|
+
|
454
|
+
Args:
|
455
|
+
batch (Tensor): The batch vector.
|
456
|
+
Denotes which node pairs to predict.
|
457
|
+
cn_info (tuple): Information of CN nodes
|
458
|
+
Contains (ID of node, src ppr, tgt ppr)
|
459
|
+
onehop_info (tuple): Information of 1-Hop nodes.
|
460
|
+
Contains (ID of node, src ppr, tgt ppr)
|
461
|
+
non1hop_info (tuple): Information of >1-Hop nodes.
|
462
|
+
Contains (ID of node, src ppr, tgt ppr)
|
463
|
+
"""
|
464
|
+
num_cns = self.get_num_ppr_thresh(batch, cn_info[0], cn_info[1],
|
465
|
+
cn_info[2], self.thresh_cn)
|
466
|
+
num_1hop = self.get_num_ppr_thresh(batch, onehop_info[0],
|
467
|
+
onehop_info[1], onehop_info[2],
|
468
|
+
self.thresh_1hop)
|
469
|
+
|
470
|
+
# TOTAL num of 1-hop neighbors union
|
471
|
+
num_ppr_ones = self.get_num_ppr_thresh(batch, onehop_info[0],
|
472
|
+
onehop_info[1], onehop_info[2],
|
473
|
+
thresh=0)
|
474
|
+
num_neighbors = num_cns + num_ppr_ones
|
475
|
+
|
476
|
+
# Process for >1-hop is different which is why we use get_count below
|
477
|
+
if non1hop_info is None:
|
478
|
+
return num_cns, num_1hop, 0, num_neighbors
|
479
|
+
else:
|
480
|
+
num_non1hop = self.get_count(non1hop_info[0], batch)
|
481
|
+
return num_cns, num_1hop, num_non1hop, num_neighbors
|
482
|
+
|
483
|
+
def get_num_ppr_thresh(self, batch: Tensor, node_mask: Tensor,
|
484
|
+
src_ppr: Tensor, tgt_ppr: Tensor,
|
485
|
+
thresh: float) -> Tensor:
|
486
|
+
"""Get # of nodes `v` where `ppr(a, v) >= eta` & `ppr(b, v) >= eta`.
|
487
|
+
|
488
|
+
Args:
|
489
|
+
batch (Tensor): The batch vector.
|
490
|
+
Denotes which node pairs to predict.
|
491
|
+
node_mask (Tensor): IDs of nodes
|
492
|
+
src_ppr (Tensor): PPR relative to src node
|
493
|
+
tgt_ppr (Tensor): PPR relative to tgt node
|
494
|
+
thresh (float): PPR threshold for nodes (`eta`)
|
495
|
+
"""
|
496
|
+
weight = torch.ones(node_mask.size(1), device=node_mask.device)
|
497
|
+
|
498
|
+
ppr_above_thresh = (src_ppr >= thresh) & (tgt_ppr >= thresh)
|
499
|
+
num_ppr = scatter(ppr_above_thresh.float() * weight,
|
500
|
+
node_mask[0].long(), dim=0, dim_size=batch.size(1),
|
501
|
+
reduce="sum")
|
502
|
+
num_ppr = num_ppr.unsqueeze(-1)
|
503
|
+
|
504
|
+
return num_ppr
|
505
|
+
|
506
|
+
def get_count(
|
507
|
+
self,
|
508
|
+
node_mask: Tensor,
|
509
|
+
batch: Tensor,
|
510
|
+
) -> Tensor:
|
511
|
+
"""# of nodes for each sample in batch.
|
512
|
+
|
513
|
+
They node have already filtered by PPR beforehand
|
514
|
+
|
515
|
+
Args:
|
516
|
+
node_mask (Tensor): IDs of nodes
|
517
|
+
batch (Tensor): The batch vector.
|
518
|
+
Denotes which node pairs to predict.
|
519
|
+
"""
|
520
|
+
weight = torch.ones(node_mask.size(1), device=node_mask.device)
|
521
|
+
num_nodes = scatter(weight, node_mask[0].long(), dim=0,
|
522
|
+
dim_size=batch.size(1), reduce="sum")
|
523
|
+
num_nodes = num_nodes.unsqueeze(-1)
|
524
|
+
|
525
|
+
return num_nodes
|
526
|
+
|
527
|
+
def get_non_1hop_ppr(self, batch: Tensor, adj: Tensor,
|
528
|
+
ppr_matrix: Tensor) -> Tensor:
|
529
|
+
r"""Get PPR scores for non-1hop nodes.
|
530
|
+
|
531
|
+
Args:
|
532
|
+
batch (Tensor): Links in batch
|
533
|
+
adj (Tensor): Adjacency matrix
|
534
|
+
ppr_matrix (Tensor): Sparse PPR matrix
|
535
|
+
"""
|
536
|
+
# NOTE: Use original adj (one pass in forward() removes links in batch)
|
537
|
+
# Done since removing them converts src/tgt nodes to >1-hop nodes.
|
538
|
+
# Therefore removing CN and 1-hop will also remove the batch links.
|
539
|
+
|
540
|
+
# During training we add back in the links in the batch
|
541
|
+
# (we're removed from adjacency before being passed to model)
|
542
|
+
# Done since otherwise they will be mistakenly seen as >1-Hop nodes
|
543
|
+
# Instead they're 1-Hop, and get ignored accordingly
|
544
|
+
# Ignored during eval since we know the links aren't in the adj
|
545
|
+
adj2 = adj
|
546
|
+
if self.training:
|
547
|
+
n = adj.size(0)
|
548
|
+
batch_flip = torch.cat(
|
549
|
+
(batch, torch.flip(batch, (0, )).to(batch.device)), dim=-1)
|
550
|
+
batch_ones = torch.ones_like(batch_flip[0], device=batch.device)
|
551
|
+
adj_edges = torch.sparse_coo_tensor(batch_flip, batch_ones, [n, n],
|
552
|
+
device=batch.device)
|
553
|
+
adj_edges = adj_edges
|
554
|
+
adj2 = (adj + adj_edges).coalesce().bool().int()
|
555
|
+
|
556
|
+
src_adj = torch.index_select(adj2, 0, batch[0])
|
557
|
+
tgt_adj = torch.index_select(adj2, 0, batch[1])
|
558
|
+
|
559
|
+
src_ppr = torch.index_select(ppr_matrix, 0, batch[0])
|
560
|
+
tgt_ppr = torch.index_select(ppr_matrix, 0, batch[1])
|
561
|
+
|
562
|
+
# Remove CN scores
|
563
|
+
src_ppr = src_ppr - src_ppr * (src_adj * tgt_adj)
|
564
|
+
tgt_ppr = tgt_ppr - tgt_ppr * (src_adj * tgt_adj)
|
565
|
+
# Also need to remove CN entries in Adj
|
566
|
+
# Otherwise they leak into next computation
|
567
|
+
src_adj = src_adj - src_adj * (src_adj * tgt_adj)
|
568
|
+
tgt_adj = tgt_adj - tgt_adj * (src_adj * tgt_adj)
|
569
|
+
|
570
|
+
# Remove 1-Hop scores
|
571
|
+
src_ppr = src_ppr - src_ppr * (src_adj + tgt_adj)
|
572
|
+
tgt_ppr = tgt_ppr - tgt_ppr * (src_adj + tgt_adj)
|
573
|
+
|
574
|
+
# Make sure we include both when we convert to dense so indices align
|
575
|
+
# Do so by adding 1 to each based on the other
|
576
|
+
src_ppr_add = src_ppr + torch.sign(tgt_ppr)
|
577
|
+
tgt_ppr_add = tgt_ppr + torch.sign(src_ppr)
|
578
|
+
|
579
|
+
src_ix = src_ppr_add.coalesce().indices()
|
580
|
+
src_vals = src_ppr_add.coalesce().values()
|
581
|
+
tgt_vals = tgt_ppr_add.coalesce().values()
|
582
|
+
|
583
|
+
# Now we can remove value which is just 1
|
584
|
+
# Technically creates -1 scores for ppr scores that were 0
|
585
|
+
# Doesn't matter as they'll be filtered out by condition later
|
586
|
+
src_vals = src_vals - 1
|
587
|
+
tgt_vals = tgt_vals - 1
|
588
|
+
|
589
|
+
ppr_condition = (src_vals >= self.thresh_non1hop) & (
|
590
|
+
tgt_vals >= self.thresh_non1hop)
|
591
|
+
src_ix, src_vals, tgt_vals = src_ix[:, ppr_condition], src_vals[
|
592
|
+
ppr_condition], tgt_vals[ppr_condition]
|
593
|
+
|
594
|
+
return src_ix, src_vals, tgt_vals
|
595
|
+
|
596
|
+
def calc_sparse_ppr(self, edge_index: Tensor, num_nodes: int,
|
597
|
+
alpha: float = 0.15, eps: float = 5e-5) -> Tensor:
|
598
|
+
r"""Calculate the PPR of the graph in sparse format.
|
599
|
+
|
600
|
+
Args:
|
601
|
+
edge_index: The edge indices
|
602
|
+
num_nodes: Number of nodes
|
603
|
+
alpha (float, optional): The alpha value of the PageRank algorithm.
|
604
|
+
(default: :obj:`0.15`)
|
605
|
+
eps (float, optional): Threshold for stopping the PPR calculation
|
606
|
+
(default: :obj:`5e-5`)
|
607
|
+
"""
|
608
|
+
ei, ei_w = get_ppr(edge_index.cpu(), alpha=alpha, eps=eps,
|
609
|
+
num_nodes=num_nodes)
|
610
|
+
ppr_matrix = torch.sparse_coo_tensor(ei, ei_w, [num_nodes, num_nodes])
|
611
|
+
|
612
|
+
return ppr_matrix
|
613
|
+
|
614
|
+
|
615
|
+
class LPAttLayer(MessagePassing):
|
616
|
+
r"""Attention Layer for pairwise interaction module.
|
617
|
+
|
618
|
+
Args:
|
619
|
+
in_channels (int): Size of input dimension
|
620
|
+
out_channels (int): Size of output dimension
|
621
|
+
node_dim (int): Dimension of nodes being aggregated
|
622
|
+
num_heads (int): Number of heads to use in MHA
|
623
|
+
dropout (float): Dropout on attention values
|
624
|
+
concat (bool, optional): Whether to concat attention
|
625
|
+
heads. Otherwise sum (default: :obj:`True`)
|
626
|
+
"""
|
627
|
+
_alpha: OptTensor
|
628
|
+
|
629
|
+
def __init__(
|
630
|
+
self,
|
631
|
+
in_channels: int,
|
632
|
+
out_channels: int,
|
633
|
+
node_dim: int,
|
634
|
+
num_heads: int,
|
635
|
+
dropout: float,
|
636
|
+
concat: bool = True,
|
637
|
+
**kwargs,
|
638
|
+
):
|
639
|
+
super().__init__(node_dim=0, flow="target_to_source", **kwargs)
|
640
|
+
|
641
|
+
self.in_channels = in_channels
|
642
|
+
self.out_channels = out_channels
|
643
|
+
self.heads = num_heads
|
644
|
+
self.concat = concat
|
645
|
+
self.dropout = dropout
|
646
|
+
self.negative_slope = 0.2 # LeakyRelu
|
647
|
+
|
648
|
+
out_dim = 2
|
649
|
+
if node_dim is None:
|
650
|
+
node_dim = in_channels * out_dim
|
651
|
+
else:
|
652
|
+
node_dim = node_dim * out_dim
|
653
|
+
|
654
|
+
self.lin_l = Linear(in_channels, self.heads * out_channels,
|
655
|
+
weight_initializer='glorot')
|
656
|
+
self.lin_r = Linear(node_dim, self.heads * out_channels,
|
657
|
+
weight_initializer='glorot')
|
658
|
+
|
659
|
+
att_out = out_channels
|
660
|
+
self.att = Parameter(Tensor(1, self.heads, att_out))
|
661
|
+
|
662
|
+
if concat:
|
663
|
+
self.bias = Parameter(Tensor(self.heads * out_channels))
|
664
|
+
else:
|
665
|
+
self.bias = Parameter(Tensor(out_channels))
|
666
|
+
|
667
|
+
self._alpha = None
|
668
|
+
|
669
|
+
self.dropout = dropout
|
670
|
+
self.post_att_norm = nn.LayerNorm(out_channels)
|
671
|
+
|
672
|
+
self.reset_parameters()
|
673
|
+
|
674
|
+
def __repr__(self) -> str:
|
675
|
+
return (f'{self.__class__.__name__}({self.in_channels}, '
|
676
|
+
f'{self.out_channels}, heads={self.heads})')
|
677
|
+
|
678
|
+
def reset_parameters(self):
|
679
|
+
self.lin_l.reset_parameters()
|
680
|
+
self.lin_r.reset_parameters()
|
681
|
+
self.post_att_norm.reset_parameters()
|
682
|
+
glorot(self.att)
|
683
|
+
zeros(self.bias)
|
684
|
+
|
685
|
+
def forward(
|
686
|
+
self,
|
687
|
+
edge_index: Tensor,
|
688
|
+
edge_feats: Tensor,
|
689
|
+
node_feats: Tensor,
|
690
|
+
ppr_rpes: Tensor,
|
691
|
+
) -> Tensor:
|
692
|
+
"""Runs the forward pass of the module.
|
693
|
+
|
694
|
+
Args:
|
695
|
+
edge_index (Tensor): The edge indices.
|
696
|
+
edge_feats (Tensor): Concatenated representations
|
697
|
+
of src and target nodes for each link
|
698
|
+
node_feats (Tensor): Representations for individual
|
699
|
+
nodes
|
700
|
+
ppr_rpes (Tensor): Relative PEs for each node
|
701
|
+
"""
|
702
|
+
out = self.propagate(edge_index, x=(edge_feats, node_feats),
|
703
|
+
ppr_rpes=ppr_rpes, size=None)
|
704
|
+
|
705
|
+
alpha = self._alpha
|
706
|
+
assert alpha is not None
|
707
|
+
self._alpha = None
|
708
|
+
|
709
|
+
if self.concat:
|
710
|
+
out = out.view(-1, self.heads * self.out_channels)
|
711
|
+
else:
|
712
|
+
out = out.mean(dim=1)
|
713
|
+
|
714
|
+
if self.bias is not None:
|
715
|
+
out = out + self.bias
|
716
|
+
|
717
|
+
out = self.post_att_norm(out)
|
718
|
+
out = F.dropout(out, p=self.dropout, training=self.training)
|
719
|
+
|
720
|
+
return out
|
721
|
+
|
722
|
+
def message(self, x_i: Tensor, x_j: Tensor, ppr_rpes: Tensor,
|
723
|
+
index: Tensor, ptr: Tensor, size_i: Optional[int]) -> Tensor:
|
724
|
+
H, C = self.heads, self.out_channels
|
725
|
+
|
726
|
+
x_j = torch.cat((x_j, ppr_rpes), dim=-1)
|
727
|
+
x_j = self.lin_r(x_j).view(-1, H, C)
|
728
|
+
|
729
|
+
# e=(a, b) attending to v
|
730
|
+
e1, e2 = x_i.chunk(2, dim=-1)
|
731
|
+
e1 = self.lin_l(e1).view(-1, H, C)
|
732
|
+
e2 = self.lin_l(e2).view(-1, H, C)
|
733
|
+
x = x_j * (e1 + e2)
|
734
|
+
|
735
|
+
x = F.leaky_relu(x, self.negative_slope)
|
736
|
+
alpha = (x * self.att).sum(dim=-1)
|
737
|
+
|
738
|
+
alpha = softmax(alpha, index, ptr, size_i)
|
739
|
+
self._alpha = alpha
|
740
|
+
|
741
|
+
return x_j * alpha.unsqueeze(-1)
|
742
|
+
|
743
|
+
|
744
|
+
class MLP(nn.Module):
|
745
|
+
r"""L Layer MLP."""
|
746
|
+
def __init__(self, in_channels: int, hid_channels: int, out_channels: int,
|
747
|
+
num_layers: int = 2, drop: int = 0, norm: str = "layer"):
|
748
|
+
super().__init__()
|
749
|
+
self.dropout = drop
|
750
|
+
|
751
|
+
if norm == "batch":
|
752
|
+
self.norm = nn.BatchNorm1d(hid_channels)
|
753
|
+
elif norm == "layer":
|
754
|
+
self.norm = nn.LayerNorm(hid_channels)
|
755
|
+
else:
|
756
|
+
self.norm = None
|
757
|
+
|
758
|
+
self.linears = torch.nn.ModuleList()
|
759
|
+
|
760
|
+
if num_layers == 1:
|
761
|
+
self.linears.append(nn.Linear(in_channels, out_channels))
|
762
|
+
else:
|
763
|
+
self.linears.append(nn.Linear(in_channels, hid_channels))
|
764
|
+
for _ in range(num_layers - 2):
|
765
|
+
self.linears.append(nn.Linear(hid_channels, hid_channels))
|
766
|
+
self.linears.append(nn.Linear(hid_channels, out_channels))
|
767
|
+
|
768
|
+
def reset_parameters(self):
|
769
|
+
for lin in self.linears:
|
770
|
+
lin.reset_parameters()
|
771
|
+
if self.norm is not None:
|
772
|
+
self.norm.reset_parameters()
|
773
|
+
|
774
|
+
def forward(self, x: Tensor) -> Tensor:
|
775
|
+
for lin in self.linears[:-1]:
|
776
|
+
x = lin(x)
|
777
|
+
x = self.norm(x) if self.norm is not None else x
|
778
|
+
x = F.relu(x)
|
779
|
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
780
|
+
|
781
|
+
x = self.linears[-1](x)
|
782
|
+
|
783
|
+
return x.squeeze(-1)
|