pyg-nightly 2.7.0.dev20250826__py3-none-any.whl → 2.7.0.dev20250828__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,7 +3,7 @@ r"""Graph sampler package."""
3
3
  from .base import (BaseSampler, NodeSamplerInput, EdgeSamplerInput,
4
4
  SamplerOutput, HeteroSamplerOutput, NegativeSampling,
5
5
  NumNeighbors)
6
- from .neighbor_sampler import NeighborSampler
6
+ from .neighbor_sampler import NeighborSampler, BidirectionalNeighborSampler
7
7
  from .hgt_sampler import HGTSampler
8
8
 
9
9
  __all__ = classes = [
@@ -15,5 +15,6 @@ __all__ = classes = [
15
15
  'NumNeighbors',
16
16
  'NegativeSampling',
17
17
  'NeighborSampler',
18
+ 'BidirectionalNeighborSampler',
18
19
  'HGTSampler',
19
20
  ]
@@ -3,7 +3,7 @@ import math
3
3
  import warnings
4
4
  from abc import ABC, abstractmethod
5
5
  from collections import defaultdict
6
- from dataclasses import dataclass
6
+ from dataclasses import dataclass, field
7
7
  from enum import Enum
8
8
  from typing import Any, Dict, List, Literal, Optional, Union
9
9
 
@@ -12,8 +12,10 @@ from torch import Tensor
12
12
 
13
13
  from torch_geometric.data import Data, FeatureStore, GraphStore, HeteroData
14
14
  from torch_geometric.sampler.utils import (
15
+ global_to_local_node_idx,
15
16
  local_to_global_node_idx,
16
17
  to_bidirectional,
18
+ unique_unsorted,
17
19
  )
18
20
  from torch_geometric.typing import EdgeType, EdgeTypeStr, NodeType, OptTensor
19
21
  from torch_geometric.utils.mixin import CastMixin
@@ -209,6 +211,7 @@ class SamplerOutput(CastMixin):
209
211
  # TODO(manan): refine this further; it does not currently define a proper
210
212
  # API for the expected output of a sampler.
211
213
  metadata: Optional[Any] = None
214
+ _seed_node: OptTensor = field(repr=False, default=None)
212
215
 
213
216
  @property
214
217
  def global_row(self) -> Tensor:
@@ -220,8 +223,17 @@ class SamplerOutput(CastMixin):
220
223
 
221
224
  @property
222
225
  def seed_node(self) -> Tensor:
223
- return local_to_global_node_idx(
224
- self.node, self.batch) if self.batch is not None else None
226
+ # can be set manually if the seed nodes are not contained in the
227
+ # sampled nodes
228
+ if self._seed_node is None:
229
+ self._seed_node = local_to_global_node_idx(
230
+ self.node, self.batch) if self.batch is not None else None
231
+ return self._seed_node
232
+
233
+ @seed_node.setter
234
+ def seed_node(self, value: Tensor):
235
+ assert len(value) == len(self.node)
236
+ self._seed_node = value
225
237
 
226
238
  @property
227
239
  def global_orig_row(self) -> Tensor:
@@ -263,6 +275,230 @@ class SamplerOutput(CastMixin):
263
275
 
264
276
  return out
265
277
 
278
+ @classmethod
279
+ def collate(cls, outputs: List['SamplerOutput'],
280
+ replace: bool = True) -> 'SamplerOutput':
281
+ r"""Collate a list of :class:`~torch_geometric.sampler.SamplerOutput`
282
+ objects into a single :class:`~torch_geometric.sampler.SamplerOutput`
283
+ object. Requires that they all have the same fields.
284
+ """
285
+ if len(outputs) == 0:
286
+ raise ValueError("Cannot collate an empty list of SamplerOutputs")
287
+ out = outputs[0]
288
+ has_edge = out.edge is not None
289
+ has_orig_row = out.orig_row is not None
290
+ has_orig_col = out.orig_col is not None
291
+ has_batch = out.batch is not None
292
+ has_num_sampled_nodes = out.num_sampled_nodes is not None
293
+ has_num_sampled_edges = out.num_sampled_edges is not None
294
+
295
+ try:
296
+ for i, sample_output in enumerate(outputs): # noqa
297
+ assert not has_edge == (sample_output.edge is None)
298
+ assert not has_orig_row == (sample_output.orig_row is None)
299
+ assert not has_orig_col == (sample_output.orig_col is None)
300
+ assert not has_batch == (sample_output.batch is None)
301
+ assert not has_num_sampled_nodes == (
302
+ sample_output.num_sampled_nodes is None)
303
+ assert not has_num_sampled_edges == (
304
+ sample_output.num_sampled_edges is None)
305
+ except AssertionError:
306
+ error_str = f"Output {i+1} has a different field than the first output" # noqa
307
+ raise ValueError(error_str) # noqa
308
+
309
+ for other in outputs[1:]:
310
+ out = out.merge_with(other, replace=replace)
311
+ return out
312
+
313
+ def merge_with(self, other: 'SamplerOutput',
314
+ replace: bool = True) -> 'SamplerOutput':
315
+ """Merges two SamplerOutputs.
316
+ If replace is True, self's nodes and edges take precedence.
317
+ """
318
+ if not replace:
319
+ return SamplerOutput(
320
+ node=torch.cat([self.node, other.node], dim=0),
321
+ row=torch.cat([self.row, len(self.node) + other.row], dim=0),
322
+ col=torch.cat([self.col, len(self.node) + other.col], dim=0),
323
+ edge=torch.cat([self.edge, other.edge], dim=0)
324
+ if self.edge is not None and other.edge is not None else None,
325
+ batch=torch.cat(
326
+ [self.batch, len(self.node) + other.batch], dim=0) if
327
+ self.batch is not None and other.batch is not None else None,
328
+ num_sampled_nodes=self.num_sampled_nodes +
329
+ other.num_sampled_nodes if self.num_sampled_nodes is not None
330
+ and other.num_sampled_nodes is not None else None,
331
+ num_sampled_edges=self.num_sampled_edges +
332
+ other.num_sampled_edges if self.num_sampled_edges is not None
333
+ and other.num_sampled_edges is not None else None,
334
+ orig_row=torch.cat(
335
+ [self.orig_row,
336
+ len(self.node) +
337
+ other.orig_row], dim=0) if self.orig_row is not None
338
+ and other.orig_row is not None else None,
339
+ orig_col=torch.cat(
340
+ [self.orig_col,
341
+ len(self.node) +
342
+ other.orig_col], dim=0) if self.orig_col is not None
343
+ and other.orig_col is not None else None,
344
+ metadata=[self.metadata, other.metadata],
345
+ )
346
+ else:
347
+
348
+ # NODES
349
+ old_nodes, new_nodes = self.node, other.node
350
+ old_node_uid, new_node_uid = [old_nodes], [new_nodes]
351
+
352
+ # batch tracks disjoint subgraph samplings
353
+ if self.batch is not None and other.batch is not None:
354
+ # Transform the batch indices to be global node ids
355
+ old_batch_nodes = self.seed_node
356
+ new_batch_nodes = other.seed_node
357
+ old_node_uid.append(old_batch_nodes)
358
+ new_node_uid.append(new_batch_nodes)
359
+
360
+ # NOTE: if any new node fields are added,
361
+ # they need to be merged here
362
+
363
+ old_node_uid = torch.stack(old_node_uid, dim=1)
364
+ new_node_uid = torch.stack(new_node_uid, dim=1)
365
+
366
+ merged_node_uid = unique_unsorted(
367
+ torch.cat([old_node_uid, new_node_uid], dim=0))
368
+ num_old_nodes = old_node_uid.shape[0]
369
+
370
+ # Recompute num sampled nodes for second output,
371
+ # subtracting out nodes already seen in first output
372
+ merged_node_num_sampled_nodes = None
373
+ if (self.num_sampled_nodes is not None
374
+ and other.num_sampled_nodes is not None):
375
+ merged_node_num_sampled_nodes = copy.copy(
376
+ self.num_sampled_nodes)
377
+ curr_index = 0
378
+ # NOTE: There's an assumption here that no two nodes will be
379
+ # sampled twice in the same SampleOutput object
380
+ for minibatch in other.num_sampled_nodes:
381
+ size_of_intersect = torch.cat([
382
+ old_node_uid,
383
+ new_node_uid[curr_index:curr_index + minibatch]
384
+ ]).unique(dim=0, sorted=False).shape[0] - num_old_nodes
385
+ merged_node_num_sampled_nodes.append(size_of_intersect)
386
+ curr_index += minibatch
387
+
388
+ merged_nodes = merged_node_uid[:, 0]
389
+ merged_batch = None
390
+ if self.batch is not None and other.batch is not None:
391
+ # Restore the batch indices to be relative to the nodes field
392
+ ref_merged_batch_nodes = merged_node_uid[:, 1].unsqueeze(
393
+ -1).expand(-1, 2) # num_nodes x 2
394
+ merged_batch = global_to_local_node_idx(
395
+ merged_node_uid, ref_merged_batch_nodes)
396
+
397
+ # EDGES
398
+ is_bidirectional = self.orig_row is not None \
399
+ and self.orig_col is not None \
400
+ and other.orig_row is not None \
401
+ and other.orig_col is not None
402
+ if is_bidirectional:
403
+ old_row, old_col = self.orig_row, self.orig_col
404
+ new_row, new_col = other.orig_row, other.orig_col
405
+ else:
406
+ old_row, old_col = self.row, self.col
407
+ new_row, new_col = other.row, other.col
408
+
409
+ # Transform the row and col indices to be global node ids
410
+ # instead of relative indices to nodes field
411
+ # Edge uids build off of node uids
412
+ old_row_idx, old_col_idx = local_to_global_node_idx(
413
+ old_node_uid,
414
+ old_row), local_to_global_node_idx(old_node_uid, old_col)
415
+ new_row_idx, new_col_idx = local_to_global_node_idx(
416
+ new_node_uid,
417
+ new_row), local_to_global_node_idx(new_node_uid, new_col)
418
+
419
+ old_edge_uid, new_edge_uid = [old_row_idx, old_col_idx
420
+ ], [new_row_idx, new_col_idx]
421
+
422
+ row_idx = 0
423
+ col_idx = old_row_idx.shape[1]
424
+ edge_idx = old_row_idx.shape[1] + old_col_idx.shape[1]
425
+
426
+ if self.edge is not None and other.edge is not None:
427
+ if is_bidirectional:
428
+ # bidirectional duplicates edge ids
429
+ old_edge_uid_ref = torch.stack([self.row, self.col],
430
+ dim=1) # num_edges x 2
431
+ old_orig_edge_uid_ref = torch.stack(
432
+ [self.orig_row, self.orig_col],
433
+ dim=1) # num_orig_edges x 2
434
+
435
+ old_edge_idx = global_to_local_node_idx(
436
+ old_edge_uid_ref, old_orig_edge_uid_ref)
437
+ old_edge = self.edge[old_edge_idx]
438
+
439
+ new_edge_uid_ref = torch.stack([other.row, other.col],
440
+ dim=1) # num_edges x 2
441
+ new_orig_edge_uid_ref = torch.stack(
442
+ [other.orig_row, other.orig_col],
443
+ dim=1) # num_orig_edges x 2
444
+
445
+ new_edge_idx = global_to_local_node_idx(
446
+ new_edge_uid_ref, new_orig_edge_uid_ref)
447
+ new_edge = other.edge[new_edge_idx]
448
+
449
+ else:
450
+ old_edge, new_edge = self.edge, other.edge
451
+
452
+ old_edge_uid.append(old_edge.unsqueeze(-1))
453
+ new_edge_uid.append(new_edge.unsqueeze(-1))
454
+
455
+ old_edge_uid = torch.cat(old_edge_uid, dim=1)
456
+ new_edge_uid = torch.cat(new_edge_uid, dim=1)
457
+
458
+ merged_edge_uid = unique_unsorted(
459
+ torch.cat([old_edge_uid, new_edge_uid], dim=0))
460
+ num_old_edges = old_edge_uid.shape[0]
461
+
462
+ merged_edge_num_sampled_edges = None
463
+ if (self.num_sampled_edges is not None
464
+ and other.num_sampled_edges is not None):
465
+ merged_edge_num_sampled_edges = copy.copy(
466
+ self.num_sampled_edges)
467
+ curr_index = 0
468
+ # NOTE: There's an assumption here that no two edges will be
469
+ # sampled twice in the same SampleOutput object
470
+ for minibatch in other.num_sampled_edges:
471
+ size_of_intersect = torch.cat([
472
+ old_edge_uid,
473
+ new_edge_uid[curr_index:curr_index + minibatch]
474
+ ]).unique(dim=0, sorted=False).shape[0] - num_old_edges
475
+ merged_edge_num_sampled_edges.append(size_of_intersect)
476
+ curr_index += minibatch
477
+
478
+ merged_row = merged_edge_uid[:, row_idx:col_idx]
479
+ merged_col = merged_edge_uid[:, col_idx:edge_idx]
480
+ merged_edge = merged_edge_uid[:, edge_idx:].squeeze() \
481
+ if self.edge is not None and other.edge is not None else None
482
+
483
+ # restore to row and col indices relative to nodes field
484
+ merged_row = global_to_local_node_idx(merged_node_uid, merged_row)
485
+ merged_col = global_to_local_node_idx(merged_node_uid, merged_col)
486
+
487
+ out = SamplerOutput(
488
+ node=merged_nodes,
489
+ row=merged_row,
490
+ col=merged_col,
491
+ edge=merged_edge,
492
+ batch=merged_batch,
493
+ num_sampled_nodes=merged_node_num_sampled_nodes,
494
+ num_sampled_edges=merged_edge_num_sampled_edges,
495
+ metadata=[self.metadata, other.metadata],
496
+ )
497
+ # Restores orig_row and orig_col if they existed before merging
498
+ if is_bidirectional:
499
+ out = out.to_bidirectional(keep_orig_edges=True)
500
+ return out
501
+
266
502
 
267
503
  @dataclass
268
504
  class HeteroSamplerOutput(CastMixin):
@@ -439,6 +675,25 @@ class HeteroSamplerOutput(CastMixin):
439
675
 
440
676
  return out
441
677
 
678
+ @classmethod
679
+ def collate(cls, outputs: List['HeteroSamplerOutput'],
680
+ replace: bool = True) -> 'HeteroSamplerOutput':
681
+ r"""Collate a list of
682
+ :class:`~torch_geometric.sampler.HeteroSamplerOutput`objects into a
683
+ single :class:`~torch_geometric.sampler.HeteroSamplerOutput` object.
684
+ Requires that they all have the same fields.
685
+ """
686
+ # TODO(zaristei)
687
+ raise NotImplementedError
688
+
689
+ def merge_with(self, other: 'HeteroSamplerOutput',
690
+ replace: bool = True) -> 'HeteroSamplerOutput':
691
+ """Merges two HeteroSamplerOutputs.
692
+ If replace is True, self's nodes and edges take precedence.
693
+ """
694
+ # TODO(zaristei)
695
+ raise NotImplementedError
696
+
442
697
 
443
698
  @dataclass(frozen=True)
444
699
  class NumNeighbors: