pyg-nightly 2.7.0.dev20250407__py3-none-any.whl → 2.7.0.dev20250408__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pyg_nightly-2.7.0.dev20250407.dist-info → pyg_nightly-2.7.0.dev20250408.dist-info}/METADATA +10 -22
- {pyg_nightly-2.7.0.dev20250407.dist-info → pyg_nightly-2.7.0.dev20250408.dist-info}/RECORD +10 -10
- torch_geometric/__init__.py +1 -1
- torch_geometric/distributed/rpc.py +1 -1
- torch_geometric/metrics/link_pred.py +10 -1
- torch_geometric/nn/models/gpse.py +12 -27
- torch_geometric/transforms/rooted_subgraph.py +1 -1
- torch_geometric/utils/_negative_sampling.py +2 -3
- {pyg_nightly-2.7.0.dev20250407.dist-info → pyg_nightly-2.7.0.dev20250408.dist-info}/WHEEL +0 -0
- {pyg_nightly-2.7.0.dev20250407.dist-info → pyg_nightly-2.7.0.dev20250408.dist-info}/licenses/LICENSE +0 -0
{pyg_nightly-2.7.0.dev20250407.dist-info → pyg_nightly-2.7.0.dev20250408.dist-info}/METADATA
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: pyg-nightly
|
3
|
-
Version: 2.7.0.
|
3
|
+
Version: 2.7.0.dev20250408
|
4
4
|
Summary: Graph Neural Network Library for PyTorch
|
5
5
|
Keywords: deep-learning,pytorch,geometric-deep-learning,graph-neural-networks,graph-convolutional-networks
|
6
6
|
Author-email: Matthias Fey <matthias@pyg.org>
|
@@ -413,18 +413,6 @@ These approaches have been implemented in PyG, and can benefit from the above GN
|
|
413
413
|
|
414
414
|
PyG is available for Python 3.9 to Python 3.12.
|
415
415
|
|
416
|
-
### Anaconda
|
417
|
-
|
418
|
-
You can now install PyG via [Anaconda](https://anaconda.org/pyg/pyg) for all major OS/PyTorch/CUDA combinations 🤗
|
419
|
-
If you have not yet installed PyTorch, install it via `conda` as described in the [official PyTorch documentation](https://pytorch.org/get-started/locally/).
|
420
|
-
Given that you have PyTorch installed (`>=1.8.0`), simply run
|
421
|
-
|
422
|
-
```
|
423
|
-
conda install pyg -c pyg
|
424
|
-
```
|
425
|
-
|
426
|
-
### PyPi
|
427
|
-
|
428
416
|
From **PyG 2.3** onwards, you can install and use PyG **without any external library** required except for PyTorch.
|
429
417
|
For this, simply run
|
430
418
|
|
@@ -448,28 +436,28 @@ We recommend to start with a minimal installation, and install additional depend
|
|
448
436
|
|
449
437
|
For ease of installation of these extensions, we provide `pip` wheels for all major OS/PyTorch/CUDA combinations, see [here](https://data.pyg.org/whl).
|
450
438
|
|
451
|
-
#### PyTorch 2.
|
439
|
+
#### PyTorch 2.6
|
452
440
|
|
453
|
-
To install the binaries for PyTorch 2.
|
441
|
+
To install the binaries for PyTorch 2.6.0, simply run
|
454
442
|
|
455
443
|
```
|
456
|
-
pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.
|
444
|
+
pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.6.0+${CUDA}.html
|
457
445
|
```
|
458
446
|
|
459
|
-
where `${CUDA}` should be replaced by either `cpu`, `cu118`, `
|
447
|
+
where `${CUDA}` should be replaced by either `cpu`, `cu118`, `cu124`, or `cu126` depending on your PyTorch installation.
|
460
448
|
|
461
|
-
| | `cpu` | `cu118` | `
|
449
|
+
| | `cpu` | `cu118` | `cu124` | `cu126` |
|
462
450
|
| ----------- | ----- | ------- | ------- | ------- |
|
463
451
|
| **Linux** | ✅ | ✅ | ✅ | ✅ |
|
464
452
|
| **Windows** | ✅ | ✅ | ✅ | ✅ |
|
465
453
|
| **macOS** | ✅ | | | |
|
466
454
|
|
467
|
-
#### PyTorch 2.
|
455
|
+
#### PyTorch 2.5
|
468
456
|
|
469
|
-
To install the binaries for PyTorch 2.
|
457
|
+
To install the binaries for PyTorch 2.5.0/2.5.1, simply run
|
470
458
|
|
471
459
|
```
|
472
|
-
pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.
|
460
|
+
pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.5.0+${CUDA}.html
|
473
461
|
```
|
474
462
|
|
475
463
|
where `${CUDA}` should be replaced by either `cpu`, `cu118`, `cu121`, or `cu124` depending on your PyTorch installation.
|
@@ -480,7 +468,7 @@ where `${CUDA}` should be replaced by either `cpu`, `cu118`, `cu121`, or `cu124`
|
|
480
468
|
| **Windows** | ✅ | ✅ | ✅ | ✅ |
|
481
469
|
| **macOS** | ✅ | | | |
|
482
470
|
|
483
|
-
**Note:** Binaries of older versions are also provided for PyTorch 1.4.0, PyTorch 1.5.0, PyTorch 1.6.0, PyTorch 1.7.0/1.7.1, PyTorch 1.8.0/1.8.1, PyTorch 1.9.0, PyTorch 1.10.0/1.10.1/1.10.2, PyTorch 1.11.0, PyTorch 1.12.0/1.12.1, PyTorch 1.13.0/1.13.1, PyTorch 2.0.0/2.0.1, PyTorch 2.1.0/2.1.1/2.1.2, PyTorch 2.2.0/2.2.1/2.2.2,
|
471
|
+
**Note:** Binaries of older versions are also provided for PyTorch 1.4.0, PyTorch 1.5.0, PyTorch 1.6.0, PyTorch 1.7.0/1.7.1, PyTorch 1.8.0/1.8.1, PyTorch 1.9.0, PyTorch 1.10.0/1.10.1/1.10.2, PyTorch 1.11.0, PyTorch 1.12.0/1.12.1, PyTorch 1.13.0/1.13.1, PyTorch 2.0.0/2.0.1, PyTorch 2.1.0/2.1.1/2.1.2, PyTorch 2.2.0/2.2.1/2.2.2, PyTorch 2.3.0/2.3.1, and PyTorch 2.4.0/2.4.1 (following the same procedure).
|
484
472
|
**For older versions, you might need to explicitly specify the latest supported version number** or install via `pip install --no-index` in order to prevent a manual installation from source.
|
485
473
|
You can look up the latest supported version number [here](https://data.pyg.org/whl).
|
486
474
|
|
@@ -1,4 +1,4 @@
|
|
1
|
-
torch_geometric/__init__.py,sha256=
|
1
|
+
torch_geometric/__init__.py,sha256=rucRaUNeMS0eQLE0b6LFcLimKD5JwNS4_nYsmtC3BMU,1978
|
2
2
|
torch_geometric/_compile.py,sha256=f-WQeH4VLi5Hn9lrgztFUCSrN_FImjhQa6BxFzcYC38,1338
|
3
3
|
torch_geometric/_onnx.py,sha256=V9ffrIKSqhDw6xUZ12lkuSfNs48cQp2EeJ6Z19GfnVw,349
|
4
4
|
torch_geometric/backend.py,sha256=lVaf7aLoVaB3M-UcByUJ1G4T4FOK6LXAg0CF4W3E8jo,1575
|
@@ -188,7 +188,7 @@ torch_geometric/distributed/event_loop.py,sha256=wr3iwMYEWOGkBlvC5huD2k5YxisaGE9
|
|
188
188
|
torch_geometric/distributed/local_feature_store.py,sha256=CLW37RN0ouDufEs2tY9d2nLLvpxubiT6zgW3LIHAU8k,19058
|
189
189
|
torch_geometric/distributed/local_graph_store.py,sha256=wNHXSS824Kk2HynbtWFXx-W4pl97UUBv6qFHAVqO3W4,8445
|
190
190
|
torch_geometric/distributed/partition.py,sha256=X0BleuY0ROlUtVXKvvz814pJwglZQ2_6OiMi1K0Hhvo,14731
|
191
|
-
torch_geometric/distributed/rpc.py,sha256=
|
191
|
+
torch_geometric/distributed/rpc.py,sha256=rJqiVR6Vbb2mpyVSC0Y5tPApqP-b1ck1Uq3IQpCsNSw,5737
|
192
192
|
torch_geometric/distributed/utils.py,sha256=FGrr3qw7hx7EQaIjjqasurloCFJ9q_0jt8jdSIUjBeM,6567
|
193
193
|
torch_geometric/explain/__init__.py,sha256=pRxVB33zsxhED1StRWdHboQWh3e06__g9N298Hzi42Y,359
|
194
194
|
torch_geometric/explain/config.py,sha256=_0j67NAwPwjrWHPncNywCT-oKyMiryJNxufxVN1BFlM,7834
|
@@ -290,7 +290,7 @@ torch_geometric/loader/temporal_dataloader.py,sha256=AQ2QFeiXKbPp6I8sUeE8H7br-1_
|
|
290
290
|
torch_geometric/loader/utils.py,sha256=f27mczQ7fEP2HpTsJGJxKS0slPu0j8zTba3jP8ViNck,14901
|
291
291
|
torch_geometric/loader/zip_loader.py,sha256=3lt10fD15Rxm1WhWzypswGzCEwUz4h8OLCD1nE15yNg,3843
|
292
292
|
torch_geometric/metrics/__init__.py,sha256=3krvDobW6vV5yHTjq2S2pmOXxNfysNG26muq7z48e94,699
|
293
|
-
torch_geometric/metrics/link_pred.py,sha256=
|
293
|
+
torch_geometric/metrics/link_pred.py,sha256=dtaI39JB-WqE1B-raiElns6xySRwmkbb9izbcyt6xHI,30886
|
294
294
|
torch_geometric/nn/__init__.py,sha256=kQHHHUxFDht2ztD-XFQuv98TvC8MdodaFsIjAvltJBw,874
|
295
295
|
torch_geometric/nn/data_parallel.py,sha256=lDAxRi83UNuzAQSj3eu9K2sQheOIU6wqR5elS6oDs90,4764
|
296
296
|
torch_geometric/nn/encoding.py,sha256=QNjwWczYExZ1wRGBmpuqYbn6tB7NC4BU-DEgzjhcZqw,3115
|
@@ -441,7 +441,7 @@ torch_geometric/nn/models/g_retriever.py,sha256=CdSOasnPiMvq5AjduNTpz-LIZiNp3X0x
|
|
441
441
|
torch_geometric/nn/models/git_mol.py,sha256=Wc6Hx6RDDR7sDWRWHfA5eK9e9gFsrTZ9OLmpMfoj3pE,12676
|
442
442
|
torch_geometric/nn/models/glem.py,sha256=sT0XM4klVlci9wduvUoXupATUw9p25uXtaJBrmv3yvs,16431
|
443
443
|
torch_geometric/nn/models/gnnff.py,sha256=15dkiLgy0LmH1hnUrpeoHioIp4BPTfjpVATpnGRt9E0,7860
|
444
|
-
torch_geometric/nn/models/gpse.py,sha256=
|
444
|
+
torch_geometric/nn/models/gpse.py,sha256=Fwldw9N3axV--BcSnI4im1sy1r87a5mAXZAXHu_6k2Y,41932
|
445
445
|
torch_geometric/nn/models/graph_mixer.py,sha256=mthMeCOikR8gseEsu4oJ3Cd9C35zHSv1p32ROwnG-6s,9246
|
446
446
|
torch_geometric/nn/models/graph_unet.py,sha256=N8TSmJo8AlbZjjcame0xW_jZvMOirL5ahw6qv5Yjpbs,5586
|
447
447
|
torch_geometric/nn/models/jumping_knowledge.py,sha256=9JR2EoViXKjcDSLb8tjJm-UHfv1mQCJvZAAEsYa0Ocw,5496
|
@@ -571,7 +571,7 @@ torch_geometric/transforms/remove_duplicated_edges.py,sha256=EMX6E1R_gXFoXwBqMqt
|
|
571
571
|
torch_geometric/transforms/remove_isolated_nodes.py,sha256=Q89b73es1tPsAmTdS7tWTIM7JcPUpL37v3EZTAd25Fc,2449
|
572
572
|
torch_geometric/transforms/remove_self_loops.py,sha256=JfoooSnTO2KPXXuC3KWGhLS0tqr4yiXOl3A0sVv2riM,1221
|
573
573
|
torch_geometric/transforms/remove_training_classes.py,sha256=GMCZwI_LYo2ZF29DABZXeuM0Sn2i3twx_V3KBUGu2As,932
|
574
|
-
torch_geometric/transforms/rooted_subgraph.py,sha256=
|
574
|
+
torch_geometric/transforms/rooted_subgraph.py,sha256=2P_2NSRcDNzfeR0zYScCaqZDkAvC0t3Ivr_lINO9abc,6504
|
575
575
|
torch_geometric/transforms/sample_points.py,sha256=UfD44528J7SKH0I2_4ELM1A4hKKHCIDeMV6UzBbQAVU,2280
|
576
576
|
torch_geometric/transforms/sign.py,sha256=bZUvUm9fMGXcYkI1GwPOW5ZC1QFu84vPBKpFnYxz2nA,2329
|
577
577
|
torch_geometric/transforms/spherical.py,sha256=nU7h4IFw69JqUwRqaweVEBegHZWPOHDoTYYVRMzIZ7U,2320
|
@@ -592,7 +592,7 @@ torch_geometric/utils/_grid.py,sha256=1coutST2TMV9TSQcmpXze0GIK9odzZ9wBtbKs6u26D
|
|
592
592
|
torch_geometric/utils/_homophily.py,sha256=1nXxGUATFPB3icEGpvEWUiuYbjU9gDGtlWpuLbtWhJk,5090
|
593
593
|
torch_geometric/utils/_index_sort.py,sha256=FTJacmOsqgsyof7MJFHlVVdXhHOjR0j7siTb0UZ-YT0,1283
|
594
594
|
torch_geometric/utils/_lexsort.py,sha256=chMEJJRXqfE6-K4vrVszdr3c338EhMZyi0Q9IEJD3p0,1403
|
595
|
-
torch_geometric/utils/_negative_sampling.py,sha256=
|
595
|
+
torch_geometric/utils/_negative_sampling.py,sha256=jxsmpryeoTT8qQrvIH11MgyhgoWzvqPGRAcVyU85VCU,15494
|
596
596
|
torch_geometric/utils/_normalize_edge_index.py,sha256=H6DY-Dzi1Psr3igG_nb0U3ZPNZz-BBDntO2iuA8FtzA,1682
|
597
597
|
torch_geometric/utils/_normalized_cut.py,sha256=uwVJkl-Q0tpY-w0nvcHajcQYcqFh1oDOf55XELdjJBU,1167
|
598
598
|
torch_geometric/utils/_one_hot.py,sha256=vXC7l7zudYRZIwWv6mT-Biuk2zKELyqteJXLynPocPM,1404
|
@@ -636,7 +636,7 @@ torch_geometric/utils/undirected.py,sha256=H_nfpI0_WluOG6VfjPyldvcjL4w5USAKWu2x5
|
|
636
636
|
torch_geometric/visualization/__init__.py,sha256=PyR_4K5SafsJrBr6qWrkjKr6GBL1b7FtZybyXCDEVwY,154
|
637
637
|
torch_geometric/visualization/graph.py,sha256=ZuLPL92yGRi7lxlqsUPwL_EVVXF7P2kMcveTtW79vpA,4784
|
638
638
|
torch_geometric/visualization/influence.py,sha256=CWMvuNA_Nf1sfbJmQgn58yS4OFpeKXeZPe7kEuvkUBw,477
|
639
|
-
pyg_nightly-2.7.0.
|
640
|
-
pyg_nightly-2.7.0.
|
641
|
-
pyg_nightly-2.7.0.
|
642
|
-
pyg_nightly-2.7.0.
|
639
|
+
pyg_nightly-2.7.0.dev20250408.dist-info/licenses/LICENSE,sha256=ic-27cMJc1kWoMEYncz3Ya3Ur2Bi3bNLWib2DT763-o,1067
|
640
|
+
pyg_nightly-2.7.0.dev20250408.dist-info/WHEEL,sha256=G2gURzTEtmeR8nrdXUJfNiB3VYVxigPQ-bEQujpNiNs,82
|
641
|
+
pyg_nightly-2.7.0.dev20250408.dist-info/METADATA,sha256=OX81Kse1GucsZCReO0E00CJl4rrtT9HhHEkNEbJU6KY,62652
|
642
|
+
pyg_nightly-2.7.0.dev20250408.dist-info/RECORD,,
|
torch_geometric/__init__.py
CHANGED
@@ -31,7 +31,7 @@ from .lazy_loader import LazyLoader
|
|
31
31
|
contrib = LazyLoader('contrib', globals(), 'torch_geometric.contrib')
|
32
32
|
graphgym = LazyLoader('graphgym', globals(), 'torch_geometric.graphgym')
|
33
33
|
|
34
|
-
__version__ = '2.7.0.
|
34
|
+
__version__ = '2.7.0.dev20250408'
|
35
35
|
|
36
36
|
__all__ = [
|
37
37
|
'Index',
|
@@ -144,7 +144,7 @@ _rpc_call_pool: Dict[int, RPCCallBase] = {}
|
|
144
144
|
@rpc_require_initialized
|
145
145
|
def rpc_register(call: RPCCallBase) -> int:
|
146
146
|
r"""Registers a call for RPC requests."""
|
147
|
-
global _rpc_call_id
|
147
|
+
global _rpc_call_id
|
148
148
|
|
149
149
|
with _rpc_call_lock:
|
150
150
|
call_id = _rpc_call_id
|
@@ -670,6 +670,9 @@ class LinkPredDiversity(_LinkPredMetric):
|
|
670
670
|
def __init__(self, k: int, category: Tensor) -> None:
|
671
671
|
super().__init__(k)
|
672
672
|
|
673
|
+
self.accum: Tensor
|
674
|
+
self.total: Tensor
|
675
|
+
|
673
676
|
if WITH_TORCHMETRICS:
|
674
677
|
self.add_state('accum', torch.tensor(0.), dist_reduce_fx='sum')
|
675
678
|
self.add_state('total', torch.tensor(0), dist_reduce_fx='sum')
|
@@ -736,11 +739,14 @@ class LinkPredPersonalization(_LinkPredMetric):
|
|
736
739
|
self.max_src_nodes = max_src_nodes
|
737
740
|
self.batch_size = batch_size
|
738
741
|
|
742
|
+
self.preds: List[Tensor]
|
743
|
+
self.total: Tensor
|
744
|
+
|
739
745
|
if WITH_TORCHMETRICS:
|
740
746
|
self.add_state('preds', default=[], dist_reduce_fx='cat')
|
741
747
|
self.add_state('total', torch.tensor(0), dist_reduce_fx='sum')
|
742
748
|
else:
|
743
|
-
self.preds
|
749
|
+
self.preds = []
|
744
750
|
self.register_buffer('total', torch.tensor(0), persistent=False)
|
745
751
|
|
746
752
|
def update(
|
@@ -826,6 +832,9 @@ class LinkPredAveragePopularity(_LinkPredMetric):
|
|
826
832
|
def __init__(self, k: int, popularity: Tensor) -> None:
|
827
833
|
super().__init__(k)
|
828
834
|
|
835
|
+
self.accum: Tensor
|
836
|
+
self.total: Tensor
|
837
|
+
|
829
838
|
if WITH_TORCHMETRICS:
|
830
839
|
self.add_state('accum', torch.tensor(0.), dist_reduce_fx='sum')
|
831
840
|
self.add_state('total', torch.tensor(0), dist_reduce_fx='sum')
|
@@ -379,7 +379,7 @@ class GPSE(torch.nn.Module):
|
|
379
379
|
|
380
380
|
.. code-block:: python
|
381
381
|
|
382
|
-
from torch_geometric.nn import GPSE, GPSENodeEncoder
|
382
|
+
from torch_geometric.nn import GPSE, GPSENodeEncoder
|
383
383
|
from torch_geometric.transforms import AddGPSE
|
384
384
|
from torch_geometric.nn.models.gpse import precompute_GPSE
|
385
385
|
|
@@ -417,13 +417,11 @@ class GPSE(torch.nn.Module):
|
|
417
417
|
|
418
418
|
encoder = GPSENodeEncoder(dim_emb=128, dim_pe_in=32, dim_pe_out=64,
|
419
419
|
expand_x=False)
|
420
|
-
gnn = GNN(
|
420
|
+
gnn = GNN(...)
|
421
421
|
|
422
422
|
for batch in loader:
|
423
|
-
|
424
|
-
|
425
|
-
# Do something with the batch, which now includes 128-dimensional
|
426
|
-
# node representations
|
423
|
+
x = encoder(batch.x, batch.pestat_GPSE)
|
424
|
+
out = gnn(x, batch.edge_index)
|
427
425
|
|
428
426
|
|
429
427
|
Args:
|
@@ -571,8 +569,7 @@ class GPSE(torch.nn.Module):
|
|
571
569
|
self.reset_parameters()
|
572
570
|
|
573
571
|
def reset_parameters(self):
|
574
|
-
|
575
|
-
self.apply(init_weights)
|
572
|
+
pass
|
576
573
|
|
577
574
|
@classmethod
|
578
575
|
def from_pretrained(cls, name: str, root: str = 'GPSE_pretrained'):
|
@@ -608,6 +605,7 @@ class GPSE(torch.nn.Module):
|
|
608
605
|
return model
|
609
606
|
|
610
607
|
def forward(self, batch):
|
608
|
+
batch = batch.clone()
|
611
609
|
for module in self.children():
|
612
610
|
batch = module(batch)
|
613
611
|
return batch
|
@@ -635,13 +633,11 @@ class GPSENodeEncoder(torch.nn.Module):
|
|
635
633
|
|
636
634
|
encoder = GPSENodeEncoder(dim_emb=128, dim_pe_in=32, dim_pe_out=64,
|
637
635
|
expand_x=False)
|
638
|
-
gnn = GNN(
|
636
|
+
gnn = GNN(...)
|
639
637
|
|
640
638
|
for batch in loader:
|
641
|
-
|
642
|
-
batch = gnn(batch)
|
643
|
-
# Do something with the batch, which now includes 128-dimensional
|
644
|
-
# node representations
|
639
|
+
x = encoder(batch.x, batch.pestat_GPSE)
|
640
|
+
batch = gnn(x, batch.edge_index)
|
645
641
|
|
646
642
|
Args:
|
647
643
|
dim_emb (int): Size of final node embedding.
|
@@ -705,28 +701,17 @@ class GPSENodeEncoder(torch.nn.Module):
|
|
705
701
|
raise ValueError(f"{self.__class__.__name__}: Does not support "
|
706
702
|
f"'{model_type}' encoder model.")
|
707
703
|
|
708
|
-
def forward(self,
|
709
|
-
if not hasattr(batch, 'pestat_GPSE'):
|
710
|
-
raise ValueError('Precomputed "pestat_GPSE" variable is required '
|
711
|
-
'for GNNNodeEncoder; either run '
|
712
|
-
'`precompute_GPSE(gpse_model, dataset)` on your '
|
713
|
-
'dataset or add `AddGPSE(gpse_model)` as a (pre) '
|
714
|
-
'transform.')
|
715
|
-
|
716
|
-
pos_enc = batch.pestat_GPSE
|
717
|
-
|
704
|
+
def forward(self, x, pos_enc):
|
718
705
|
pos_enc = self.dropout_be(pos_enc)
|
719
706
|
pos_enc = self.raw_norm(pos_enc) if self.raw_norm else pos_enc
|
720
707
|
pos_enc = self.pe_encoder(pos_enc) # (Num nodes) x dim_pe
|
721
708
|
pos_enc = self.dropout_ae(pos_enc)
|
722
709
|
|
723
710
|
# Expand node features if needed
|
724
|
-
h = self.linear_x(
|
711
|
+
h = self.linear_x(x) if self.expand_x else x
|
725
712
|
|
726
713
|
# Concatenate final PEs to input embedding
|
727
|
-
|
728
|
-
|
729
|
-
return batch
|
714
|
+
return torch.cat((h, pos_enc), 1)
|
730
715
|
|
731
716
|
|
732
717
|
@torch.no_grad()
|
@@ -94,7 +94,7 @@ class RootedSubgraph(BaseTransform, ABC):
|
|
94
94
|
arange = torch.arange(n_id.size(0), device=data.edge_index.device)
|
95
95
|
node_map = data.edge_index.new_ones(num_nodes, num_nodes)
|
96
96
|
node_map[n_sub_batch, n_id] = arange
|
97
|
-
sub_edge_index += (arange *
|
97
|
+
sub_edge_index += (arange * num_nodes)[e_sub_batch]
|
98
98
|
sub_edge_index = node_map.view(-1)[sub_edge_index]
|
99
99
|
|
100
100
|
return sub_edge_index, n_id, e_id, n_sub_batch, e_sub_batch
|
@@ -109,10 +109,9 @@ def negative_sampling(
|
|
109
109
|
idx = idx.to('cpu')
|
110
110
|
for _ in range(3): # Number of tries to sample negative indices.
|
111
111
|
rnd = sample(population, sample_size, device='cpu')
|
112
|
-
mask = np.isin(rnd.numpy(), idx.numpy())
|
112
|
+
mask = torch.from_numpy(np.isin(rnd.numpy(), idx.numpy())).bool()
|
113
113
|
if neg_idx is not None:
|
114
|
-
mask |= np.isin(rnd, neg_idx.
|
115
|
-
mask = torch.from_numpy(mask).to(torch.bool)
|
114
|
+
mask |= torch.from_numpy(np.isin(rnd, neg_idx.cpu())).bool()
|
116
115
|
rnd = rnd[~mask].to(edge_index.device)
|
117
116
|
neg_idx = rnd if neg_idx is None else torch.cat([neg_idx, rnd])
|
118
117
|
if neg_idx.numel() >= num_neg_samples:
|
File without changes
|
{pyg_nightly-2.7.0.dev20250407.dist-info → pyg_nightly-2.7.0.dev20250408.dist-info}/licenses/LICENSE
RENAMED
File without changes
|