pyg-nightly 2.7.0.dev20250825__py3-none-any.whl → 2.7.0.dev20250827__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pyg_nightly-2.7.0.dev20250825.dist-info → pyg_nightly-2.7.0.dev20250827.dist-info}/METADATA +1 -1
- {pyg_nightly-2.7.0.dev20250825.dist-info → pyg_nightly-2.7.0.dev20250827.dist-info}/RECORD +13 -12
- torch_geometric/__init__.py +3 -2
- torch_geometric/_onnx.py +214 -0
- torch_geometric/loader/link_neighbor_loader.py +1 -0
- torch_geometric/nn/models/__init__.py +3 -0
- torch_geometric/nn/models/lpformer.py +783 -0
- torch_geometric/sampler/__init__.py +2 -1
- torch_geometric/sampler/base.py +258 -3
- torch_geometric/sampler/neighbor_sampler.py +283 -13
- torch_geometric/sampler/utils.py +48 -5
- {pyg_nightly-2.7.0.dev20250825.dist-info → pyg_nightly-2.7.0.dev20250827.dist-info}/WHEEL +0 -0
- {pyg_nightly-2.7.0.dev20250825.dist-info → pyg_nightly-2.7.0.dev20250827.dist-info}/licenses/LICENSE +0 -0
@@ -3,7 +3,7 @@ r"""Graph sampler package."""
|
|
3
3
|
from .base import (BaseSampler, NodeSamplerInput, EdgeSamplerInput,
|
4
4
|
SamplerOutput, HeteroSamplerOutput, NegativeSampling,
|
5
5
|
NumNeighbors)
|
6
|
-
from .neighbor_sampler import NeighborSampler
|
6
|
+
from .neighbor_sampler import NeighborSampler, BidirectionalNeighborSampler
|
7
7
|
from .hgt_sampler import HGTSampler
|
8
8
|
|
9
9
|
__all__ = classes = [
|
@@ -15,5 +15,6 @@ __all__ = classes = [
|
|
15
15
|
'NumNeighbors',
|
16
16
|
'NegativeSampling',
|
17
17
|
'NeighborSampler',
|
18
|
+
'BidirectionalNeighborSampler',
|
18
19
|
'HGTSampler',
|
19
20
|
]
|
torch_geometric/sampler/base.py
CHANGED
@@ -3,7 +3,7 @@ import math
|
|
3
3
|
import warnings
|
4
4
|
from abc import ABC, abstractmethod
|
5
5
|
from collections import defaultdict
|
6
|
-
from dataclasses import dataclass
|
6
|
+
from dataclasses import dataclass, field
|
7
7
|
from enum import Enum
|
8
8
|
from typing import Any, Dict, List, Literal, Optional, Union
|
9
9
|
|
@@ -12,8 +12,10 @@ from torch import Tensor
|
|
12
12
|
|
13
13
|
from torch_geometric.data import Data, FeatureStore, GraphStore, HeteroData
|
14
14
|
from torch_geometric.sampler.utils import (
|
15
|
+
global_to_local_node_idx,
|
15
16
|
local_to_global_node_idx,
|
16
17
|
to_bidirectional,
|
18
|
+
unique_unsorted,
|
17
19
|
)
|
18
20
|
from torch_geometric.typing import EdgeType, EdgeTypeStr, NodeType, OptTensor
|
19
21
|
from torch_geometric.utils.mixin import CastMixin
|
@@ -209,6 +211,7 @@ class SamplerOutput(CastMixin):
|
|
209
211
|
# TODO(manan): refine this further; it does not currently define a proper
|
210
212
|
# API for the expected output of a sampler.
|
211
213
|
metadata: Optional[Any] = None
|
214
|
+
_seed_node: OptTensor = field(repr=False, default=None)
|
212
215
|
|
213
216
|
@property
|
214
217
|
def global_row(self) -> Tensor:
|
@@ -220,8 +223,17 @@ class SamplerOutput(CastMixin):
|
|
220
223
|
|
221
224
|
@property
|
222
225
|
def seed_node(self) -> Tensor:
|
223
|
-
|
224
|
-
|
226
|
+
# can be set manually if the seed nodes are not contained in the
|
227
|
+
# sampled nodes
|
228
|
+
if self._seed_node is None:
|
229
|
+
self._seed_node = local_to_global_node_idx(
|
230
|
+
self.node, self.batch) if self.batch is not None else None
|
231
|
+
return self._seed_node
|
232
|
+
|
233
|
+
@seed_node.setter
|
234
|
+
def seed_node(self, value: Tensor):
|
235
|
+
assert len(value) == len(self.node)
|
236
|
+
self._seed_node = value
|
225
237
|
|
226
238
|
@property
|
227
239
|
def global_orig_row(self) -> Tensor:
|
@@ -263,6 +275,230 @@ class SamplerOutput(CastMixin):
|
|
263
275
|
|
264
276
|
return out
|
265
277
|
|
278
|
+
@classmethod
|
279
|
+
def collate(cls, outputs: List['SamplerOutput'],
|
280
|
+
replace: bool = True) -> 'SamplerOutput':
|
281
|
+
r"""Collate a list of :class:`~torch_geometric.sampler.SamplerOutput`
|
282
|
+
objects into a single :class:`~torch_geometric.sampler.SamplerOutput`
|
283
|
+
object. Requires that they all have the same fields.
|
284
|
+
"""
|
285
|
+
if len(outputs) == 0:
|
286
|
+
raise ValueError("Cannot collate an empty list of SamplerOutputs")
|
287
|
+
out = outputs[0]
|
288
|
+
has_edge = out.edge is not None
|
289
|
+
has_orig_row = out.orig_row is not None
|
290
|
+
has_orig_col = out.orig_col is not None
|
291
|
+
has_batch = out.batch is not None
|
292
|
+
has_num_sampled_nodes = out.num_sampled_nodes is not None
|
293
|
+
has_num_sampled_edges = out.num_sampled_edges is not None
|
294
|
+
|
295
|
+
try:
|
296
|
+
for i, sample_output in enumerate(outputs): # noqa
|
297
|
+
assert not has_edge == (sample_output.edge is None)
|
298
|
+
assert not has_orig_row == (sample_output.orig_row is None)
|
299
|
+
assert not has_orig_col == (sample_output.orig_col is None)
|
300
|
+
assert not has_batch == (sample_output.batch is None)
|
301
|
+
assert not has_num_sampled_nodes == (
|
302
|
+
sample_output.num_sampled_nodes is None)
|
303
|
+
assert not has_num_sampled_edges == (
|
304
|
+
sample_output.num_sampled_edges is None)
|
305
|
+
except AssertionError:
|
306
|
+
error_str = f"Output {i+1} has a different field than the first output" # noqa
|
307
|
+
raise ValueError(error_str) # noqa
|
308
|
+
|
309
|
+
for other in outputs[1:]:
|
310
|
+
out = out.merge_with(other, replace=replace)
|
311
|
+
return out
|
312
|
+
|
313
|
+
def merge_with(self, other: 'SamplerOutput',
|
314
|
+
replace: bool = True) -> 'SamplerOutput':
|
315
|
+
"""Merges two SamplerOutputs.
|
316
|
+
If replace is True, self's nodes and edges take precedence.
|
317
|
+
"""
|
318
|
+
if not replace:
|
319
|
+
return SamplerOutput(
|
320
|
+
node=torch.cat([self.node, other.node], dim=0),
|
321
|
+
row=torch.cat([self.row, len(self.node) + other.row], dim=0),
|
322
|
+
col=torch.cat([self.col, len(self.node) + other.col], dim=0),
|
323
|
+
edge=torch.cat([self.edge, other.edge], dim=0)
|
324
|
+
if self.edge is not None and other.edge is not None else None,
|
325
|
+
batch=torch.cat(
|
326
|
+
[self.batch, len(self.node) + other.batch], dim=0) if
|
327
|
+
self.batch is not None and other.batch is not None else None,
|
328
|
+
num_sampled_nodes=self.num_sampled_nodes +
|
329
|
+
other.num_sampled_nodes if self.num_sampled_nodes is not None
|
330
|
+
and other.num_sampled_nodes is not None else None,
|
331
|
+
num_sampled_edges=self.num_sampled_edges +
|
332
|
+
other.num_sampled_edges if self.num_sampled_edges is not None
|
333
|
+
and other.num_sampled_edges is not None else None,
|
334
|
+
orig_row=torch.cat(
|
335
|
+
[self.orig_row,
|
336
|
+
len(self.node) +
|
337
|
+
other.orig_row], dim=0) if self.orig_row is not None
|
338
|
+
and other.orig_row is not None else None,
|
339
|
+
orig_col=torch.cat(
|
340
|
+
[self.orig_col,
|
341
|
+
len(self.node) +
|
342
|
+
other.orig_col], dim=0) if self.orig_col is not None
|
343
|
+
and other.orig_col is not None else None,
|
344
|
+
metadata=[self.metadata, other.metadata],
|
345
|
+
)
|
346
|
+
else:
|
347
|
+
|
348
|
+
# NODES
|
349
|
+
old_nodes, new_nodes = self.node, other.node
|
350
|
+
old_node_uid, new_node_uid = [old_nodes], [new_nodes]
|
351
|
+
|
352
|
+
# batch tracks disjoint subgraph samplings
|
353
|
+
if self.batch is not None and other.batch is not None:
|
354
|
+
# Transform the batch indices to be global node ids
|
355
|
+
old_batch_nodes = self.seed_node
|
356
|
+
new_batch_nodes = other.seed_node
|
357
|
+
old_node_uid.append(old_batch_nodes)
|
358
|
+
new_node_uid.append(new_batch_nodes)
|
359
|
+
|
360
|
+
# NOTE: if any new node fields are added,
|
361
|
+
# they need to be merged here
|
362
|
+
|
363
|
+
old_node_uid = torch.stack(old_node_uid, dim=1)
|
364
|
+
new_node_uid = torch.stack(new_node_uid, dim=1)
|
365
|
+
|
366
|
+
merged_node_uid = unique_unsorted(
|
367
|
+
torch.cat([old_node_uid, new_node_uid], dim=0))
|
368
|
+
num_old_nodes = old_node_uid.shape[0]
|
369
|
+
|
370
|
+
# Recompute num sampled nodes for second output,
|
371
|
+
# subtracting out nodes already seen in first output
|
372
|
+
merged_node_num_sampled_nodes = None
|
373
|
+
if (self.num_sampled_nodes is not None
|
374
|
+
and other.num_sampled_nodes is not None):
|
375
|
+
merged_node_num_sampled_nodes = copy.copy(
|
376
|
+
self.num_sampled_nodes)
|
377
|
+
curr_index = 0
|
378
|
+
# NOTE: There's an assumption here that no two nodes will be
|
379
|
+
# sampled twice in the same SampleOutput object
|
380
|
+
for minibatch in other.num_sampled_nodes:
|
381
|
+
size_of_intersect = torch.cat([
|
382
|
+
old_node_uid,
|
383
|
+
new_node_uid[curr_index:curr_index + minibatch]
|
384
|
+
]).unique(dim=0, sorted=False).shape[0] - num_old_nodes
|
385
|
+
merged_node_num_sampled_nodes.append(size_of_intersect)
|
386
|
+
curr_index += minibatch
|
387
|
+
|
388
|
+
merged_nodes = merged_node_uid[:, 0]
|
389
|
+
merged_batch = None
|
390
|
+
if self.batch is not None and other.batch is not None:
|
391
|
+
# Restore the batch indices to be relative to the nodes field
|
392
|
+
ref_merged_batch_nodes = merged_node_uid[:, 1].unsqueeze(
|
393
|
+
-1).expand(-1, 2) # num_nodes x 2
|
394
|
+
merged_batch = global_to_local_node_idx(
|
395
|
+
merged_node_uid, ref_merged_batch_nodes)
|
396
|
+
|
397
|
+
# EDGES
|
398
|
+
is_bidirectional = self.orig_row is not None \
|
399
|
+
and self.orig_col is not None \
|
400
|
+
and other.orig_row is not None \
|
401
|
+
and other.orig_col is not None
|
402
|
+
if is_bidirectional:
|
403
|
+
old_row, old_col = self.orig_row, self.orig_col
|
404
|
+
new_row, new_col = other.orig_row, other.orig_col
|
405
|
+
else:
|
406
|
+
old_row, old_col = self.row, self.col
|
407
|
+
new_row, new_col = other.row, other.col
|
408
|
+
|
409
|
+
# Transform the row and col indices to be global node ids
|
410
|
+
# instead of relative indices to nodes field
|
411
|
+
# Edge uids build off of node uids
|
412
|
+
old_row_idx, old_col_idx = local_to_global_node_idx(
|
413
|
+
old_node_uid,
|
414
|
+
old_row), local_to_global_node_idx(old_node_uid, old_col)
|
415
|
+
new_row_idx, new_col_idx = local_to_global_node_idx(
|
416
|
+
new_node_uid,
|
417
|
+
new_row), local_to_global_node_idx(new_node_uid, new_col)
|
418
|
+
|
419
|
+
old_edge_uid, new_edge_uid = [old_row_idx, old_col_idx
|
420
|
+
], [new_row_idx, new_col_idx]
|
421
|
+
|
422
|
+
row_idx = 0
|
423
|
+
col_idx = old_row_idx.shape[1]
|
424
|
+
edge_idx = old_row_idx.shape[1] + old_col_idx.shape[1]
|
425
|
+
|
426
|
+
if self.edge is not None and other.edge is not None:
|
427
|
+
if is_bidirectional:
|
428
|
+
# bidirectional duplicates edge ids
|
429
|
+
old_edge_uid_ref = torch.stack([self.row, self.col],
|
430
|
+
dim=1) # num_edges x 2
|
431
|
+
old_orig_edge_uid_ref = torch.stack(
|
432
|
+
[self.orig_row, self.orig_col],
|
433
|
+
dim=1) # num_orig_edges x 2
|
434
|
+
|
435
|
+
old_edge_idx = global_to_local_node_idx(
|
436
|
+
old_edge_uid_ref, old_orig_edge_uid_ref)
|
437
|
+
old_edge = self.edge[old_edge_idx]
|
438
|
+
|
439
|
+
new_edge_uid_ref = torch.stack([other.row, other.col],
|
440
|
+
dim=1) # num_edges x 2
|
441
|
+
new_orig_edge_uid_ref = torch.stack(
|
442
|
+
[other.orig_row, other.orig_col],
|
443
|
+
dim=1) # num_orig_edges x 2
|
444
|
+
|
445
|
+
new_edge_idx = global_to_local_node_idx(
|
446
|
+
new_edge_uid_ref, new_orig_edge_uid_ref)
|
447
|
+
new_edge = other.edge[new_edge_idx]
|
448
|
+
|
449
|
+
else:
|
450
|
+
old_edge, new_edge = self.edge, other.edge
|
451
|
+
|
452
|
+
old_edge_uid.append(old_edge.unsqueeze(-1))
|
453
|
+
new_edge_uid.append(new_edge.unsqueeze(-1))
|
454
|
+
|
455
|
+
old_edge_uid = torch.cat(old_edge_uid, dim=1)
|
456
|
+
new_edge_uid = torch.cat(new_edge_uid, dim=1)
|
457
|
+
|
458
|
+
merged_edge_uid = unique_unsorted(
|
459
|
+
torch.cat([old_edge_uid, new_edge_uid], dim=0))
|
460
|
+
num_old_edges = old_edge_uid.shape[0]
|
461
|
+
|
462
|
+
merged_edge_num_sampled_edges = None
|
463
|
+
if (self.num_sampled_edges is not None
|
464
|
+
and other.num_sampled_edges is not None):
|
465
|
+
merged_edge_num_sampled_edges = copy.copy(
|
466
|
+
self.num_sampled_edges)
|
467
|
+
curr_index = 0
|
468
|
+
# NOTE: There's an assumption here that no two edges will be
|
469
|
+
# sampled twice in the same SampleOutput object
|
470
|
+
for minibatch in other.num_sampled_edges:
|
471
|
+
size_of_intersect = torch.cat([
|
472
|
+
old_edge_uid,
|
473
|
+
new_edge_uid[curr_index:curr_index + minibatch]
|
474
|
+
]).unique(dim=0, sorted=False).shape[0] - num_old_edges
|
475
|
+
merged_edge_num_sampled_edges.append(size_of_intersect)
|
476
|
+
curr_index += minibatch
|
477
|
+
|
478
|
+
merged_row = merged_edge_uid[:, row_idx:col_idx]
|
479
|
+
merged_col = merged_edge_uid[:, col_idx:edge_idx]
|
480
|
+
merged_edge = merged_edge_uid[:, edge_idx:].squeeze() \
|
481
|
+
if self.edge is not None and other.edge is not None else None
|
482
|
+
|
483
|
+
# restore to row and col indices relative to nodes field
|
484
|
+
merged_row = global_to_local_node_idx(merged_node_uid, merged_row)
|
485
|
+
merged_col = global_to_local_node_idx(merged_node_uid, merged_col)
|
486
|
+
|
487
|
+
out = SamplerOutput(
|
488
|
+
node=merged_nodes,
|
489
|
+
row=merged_row,
|
490
|
+
col=merged_col,
|
491
|
+
edge=merged_edge,
|
492
|
+
batch=merged_batch,
|
493
|
+
num_sampled_nodes=merged_node_num_sampled_nodes,
|
494
|
+
num_sampled_edges=merged_edge_num_sampled_edges,
|
495
|
+
metadata=[self.metadata, other.metadata],
|
496
|
+
)
|
497
|
+
# Restores orig_row and orig_col if they existed before merging
|
498
|
+
if is_bidirectional:
|
499
|
+
out = out.to_bidirectional(keep_orig_edges=True)
|
500
|
+
return out
|
501
|
+
|
266
502
|
|
267
503
|
@dataclass
|
268
504
|
class HeteroSamplerOutput(CastMixin):
|
@@ -439,6 +675,25 @@ class HeteroSamplerOutput(CastMixin):
|
|
439
675
|
|
440
676
|
return out
|
441
677
|
|
678
|
+
@classmethod
|
679
|
+
def collate(cls, outputs: List['HeteroSamplerOutput'],
|
680
|
+
replace: bool = True) -> 'HeteroSamplerOutput':
|
681
|
+
r"""Collate a list of
|
682
|
+
:class:`~torch_geometric.sampler.HeteroSamplerOutput`objects into a
|
683
|
+
single :class:`~torch_geometric.sampler.HeteroSamplerOutput` object.
|
684
|
+
Requires that they all have the same fields.
|
685
|
+
"""
|
686
|
+
# TODO(zaristei)
|
687
|
+
raise NotImplementedError
|
688
|
+
|
689
|
+
def merge_with(self, other: 'HeteroSamplerOutput',
|
690
|
+
replace: bool = True) -> 'HeteroSamplerOutput':
|
691
|
+
"""Merges two HeteroSamplerOutputs.
|
692
|
+
If replace is True, self's nodes and edges take precedence.
|
693
|
+
"""
|
694
|
+
# TODO(zaristei)
|
695
|
+
raise NotImplementedError
|
696
|
+
|
442
697
|
|
443
698
|
@dataclass(frozen=True)
|
444
699
|
class NumNeighbors:
|