pyg-nightly 2.7.0.dev20250826__py3-none-any.whl → 2.7.0.dev20250827__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,783 @@
1
+ import math
2
+ from typing import Optional
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from torch import Tensor
8
+ from torch.nn import Parameter
9
+
10
+ from ...nn.conv import MessagePassing
11
+ from ...nn.dense.linear import Linear
12
+ from ...nn.inits import glorot, zeros
13
+ from ...typing import Adj, OptTensor, Tuple
14
+ from ...utils import get_ppr, is_sparse, scatter, softmax
15
+ from .basic_gnn import GCN
16
+
17
+
18
+ class LPFormer(nn.Module):
19
+ r"""The LPFormer model from the
20
+ `"LPFormer: An Adaptive Graph Transformer for Link Prediction"
21
+ <https://arxiv.org/abs/2310.11009>`_ paper.
22
+
23
+ .. note::
24
+
25
+ For an example of using LPFormer, see
26
+ `examples/lpformer.py
27
+ <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/
28
+ lpformer.py>`_.
29
+
30
+ Args:
31
+ in_channels (int): Size of input dimension
32
+ hidden_channels (int): Size of hidden dimension
33
+ num_gnn_layers (int, optional): Number of GNN layers
34
+ (default: :obj:`2`)
35
+ gnn_dropout(float, optional): Dropout used for GNN
36
+ (default: :obj:`0.1`)
37
+ num_transformer_layers (int, optional): Number of Transformer layers
38
+ (default: :obj:`1`)
39
+ num_heads (int, optional): Number of heads to use in MHA
40
+ (default: :obj:`1`)
41
+ transformer_dropout (float, optional): Dropout used for Transformer
42
+ (default: :obj:`0.1`)
43
+ ppr_thresholds (list): PPR thresholds for different types of nodes.
44
+ Types include (in order) common neighbors, 1-Hop nodes
45
+ (that aren't CNs), and all other nodes.
46
+ (default: :obj:`[0, 1e-4, 1e-2]`)
47
+ gcn_cache (bool, optional): Whether to cache edge indices
48
+ during message passing. (default: :obj:`False`)
49
+ """
50
+ def __init__(
51
+ self,
52
+ in_channels: int,
53
+ hidden_channels: int,
54
+ num_gnn_layers: int = 2,
55
+ gnn_dropout: float = 0.1,
56
+ num_transformer_layers: int = 1,
57
+ num_heads: int = 1,
58
+ transformer_dropout: float = 0.1,
59
+ ppr_thresholds: list = None,
60
+ gcn_cache=False,
61
+ ):
62
+ super().__init__()
63
+
64
+ # Default thresholds
65
+ if ppr_thresholds is None:
66
+ ppr_thresholds = [0, 1e-4, 1e-2]
67
+
68
+ if len(ppr_thresholds) == 3:
69
+ self.thresh_cn = ppr_thresholds[0]
70
+ self.thresh_1hop = ppr_thresholds[1]
71
+ self.thresh_non1hop = ppr_thresholds[2]
72
+ else:
73
+ raise ValueError(
74
+ "Argument 'ppr_thresholds' must only be length 3!")
75
+
76
+ self.in_dim = in_channels
77
+ self.hid_dim = hidden_channels
78
+ self.gnn_drop = gnn_dropout
79
+ self.trans_drop = transformer_dropout
80
+
81
+ self.gnn = GCN(in_channels, hidden_channels, num_gnn_layers,
82
+ dropout=gnn_dropout, norm="layer_norm",
83
+ cached=gcn_cache)
84
+ self.gnn_norm = nn.LayerNorm(hidden_channels)
85
+
86
+ # Create Transformer Layers
87
+ self.att_layers = nn.ModuleList()
88
+ for il in range(num_transformer_layers):
89
+ if il == 0:
90
+ node_dim = None
91
+ self.out_dim = self.hid_dim * 2 if num_transformer_layers > 1 \
92
+ else self.hid_dim
93
+ elif il == self.num_layers - 1:
94
+ node_dim = self.hid_dim
95
+ else:
96
+ self.out_dim = node_dim = self.hid_dim
97
+
98
+ self.att_layers.append(
99
+ LPAttLayer(self.hid_dim, self.out_dim, node_dim, num_heads,
100
+ self.trans_drop))
101
+
102
+ self.elementwise_lin = MLP(self.hid_dim, self.hid_dim, self.hid_dim)
103
+
104
+ # Relative Positional Encodings
105
+ self.ppr_encoder_cn = MLP(2, self.hid_dim, self.hid_dim)
106
+ self.ppr_encoder_onehop = MLP(2, self.hid_dim, self.hid_dim)
107
+ self.ppr_encoder_non1hop = MLP(2, self.hid_dim, self.hid_dim)
108
+
109
+ # thresh=1 implies ignoring some set of nodes
110
+ # Also allows us to be more efficient later
111
+ if self.thresh_non1hop == 1 and self.thresh_1hop == 1:
112
+ self.mask = "cn"
113
+ elif self.thresh_non1hop == 1 and self.thresh_1hop < 1:
114
+ self.mask = "1-hop"
115
+ else:
116
+ self.mask = "all"
117
+
118
+ # 4 is for counts of diff nodes
119
+ pairwise_dim = self.hid_dim * num_heads + 4
120
+ self.pairwise_lin = MLP(pairwise_dim, pairwise_dim, self.hid_dim)
121
+
122
+ self.score_func = MLP(self.hid_dim * 2, self.hid_dim * 2, 1, norm=None)
123
+
124
+ def __repr__(self) -> str:
125
+ return (f'{self.__class__.__name__}({self.in_dim}, '
126
+ f'{self.hid_dim}, num_gnn_layers={self.gnn.num_layers}, '
127
+ f'num_transformer_layers={len(self.att_layers)})')
128
+
129
+ def reset_parameters(self):
130
+ r"""Resets all learnable parameters of the module."""
131
+ self.gnn.reset_parameters()
132
+ self.gnn_norm.reset_parameters()
133
+ self.elementwise_lin.reset_parameters()
134
+ self.pairwise_lin.reset_parameters()
135
+ self.ppr_encoder_cn.reset_parameters()
136
+ self.ppr_encoder_onehop.reset_parameters()
137
+ self.ppr_encoder_non1hop.reset_parameters()
138
+ self.score_func.reset_parameters()
139
+ for i in range(len(self.att_layers)):
140
+ self.att_layers[i].reset_parameters()
141
+
142
+ def forward(
143
+ self,
144
+ batch: Tensor,
145
+ x: Tensor,
146
+ edge_index: Adj,
147
+ ppr_matrix: Tensor,
148
+ ) -> Tensor:
149
+ r"""Forward Pass of LPFormer.
150
+
151
+ Returns raw logits for each link
152
+
153
+ Args:
154
+ batch (Tensor): The batch vector.
155
+ Denotes which node pairs to predict.
156
+ x (Tensor): Input node features
157
+ edge_index (torch.Tensor, SparseTensor): The edge indices.
158
+ Either in COO or SparseTensor format
159
+ ppr_matrix (Tensor): PPR matrix
160
+ """
161
+ batch = batch.to(x.device)
162
+
163
+ X_node = self.propagate(x, edge_index)
164
+ x_i, x_j = X_node[batch[0]], X_node[batch[1]]
165
+ elementwise_edge_feats = self.elementwise_lin(x_i * x_j)
166
+
167
+ # Ensure in sparse format
168
+ # Need as native torch.sparse for later computations
169
+ # (necessary operations are not supported by PyG SparseTensor)
170
+ if not edge_index.is_sparse:
171
+ num_nodes = ppr_matrix.size(1)
172
+ vals = torch.ones(len(edge_index[0]), device=edge_index.device)
173
+ edge_index = torch.sparse_coo_tensor(edge_index, vals,
174
+ [num_nodes, num_nodes])
175
+ # Checks if SparseTensor, if so the convert
176
+ if is_sparse(edge_index) and not edge_index.is_sparse:
177
+ edge_index = edge_index.to_torch_sparse_coo_tensor()
178
+
179
+ # Ensure {0, 1}
180
+ edge_index = edge_index.coalesce().bool().int()
181
+
182
+ pairwise_feats = self.calc_pairwise(batch, X_node, edge_index,
183
+ ppr_matrix)
184
+ combined_feats = torch.cat((elementwise_edge_feats, pairwise_feats),
185
+ dim=-1)
186
+
187
+ logits = self.score_func(combined_feats)
188
+ return logits
189
+
190
+ def propagate(self, x: Tensor, adj: Adj) -> Tensor:
191
+ """Propagate via GNN.
192
+
193
+ Args:
194
+ x (Tensor): Node features
195
+ adj (torch.Tensor, SparseTensor): Adjacency matrix
196
+ """
197
+ x = F.dropout(x, p=self.gnn_drop, training=self.training)
198
+ X_node = self.gnn(x, adj)
199
+ X_node = self.gnn_norm(X_node)
200
+
201
+ return X_node
202
+
203
+ def calc_pairwise(self, batch: Tensor, X_node: Tensor, adj_mask: Tensor,
204
+ ppr_matrix: Tensor) -> Tensor:
205
+ r"""Calculate the pairwise features for the node pairs.
206
+
207
+ Args:
208
+ batch (Tensor): The batch vector.
209
+ Denotes which node pairs to predict.
210
+ X_node (Tensor): Node representations
211
+ adj_mask (Tensor): Mask of adjacency matrix used for computing the
212
+ different node types.
213
+ ppr_matrix (Tensor): PPR matrix
214
+ """
215
+ k_i, k_j = X_node[batch[0]], X_node[batch[1]]
216
+ pairwise_feats = torch.cat((k_i, k_j), dim=-1)
217
+
218
+ cn_info, onehop_info, non1hop_info = self.compute_node_mask(
219
+ batch, adj_mask, ppr_matrix)
220
+
221
+ all_mask = cn_info[0]
222
+ if onehop_info is not None:
223
+ all_mask = torch.cat((all_mask, onehop_info[0]), dim=-1)
224
+ if non1hop_info is not None:
225
+ all_mask = torch.cat((all_mask, non1hop_info[0]), dim=-1)
226
+
227
+ pes = self.get_pos_encodings(cn_info[1:], onehop_info[1:],
228
+ non1hop_info[1:])
229
+
230
+ for lay in range(len(self.att_layers)):
231
+ pairwise_feats = self.att_layers[lay](all_mask, pairwise_feats,
232
+ X_node, pes)
233
+
234
+ num_cns, num_1hop, num_non1hop, num_neigh = self.get_structure_cnts(
235
+ batch, cn_info, onehop_info, non1hop_info)
236
+
237
+ pairwise_feats = torch.cat(
238
+ (pairwise_feats, num_cns, num_1hop, num_non1hop, num_neigh),
239
+ dim=-1)
240
+
241
+ pairwise_feats = self.pairwise_lin(pairwise_feats)
242
+ return pairwise_feats
243
+
244
+ def get_pos_encodings(
245
+ self, cn_ppr: Tuple[Tensor, Tensor],
246
+ onehop_ppr: Optional[Tuple[Tensor, Tensor]] = None,
247
+ non1hop_ppr: Optional[Tuple[Tensor, Tensor]] = None) -> Tensor:
248
+ r"""Calculate the PPR-based relative positional encodings.
249
+
250
+ Due to thresholds, sometimes we don't have 1-hop or >1-hop nodes.
251
+ In those cases, the value of onehop_ppr and/or non1hop_ppr should
252
+ be `None`.
253
+
254
+ Args:
255
+ cn_ppr (tuple, optional): PPR scores of CNs.
256
+ onehop_ppr (tuple, optional): PPR scores of 1-Hop.
257
+ (default: :obj:`None`)
258
+ non1hop_ppr (tuple, optional): PPR scores of >1-Hop.
259
+ (default: :obj:`None`)
260
+ """
261
+ cn_a = self.ppr_encoder_cn(torch.stack((cn_ppr[0], cn_ppr[1])).t())
262
+ cn_b = self.ppr_encoder_cn(torch.stack((cn_ppr[1], cn_ppr[0])).t())
263
+ cn_pe = cn_a + cn_b
264
+
265
+ if onehop_ppr is None:
266
+ return cn_pe
267
+
268
+ onehop_a = self.ppr_encoder_onehop(
269
+ torch.stack((onehop_ppr[0], onehop_ppr[1])).t())
270
+ onehop_b = self.ppr_encoder_onehop(
271
+ torch.stack((onehop_ppr[1], onehop_ppr[0])).t())
272
+ onehop_pe = onehop_a + onehop_b
273
+
274
+ if non1hop_ppr is None:
275
+ return torch.cat((cn_pe, onehop_pe), dim=0)
276
+
277
+ non1hop_a = self.ppr_encoder_non1hop(
278
+ torch.stack((non1hop_ppr[0], non1hop_ppr[1])).t())
279
+ non1hop_b = self.ppr_encoder_non1hop(
280
+ torch.stack((non1hop_ppr[1], non1hop_ppr[0])).t())
281
+ non1hop_pe = non1hop_a + non1hop_b
282
+
283
+ return torch.cat((cn_pe, onehop_pe, non1hop_pe), dim=0)
284
+
285
+ def compute_node_mask(
286
+ self, batch: Tensor, adj: Tensor, ppr_matrix: Tensor
287
+ ) -> Tuple[Tuple, Optional[Tuple], Optional[Tuple]]:
288
+ r"""Get mask based on type of node.
289
+
290
+ When mask_type is not "cn", also return the ppr vals for both
291
+ the source and target.
292
+
293
+ Args:
294
+ batch (Tensor): The batch vector.
295
+ Denotes which node pairs to predict.
296
+ adj (SparseTensor): Adjacency matrix
297
+ ppr_matrix (Tensor): PPR matrix
298
+ """
299
+ src_adj = torch.index_select(adj, 0, batch[0])
300
+ tgt_adj = torch.index_select(adj, 0, batch[1])
301
+
302
+ if self.mask == "cn":
303
+ # 1 when CN, 0 otherwise
304
+ pair_adj = src_adj * tgt_adj
305
+ else:
306
+ # Equals: {0: ">1-Hop", 1: "1-Hop (Non-CN)", 2: "CN"}
307
+ pair_adj = src_adj + tgt_adj
308
+
309
+ pair_ix, node_type, src_ppr, tgt_ppr = self.get_ppr_vals(
310
+ batch, pair_adj, ppr_matrix)
311
+
312
+ cn_filt_cond = (src_ppr >= self.thresh_cn) & (tgt_ppr
313
+ >= self.thresh_cn)
314
+ onehop_filt_cond = (src_ppr >= self.thresh_1hop) & (
315
+ tgt_ppr >= self.thresh_1hop)
316
+
317
+ if self.mask != "cn":
318
+ filt_cond = torch.where(node_type == 1, onehop_filt_cond,
319
+ cn_filt_cond)
320
+ else:
321
+ filt_cond = torch.where(node_type == 0, onehop_filt_cond,
322
+ cn_filt_cond)
323
+
324
+ pair_ix, node_type = pair_ix[:, filt_cond], node_type[filt_cond]
325
+ src_ppr, tgt_ppr = src_ppr[filt_cond], tgt_ppr[filt_cond]
326
+
327
+ # >1-Hop mask is gotten separately
328
+ if self.mask == "all":
329
+ non1hop_ix, non1hop_sppr, non1hop_tppr = self.get_non_1hop_ppr(
330
+ batch, adj, ppr_matrix)
331
+
332
+ # Dropout
333
+ if self.training and self.trans_drop > 0:
334
+ pair_ix, src_ppr, tgt_ppr, node_type = self.drop_pairwise(
335
+ pair_ix, src_ppr, tgt_ppr, node_type)
336
+ if self.mask == "all":
337
+ non1hop_ix, non1hop_sppr, non1hop_tppr, _ = self.drop_pairwise(
338
+ non1hop_ix, non1hop_sppr, non1hop_tppr)
339
+
340
+ # Separate out CN and 1-Hop
341
+ if self.mask != "cn":
342
+ cn_ind = node_type == 2
343
+ cn_ix = pair_ix[:, cn_ind]
344
+ cn_src_ppr = src_ppr[cn_ind]
345
+ cn_tgt_ppr = tgt_ppr[cn_ind]
346
+
347
+ one_hop_ind = node_type == 1
348
+ onehop_ix = pair_ix[:, one_hop_ind]
349
+ onehop_src_ppr = src_ppr[one_hop_ind]
350
+ onehop_tgt_ppr = tgt_ppr[one_hop_ind]
351
+
352
+ if self.mask == "cn":
353
+ return (pair_ix, src_ppr, tgt_ppr), None, None
354
+ elif self.mask == "1-hop":
355
+ return (cn_ix, cn_src_ppr, cn_tgt_ppr), (onehop_ix, onehop_src_ppr,
356
+ onehop_tgt_ppr), None
357
+ else:
358
+ return (cn_ix, cn_src_ppr,
359
+ cn_tgt_ppr), (onehop_ix, onehop_src_ppr,
360
+ onehop_tgt_ppr), (non1hop_ix, non1hop_sppr,
361
+ non1hop_tppr)
362
+
363
+ def get_ppr_vals(
364
+ self, batch: Tensor, pair_diff_adj: Tensor,
365
+ ppr_matrix: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
366
+ r"""Get the src and tgt ppr vals.
367
+
368
+ Returns the: link the node belongs to, type of node
369
+ (e.g., CN), PPR relative to src, PPR relative to tgt.
370
+
371
+ Args:
372
+ batch (Tensor): The batch vector.
373
+ Denotes which node pairs to predict.
374
+ pair_diff_adj (SparseTensor): Combination of rows in
375
+ adjacency for src and tgt nodes (e.g., X1 + X2)
376
+ ppr_matrix (Tensor): PPR matrix
377
+ """
378
+ # Additional terms for also choosing scores when ppr=0
379
+ # Multiplication removes any values for nodes not in batch
380
+ # Addition then adds offset to ensure we select when ppr=0
381
+ # All selected scores are +1 higher than their true val
382
+ src_ppr_adj = torch.index_select(
383
+ ppr_matrix, 0, batch[0]) * pair_diff_adj + pair_diff_adj
384
+ tgt_ppr_adj = torch.index_select(
385
+ ppr_matrix, 0, batch[1]) * pair_diff_adj + pair_diff_adj
386
+
387
+ # Can now convert ppr scores to dense
388
+ ppr_ix = src_ppr_adj.coalesce().indices()
389
+ src_ppr = src_ppr_adj.coalesce().values()
390
+ tgt_ppr = tgt_ppr_adj.coalesce().values()
391
+
392
+ # TODO: Needed due to a bug in recent torch versions
393
+ # see here for more - https://github.com/pytorch/pytorch/issues/114529
394
+ # note that if one is 0 so is the other
395
+ zero_vals = (src_ppr != 0)
396
+ src_ppr = src_ppr[zero_vals]
397
+ tgt_ppr = tgt_ppr[tgt_ppr != 0]
398
+ ppr_ix = ppr_ix[:, zero_vals]
399
+
400
+ pair_diff_adj = pair_diff_adj.coalesce().values()
401
+ node_type = pair_diff_adj[src_ppr != 0]
402
+
403
+ # Remove additional +1 from each ppr val
404
+ src_ppr = (src_ppr - node_type) / node_type
405
+ tgt_ppr = (tgt_ppr - node_type) / node_type
406
+
407
+ return ppr_ix, node_type, src_ppr, tgt_ppr
408
+
409
+ def drop_pairwise(
410
+ self,
411
+ pair_ix: Tensor,
412
+ src_ppr: Optional[Tensor] = None,
413
+ tgt_ppr: Optional[Tensor] = None,
414
+ node_indicator: Optional[Tensor] = None,
415
+ ) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
416
+ r"""Perform dropout on pairwise information
417
+ by randomly dropping a percentage of nodes.
418
+
419
+ Done before performing attention for efficiency
420
+
421
+ Args:
422
+ pair_ix (Tensor): Link node belongs to
423
+ src_ppr (Tensor, optional): PPR relative to src
424
+ (default: :obj:`None`)
425
+ tgt_ppr (Tensor, optional): PPR relative to tgt
426
+ (default: :obj:`None`)
427
+ node_indicator (Tensor, optional): Type of node (e.g., CN)
428
+ (default: :obj:`None`)
429
+ """
430
+ num_indices = math.ceil(pair_ix.size(1) * (1 - self.trans_drop))
431
+ indices = torch.randperm(pair_ix.size(1))[:num_indices]
432
+ pair_ix = pair_ix[:, indices]
433
+
434
+ if src_ppr is not None:
435
+ src_ppr = src_ppr[indices]
436
+ if tgt_ppr is not None:
437
+ tgt_ppr = tgt_ppr[indices]
438
+ if node_indicator is not None:
439
+ node_indicator = node_indicator[indices]
440
+
441
+ return pair_ix, src_ppr, tgt_ppr, node_indicator
442
+
443
+ def get_structure_cnts(
444
+ self,
445
+ batch: Tensor,
446
+ cn_info: Tuple[Tensor, Tensor],
447
+ onehop_info: Tuple[Tensor, Tensor],
448
+ non1hop_info: Optional[Tuple[Tensor, Tensor]],
449
+ ) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
450
+ """Counts for CNs, 1-Hop, and >1-Hop that satisfy PPR threshold.
451
+
452
+ Also include total # of neighbors
453
+
454
+ Args:
455
+ batch (Tensor): The batch vector.
456
+ Denotes which node pairs to predict.
457
+ cn_info (tuple): Information of CN nodes
458
+ Contains (ID of node, src ppr, tgt ppr)
459
+ onehop_info (tuple): Information of 1-Hop nodes.
460
+ Contains (ID of node, src ppr, tgt ppr)
461
+ non1hop_info (tuple): Information of >1-Hop nodes.
462
+ Contains (ID of node, src ppr, tgt ppr)
463
+ """
464
+ num_cns = self.get_num_ppr_thresh(batch, cn_info[0], cn_info[1],
465
+ cn_info[2], self.thresh_cn)
466
+ num_1hop = self.get_num_ppr_thresh(batch, onehop_info[0],
467
+ onehop_info[1], onehop_info[2],
468
+ self.thresh_1hop)
469
+
470
+ # TOTAL num of 1-hop neighbors union
471
+ num_ppr_ones = self.get_num_ppr_thresh(batch, onehop_info[0],
472
+ onehop_info[1], onehop_info[2],
473
+ thresh=0)
474
+ num_neighbors = num_cns + num_ppr_ones
475
+
476
+ # Process for >1-hop is different which is why we use get_count below
477
+ if non1hop_info is None:
478
+ return num_cns, num_1hop, 0, num_neighbors
479
+ else:
480
+ num_non1hop = self.get_count(non1hop_info[0], batch)
481
+ return num_cns, num_1hop, num_non1hop, num_neighbors
482
+
483
+ def get_num_ppr_thresh(self, batch: Tensor, node_mask: Tensor,
484
+ src_ppr: Tensor, tgt_ppr: Tensor,
485
+ thresh: float) -> Tensor:
486
+ """Get # of nodes `v` where `ppr(a, v) >= eta` & `ppr(b, v) >= eta`.
487
+
488
+ Args:
489
+ batch (Tensor): The batch vector.
490
+ Denotes which node pairs to predict.
491
+ node_mask (Tensor): IDs of nodes
492
+ src_ppr (Tensor): PPR relative to src node
493
+ tgt_ppr (Tensor): PPR relative to tgt node
494
+ thresh (float): PPR threshold for nodes (`eta`)
495
+ """
496
+ weight = torch.ones(node_mask.size(1), device=node_mask.device)
497
+
498
+ ppr_above_thresh = (src_ppr >= thresh) & (tgt_ppr >= thresh)
499
+ num_ppr = scatter(ppr_above_thresh.float() * weight,
500
+ node_mask[0].long(), dim=0, dim_size=batch.size(1),
501
+ reduce="sum")
502
+ num_ppr = num_ppr.unsqueeze(-1)
503
+
504
+ return num_ppr
505
+
506
+ def get_count(
507
+ self,
508
+ node_mask: Tensor,
509
+ batch: Tensor,
510
+ ) -> Tensor:
511
+ """# of nodes for each sample in batch.
512
+
513
+ They node have already filtered by PPR beforehand
514
+
515
+ Args:
516
+ node_mask (Tensor): IDs of nodes
517
+ batch (Tensor): The batch vector.
518
+ Denotes which node pairs to predict.
519
+ """
520
+ weight = torch.ones(node_mask.size(1), device=node_mask.device)
521
+ num_nodes = scatter(weight, node_mask[0].long(), dim=0,
522
+ dim_size=batch.size(1), reduce="sum")
523
+ num_nodes = num_nodes.unsqueeze(-1)
524
+
525
+ return num_nodes
526
+
527
+ def get_non_1hop_ppr(self, batch: Tensor, adj: Tensor,
528
+ ppr_matrix: Tensor) -> Tensor:
529
+ r"""Get PPR scores for non-1hop nodes.
530
+
531
+ Args:
532
+ batch (Tensor): Links in batch
533
+ adj (Tensor): Adjacency matrix
534
+ ppr_matrix (Tensor): Sparse PPR matrix
535
+ """
536
+ # NOTE: Use original adj (one pass in forward() removes links in batch)
537
+ # Done since removing them converts src/tgt nodes to >1-hop nodes.
538
+ # Therefore removing CN and 1-hop will also remove the batch links.
539
+
540
+ # During training we add back in the links in the batch
541
+ # (we're removed from adjacency before being passed to model)
542
+ # Done since otherwise they will be mistakenly seen as >1-Hop nodes
543
+ # Instead they're 1-Hop, and get ignored accordingly
544
+ # Ignored during eval since we know the links aren't in the adj
545
+ adj2 = adj
546
+ if self.training:
547
+ n = adj.size(0)
548
+ batch_flip = torch.cat(
549
+ (batch, torch.flip(batch, (0, )).to(batch.device)), dim=-1)
550
+ batch_ones = torch.ones_like(batch_flip[0], device=batch.device)
551
+ adj_edges = torch.sparse_coo_tensor(batch_flip, batch_ones, [n, n],
552
+ device=batch.device)
553
+ adj_edges = adj_edges
554
+ adj2 = (adj + adj_edges).coalesce().bool().int()
555
+
556
+ src_adj = torch.index_select(adj2, 0, batch[0])
557
+ tgt_adj = torch.index_select(adj2, 0, batch[1])
558
+
559
+ src_ppr = torch.index_select(ppr_matrix, 0, batch[0])
560
+ tgt_ppr = torch.index_select(ppr_matrix, 0, batch[1])
561
+
562
+ # Remove CN scores
563
+ src_ppr = src_ppr - src_ppr * (src_adj * tgt_adj)
564
+ tgt_ppr = tgt_ppr - tgt_ppr * (src_adj * tgt_adj)
565
+ # Also need to remove CN entries in Adj
566
+ # Otherwise they leak into next computation
567
+ src_adj = src_adj - src_adj * (src_adj * tgt_adj)
568
+ tgt_adj = tgt_adj - tgt_adj * (src_adj * tgt_adj)
569
+
570
+ # Remove 1-Hop scores
571
+ src_ppr = src_ppr - src_ppr * (src_adj + tgt_adj)
572
+ tgt_ppr = tgt_ppr - tgt_ppr * (src_adj + tgt_adj)
573
+
574
+ # Make sure we include both when we convert to dense so indices align
575
+ # Do so by adding 1 to each based on the other
576
+ src_ppr_add = src_ppr + torch.sign(tgt_ppr)
577
+ tgt_ppr_add = tgt_ppr + torch.sign(src_ppr)
578
+
579
+ src_ix = src_ppr_add.coalesce().indices()
580
+ src_vals = src_ppr_add.coalesce().values()
581
+ tgt_vals = tgt_ppr_add.coalesce().values()
582
+
583
+ # Now we can remove value which is just 1
584
+ # Technically creates -1 scores for ppr scores that were 0
585
+ # Doesn't matter as they'll be filtered out by condition later
586
+ src_vals = src_vals - 1
587
+ tgt_vals = tgt_vals - 1
588
+
589
+ ppr_condition = (src_vals >= self.thresh_non1hop) & (
590
+ tgt_vals >= self.thresh_non1hop)
591
+ src_ix, src_vals, tgt_vals = src_ix[:, ppr_condition], src_vals[
592
+ ppr_condition], tgt_vals[ppr_condition]
593
+
594
+ return src_ix, src_vals, tgt_vals
595
+
596
+ def calc_sparse_ppr(self, edge_index: Tensor, num_nodes: int,
597
+ alpha: float = 0.15, eps: float = 5e-5) -> Tensor:
598
+ r"""Calculate the PPR of the graph in sparse format.
599
+
600
+ Args:
601
+ edge_index: The edge indices
602
+ num_nodes: Number of nodes
603
+ alpha (float, optional): The alpha value of the PageRank algorithm.
604
+ (default: :obj:`0.15`)
605
+ eps (float, optional): Threshold for stopping the PPR calculation
606
+ (default: :obj:`5e-5`)
607
+ """
608
+ ei, ei_w = get_ppr(edge_index.cpu(), alpha=alpha, eps=eps,
609
+ num_nodes=num_nodes)
610
+ ppr_matrix = torch.sparse_coo_tensor(ei, ei_w, [num_nodes, num_nodes])
611
+
612
+ return ppr_matrix
613
+
614
+
615
+ class LPAttLayer(MessagePassing):
616
+ r"""Attention Layer for pairwise interaction module.
617
+
618
+ Args:
619
+ in_channels (int): Size of input dimension
620
+ out_channels (int): Size of output dimension
621
+ node_dim (int): Dimension of nodes being aggregated
622
+ num_heads (int): Number of heads to use in MHA
623
+ dropout (float): Dropout on attention values
624
+ concat (bool, optional): Whether to concat attention
625
+ heads. Otherwise sum (default: :obj:`True`)
626
+ """
627
+ _alpha: OptTensor
628
+
629
+ def __init__(
630
+ self,
631
+ in_channels: int,
632
+ out_channels: int,
633
+ node_dim: int,
634
+ num_heads: int,
635
+ dropout: float,
636
+ concat: bool = True,
637
+ **kwargs,
638
+ ):
639
+ super().__init__(node_dim=0, flow="target_to_source", **kwargs)
640
+
641
+ self.in_channels = in_channels
642
+ self.out_channels = out_channels
643
+ self.heads = num_heads
644
+ self.concat = concat
645
+ self.dropout = dropout
646
+ self.negative_slope = 0.2 # LeakyRelu
647
+
648
+ out_dim = 2
649
+ if node_dim is None:
650
+ node_dim = in_channels * out_dim
651
+ else:
652
+ node_dim = node_dim * out_dim
653
+
654
+ self.lin_l = Linear(in_channels, self.heads * out_channels,
655
+ weight_initializer='glorot')
656
+ self.lin_r = Linear(node_dim, self.heads * out_channels,
657
+ weight_initializer='glorot')
658
+
659
+ att_out = out_channels
660
+ self.att = Parameter(Tensor(1, self.heads, att_out))
661
+
662
+ if concat:
663
+ self.bias = Parameter(Tensor(self.heads * out_channels))
664
+ else:
665
+ self.bias = Parameter(Tensor(out_channels))
666
+
667
+ self._alpha = None
668
+
669
+ self.dropout = dropout
670
+ self.post_att_norm = nn.LayerNorm(out_channels)
671
+
672
+ self.reset_parameters()
673
+
674
+ def __repr__(self) -> str:
675
+ return (f'{self.__class__.__name__}({self.in_channels}, '
676
+ f'{self.out_channels}, heads={self.heads})')
677
+
678
+ def reset_parameters(self):
679
+ self.lin_l.reset_parameters()
680
+ self.lin_r.reset_parameters()
681
+ self.post_att_norm.reset_parameters()
682
+ glorot(self.att)
683
+ zeros(self.bias)
684
+
685
+ def forward(
686
+ self,
687
+ edge_index: Tensor,
688
+ edge_feats: Tensor,
689
+ node_feats: Tensor,
690
+ ppr_rpes: Tensor,
691
+ ) -> Tensor:
692
+ """Runs the forward pass of the module.
693
+
694
+ Args:
695
+ edge_index (Tensor): The edge indices.
696
+ edge_feats (Tensor): Concatenated representations
697
+ of src and target nodes for each link
698
+ node_feats (Tensor): Representations for individual
699
+ nodes
700
+ ppr_rpes (Tensor): Relative PEs for each node
701
+ """
702
+ out = self.propagate(edge_index, x=(edge_feats, node_feats),
703
+ ppr_rpes=ppr_rpes, size=None)
704
+
705
+ alpha = self._alpha
706
+ assert alpha is not None
707
+ self._alpha = None
708
+
709
+ if self.concat:
710
+ out = out.view(-1, self.heads * self.out_channels)
711
+ else:
712
+ out = out.mean(dim=1)
713
+
714
+ if self.bias is not None:
715
+ out = out + self.bias
716
+
717
+ out = self.post_att_norm(out)
718
+ out = F.dropout(out, p=self.dropout, training=self.training)
719
+
720
+ return out
721
+
722
+ def message(self, x_i: Tensor, x_j: Tensor, ppr_rpes: Tensor,
723
+ index: Tensor, ptr: Tensor, size_i: Optional[int]) -> Tensor:
724
+ H, C = self.heads, self.out_channels
725
+
726
+ x_j = torch.cat((x_j, ppr_rpes), dim=-1)
727
+ x_j = self.lin_r(x_j).view(-1, H, C)
728
+
729
+ # e=(a, b) attending to v
730
+ e1, e2 = x_i.chunk(2, dim=-1)
731
+ e1 = self.lin_l(e1).view(-1, H, C)
732
+ e2 = self.lin_l(e2).view(-1, H, C)
733
+ x = x_j * (e1 + e2)
734
+
735
+ x = F.leaky_relu(x, self.negative_slope)
736
+ alpha = (x * self.att).sum(dim=-1)
737
+
738
+ alpha = softmax(alpha, index, ptr, size_i)
739
+ self._alpha = alpha
740
+
741
+ return x_j * alpha.unsqueeze(-1)
742
+
743
+
744
+ class MLP(nn.Module):
745
+ r"""L Layer MLP."""
746
+ def __init__(self, in_channels: int, hid_channels: int, out_channels: int,
747
+ num_layers: int = 2, drop: int = 0, norm: str = "layer"):
748
+ super().__init__()
749
+ self.dropout = drop
750
+
751
+ if norm == "batch":
752
+ self.norm = nn.BatchNorm1d(hid_channels)
753
+ elif norm == "layer":
754
+ self.norm = nn.LayerNorm(hid_channels)
755
+ else:
756
+ self.norm = None
757
+
758
+ self.linears = torch.nn.ModuleList()
759
+
760
+ if num_layers == 1:
761
+ self.linears.append(nn.Linear(in_channels, out_channels))
762
+ else:
763
+ self.linears.append(nn.Linear(in_channels, hid_channels))
764
+ for _ in range(num_layers - 2):
765
+ self.linears.append(nn.Linear(hid_channels, hid_channels))
766
+ self.linears.append(nn.Linear(hid_channels, out_channels))
767
+
768
+ def reset_parameters(self):
769
+ for lin in self.linears:
770
+ lin.reset_parameters()
771
+ if self.norm is not None:
772
+ self.norm.reset_parameters()
773
+
774
+ def forward(self, x: Tensor) -> Tensor:
775
+ for lin in self.linears[:-1]:
776
+ x = lin(x)
777
+ x = self.norm(x) if self.norm is not None else x
778
+ x = F.relu(x)
779
+ x = F.dropout(x, p=self.dropout, training=self.training)
780
+
781
+ x = self.linears[-1](x)
782
+
783
+ return x.squeeze(-1)