pyg-nightly 2.7.0.dev20250503__py3-none-any.whl → 2.7.0.dev20250505__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pyg_nightly-2.7.0.dev20250503.dist-info → pyg_nightly-2.7.0.dev20250505.dist-info}/METADATA +1 -1
- {pyg_nightly-2.7.0.dev20250503.dist-info → pyg_nightly-2.7.0.dev20250505.dist-info}/RECORD +16 -16
- torch_geometric/__init__.py +8 -1
- torch_geometric/data/collate.py +1 -3
- torch_geometric/edge_index.py +2 -7
- torch_geometric/nn/dense/linear.py +0 -20
- torch_geometric/nn/pool/connect/base.py +1 -3
- torch_geometric/nn/pool/select/base.py +1 -4
- torch_geometric/typing.py +0 -2
- torch_geometric/utils/_lexsort.py +0 -9
- torch_geometric/utils/_scatter.py +129 -194
- torch_geometric/utils/_sort_edge_index.py +0 -2
- torch_geometric/utils/_spmm.py +1 -2
- torch_geometric/utils/sparse.py +4 -12
- {pyg_nightly-2.7.0.dev20250503.dist-info → pyg_nightly-2.7.0.dev20250505.dist-info}/WHEEL +0 -0
- {pyg_nightly-2.7.0.dev20250503.dist-info → pyg_nightly-2.7.0.dev20250505.dist-info}/licenses/LICENSE +0 -0
{pyg_nightly-2.7.0.dev20250503.dist-info → pyg_nightly-2.7.0.dev20250505.dist-info}/METADATA
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: pyg-nightly
|
3
|
-
Version: 2.7.0.
|
3
|
+
Version: 2.7.0.dev20250505
|
4
4
|
Summary: Graph Neural Network Library for PyTorch
|
5
5
|
Keywords: deep-learning,pytorch,geometric-deep-learning,graph-neural-networks,graph-convolutional-networks
|
6
6
|
Author-email: Matthias Fey <matthias@pyg.org>
|
@@ -1,4 +1,4 @@
|
|
1
|
-
torch_geometric/__init__.py,sha256=
|
1
|
+
torch_geometric/__init__.py,sha256=BiduxVS51etguzPOVckBI8kTSmBUcEA48YUl_PxIJlw,2255
|
2
2
|
torch_geometric/_compile.py,sha256=f-WQeH4VLi5Hn9lrgztFUCSrN_FImjhQa6BxFzcYC38,1338
|
3
3
|
torch_geometric/_onnx.py,sha256=V9ffrIKSqhDw6xUZ12lkuSfNs48cQp2EeJ6Z19GfnVw,349
|
4
4
|
torch_geometric/backend.py,sha256=lVaf7aLoVaB3M-UcByUJ1G4T4FOK6LXAg0CF4W3E8jo,1575
|
@@ -7,7 +7,7 @@ torch_geometric/config_store.py,sha256=zdMzlgBpUmBkPovpYQh5fMNwTZLDq2OneqX47QEx7
|
|
7
7
|
torch_geometric/debug.py,sha256=cLyH9OaL2v7POyW-80b19w-ctA7a_5EZsS4aUF1wc2U,1295
|
8
8
|
torch_geometric/deprecation.py,sha256=dWRymDIUkUVI2MeEmBG5WF4R6jObZeseSBV9G6FNfjc,858
|
9
9
|
torch_geometric/device.py,sha256=tU5-_lBNVbVHl_kUmWPwiG5mQ1pyapwMF4JkmtNN3MM,1224
|
10
|
-
torch_geometric/edge_index.py,sha256=
|
10
|
+
torch_geometric/edge_index.py,sha256=jSWrZ77qKKydVDxiXrsBlaoL6Qdems6-HiA_B_qDo2o,70078
|
11
11
|
torch_geometric/experimental.py,sha256=JbtNNEXjFGI8hZ9raM6-qrZURP6Z5nlDK8QicZUIbz0,4756
|
12
12
|
torch_geometric/hash_tensor.py,sha256=WB-aBCJWNWqnlnzQ8Ob4LHeCXm0u1_NPPhmNAEwBpq4,24906
|
13
13
|
torch_geometric/home.py,sha256=EV54B4Dmiv61GDbkCwtCfWGWJ4eFGwZ8s3KOgGjwYgY,790
|
@@ -19,7 +19,7 @@ torch_geometric/logging.py,sha256=HmHHLiCcM64k-6UYNOSfXPIeSGNAyiGGcn8cD8tlyuQ,85
|
|
19
19
|
torch_geometric/resolver.py,sha256=fn-_6mCpI2xv7eDZnIFcYrHOn0IrwbkWFLDb9laQrWI,1270
|
20
20
|
torch_geometric/seed.py,sha256=MJLbVwpb9i8mK3oi32sS__Cq-dRq_afTeoOL_HoA9ko,372
|
21
21
|
torch_geometric/template.py,sha256=rqjDWgcSAgTCiV4bkOjWRPaO4PpUdC_RXigzxxBqAu8,1060
|
22
|
-
torch_geometric/typing.py,sha256=
|
22
|
+
torch_geometric/typing.py,sha256=Ryx6oGoOsEh8rJ3O0O6j8O18ZPHkrIv-7dr-suQZa6Q,15486
|
23
23
|
torch_geometric/warnings.py,sha256=t114CbkrmiqkXaavx5g7OO52dLdktf-U__B5QqYIQvI,413
|
24
24
|
torch_geometric/contrib/__init__.py,sha256=0pWkmXfZtbdr-AKwlii5LTFggTEH-MCrSKpZxrtPlVs,352
|
25
25
|
torch_geometric/contrib/datasets/__init__.py,sha256=lrGnWsEiJf5zsBRmshGZZFN_uYR2ezDjbj9n9nCpvtk,23
|
@@ -32,7 +32,7 @@ torch_geometric/contrib/nn/models/rbcd_attack.py,sha256=qcyxBxAbx8LKzpp3RoJQ0cxl
|
|
32
32
|
torch_geometric/contrib/transforms/__init__.py,sha256=lrGnWsEiJf5zsBRmshGZZFN_uYR2ezDjbj9n9nCpvtk,23
|
33
33
|
torch_geometric/data/__init__.py,sha256=D6Iz5A9vEb_2rpf96Zn7uM-lchZ3WpW8X7WdAD1yxKw,4565
|
34
34
|
torch_geometric/data/batch.py,sha256=8X8CN4_1rjrh48R3R2--mZUgfsO7Po9JP-H6SbrBiBA,8740
|
35
|
-
torch_geometric/data/collate.py,sha256=
|
35
|
+
torch_geometric/data/collate.py,sha256=tOUvttXoEo-bOvJx_qMivJq2JqOsB9iDdjovtiyys4o,12644
|
36
36
|
torch_geometric/data/data.py,sha256=mp_jsjsaVwUcY-FghlqNZTHUQEKBdi7xWR_oA2ewrD4,43821
|
37
37
|
torch_geometric/data/database.py,sha256=VTct1xyzXsK0GZahBV9-noviCzjRteAsKMG7VgJ52n0,22998
|
38
38
|
torch_geometric/data/datapipes.py,sha256=9_Cq3j_7LIF4plQFzbLaqyy0LcpKdAic6yiKgMqSX9A,3083
|
@@ -415,7 +415,7 @@ torch_geometric/nn/dense/dense_graph_conv.py,sha256=_7y-EmyStVouGPyA2H5baufNZHwj
|
|
415
415
|
torch_geometric/nn/dense/dense_sage_conv.py,sha256=erfy0RAWOAkbRi8QXBVgkv37QeSo8XdcXYGYLZBgY7A,2672
|
416
416
|
torch_geometric/nn/dense/diff_pool.py,sha256=bHIKbfV8Fv36H611V8bDpT6ACTKx8d1-hKzDXm5dQ9g,3051
|
417
417
|
torch_geometric/nn/dense/dmon_pool.py,sha256=l4usDrjX4LAVcAU2jTXte4aUk8UeyDStlRHzKwwpi8s,6115
|
418
|
-
torch_geometric/nn/dense/linear.py,sha256=
|
418
|
+
torch_geometric/nn/dense/linear.py,sha256=rAXLXq8kAuyyR1fxNOj8Rj_w7Zk5TMergcPUsXQAWI4,17150
|
419
419
|
torch_geometric/nn/dense/mincut_pool.py,sha256=CirJKIEXICGul3ziTno-o2EqDkQLkV7m2KxdYPyI4ZI,4111
|
420
420
|
torch_geometric/nn/functional/__init__.py,sha256=ggqde0hPT7wKzWAbQaEe9yX-Jcg_tWO-wivMmAJ9rz0,129
|
421
421
|
torch_geometric/nn/functional/bro.py,sha256=_MurXJXVY1cFaCjDEAyvNoXv-Ka_Odlz-jxIS4OuDzY,1549
|
@@ -498,10 +498,10 @@ torch_geometric/nn/pool/sag_pool.py,sha256=YgNJUDd2WrE2PW9_ibQC7YaXlnSOCJ_4vt2LY
|
|
498
498
|
torch_geometric/nn/pool/topk_pool.py,sha256=0n2Bg2Pt6nVozlAJjZpIMcTCMG7o_sGkzDNzNVN8D3A,5159
|
499
499
|
torch_geometric/nn/pool/voxel_grid.py,sha256=OLe0kAsYYiJLdlgNwJYTIDX53lg1t2X_TCTawPtcU2A,2793
|
500
500
|
torch_geometric/nn/pool/connect/__init__.py,sha256=rIaO9siCtXt5wBTQnSWnDyadnGZgF1hgfQo21Foij2M,287
|
501
|
-
torch_geometric/nn/pool/connect/base.py,sha256=
|
501
|
+
torch_geometric/nn/pool/connect/base.py,sha256=2rLOCC5NGQ7CIrrY-_zWtXKVcW6MPrVJO7tFxJHw4m0,4009
|
502
502
|
torch_geometric/nn/pool/connect/filter_edges.py,sha256=LDzTjOWRjone2Gw7buBwwp2rOSRVaDmoyPJBik18BTo,2190
|
503
503
|
torch_geometric/nn/pool/select/__init__.py,sha256=V0nnZQhbWPt_yDylHD5nwCSBMYzyWfgETePvNE-a7AM,254
|
504
|
-
torch_geometric/nn/pool/select/base.py,sha256=
|
504
|
+
torch_geometric/nn/pool/select/base.py,sha256=1PhuMBqmAGscr4ERluRHapZ4-GcMMLebluuRJB0esms,3225
|
505
505
|
torch_geometric/nn/pool/select/topk.py,sha256=R1LTjOvanJqlrcDe0qinqz286qOJpmjC1tPeiQdPGcU,5305
|
506
506
|
torch_geometric/nn/unpool/__init__.py,sha256=J6I3abNR1MRxisXzbX3sBRH-hlMpmUe7FVc3UziZ67s,129
|
507
507
|
torch_geometric/nn/unpool/knn_interpolate.py,sha256=8GlKoB-wzZz6ETJP7SsKHbzwenr4JiPg6sK3uh9I6R8,2586
|
@@ -592,17 +592,17 @@ torch_geometric/utils/_degree.py,sha256=FcsGx5cQdrBmoCQ4qQ2csjsTiDICP1as4x1HD9y5
|
|
592
592
|
torch_geometric/utils/_grid.py,sha256=1coutST2TMV9TSQcmpXze0GIK9odzZ9wBtbKs6u26D8,2562
|
593
593
|
torch_geometric/utils/_homophily.py,sha256=1nXxGUATFPB3icEGpvEWUiuYbjU9gDGtlWpuLbtWhJk,5090
|
594
594
|
torch_geometric/utils/_index_sort.py,sha256=FTJacmOsqgsyof7MJFHlVVdXhHOjR0j7siTb0UZ-YT0,1283
|
595
|
-
torch_geometric/utils/_lexsort.py,sha256=
|
595
|
+
torch_geometric/utils/_lexsort.py,sha256=GvqVvDrEjurQP_zDZsWkg6zCkL4ORVPNRdTsoa2pllc,1097
|
596
596
|
torch_geometric/utils/_negative_sampling.py,sha256=jxsmpryeoTT8qQrvIH11MgyhgoWzvqPGRAcVyU85VCU,15494
|
597
597
|
torch_geometric/utils/_normalize_edge_index.py,sha256=H6DY-Dzi1Psr3igG_nb0U3ZPNZz-BBDntO2iuA8FtzA,1682
|
598
598
|
torch_geometric/utils/_normalized_cut.py,sha256=uwVJkl-Q0tpY-w0nvcHajcQYcqFh1oDOf55XELdjJBU,1167
|
599
599
|
torch_geometric/utils/_one_hot.py,sha256=vXC7l7zudYRZIwWv6mT-Biuk2zKELyqteJXLynPocPM,1404
|
600
|
-
torch_geometric/utils/_scatter.py,sha256=
|
600
|
+
torch_geometric/utils/_scatter.py,sha256=qwM8GpMiXT8uuckTg9nYDtwg2WqAWljkhjLE8omOHhk,11642
|
601
601
|
torch_geometric/utils/_segment.py,sha256=CqS7_NMQihX89gEwFVHbyMEZgaEnSlJGpyuWqy3i8HI,1976
|
602
602
|
torch_geometric/utils/_select.py,sha256=BZ5P6-1riR4xhCIJZnsNg5HmeAGelRzH42TpADj9xpQ,2439
|
603
603
|
torch_geometric/utils/_softmax.py,sha256=6dTVbWX04laemRP-ZFPMS6ymRZtRa8zYF22QCXl_m4w,3242
|
604
|
-
torch_geometric/utils/_sort_edge_index.py,sha256=
|
605
|
-
torch_geometric/utils/_spmm.py,sha256=
|
604
|
+
torch_geometric/utils/_sort_edge_index.py,sha256=Z5F9xcRp3hKzTiTlc2gqYufs8QLDSI4cuJNRI0zF0G4,4373
|
605
|
+
torch_geometric/utils/_spmm.py,sha256=zesnYzhGZoEih99iW_fMBysiWFPLLiuV8YTIQg0fOL4,5740
|
606
606
|
torch_geometric/utils/_subgraph.py,sha256=GcOGNUcVe97tifuQyi5qBZ88A_Wo3-o17l9xCSIsau4,18456
|
607
607
|
torch_geometric/utils/_to_dense_adj.py,sha256=hl1sboUBvED5Er66bqLms4VdmxKA-7Y3ozJIR-YIAUc,3606
|
608
608
|
torch_geometric/utils/_to_dense_batch.py,sha256=-K5NjjfvjKYKJQ3kXgNIDR7lwMJ_GGISI45b50IGMvY,4582
|
@@ -632,12 +632,12 @@ torch_geometric/utils/ppr.py,sha256=ebiHbQqRJsQbGUI5xu-IkzQSQsgIaC71vgO0KcXIKAk,
|
|
632
632
|
torch_geometric/utils/random.py,sha256=Rv5HlhG5310rytbT9EZ7xWLGKQfozfz1azvYi5nx2-U,5148
|
633
633
|
torch_geometric/utils/repeat.py,sha256=RxCoRoEisaP6NouXPPW5tY1Rn-tIfrmpJPm0qGP6W8M,815
|
634
634
|
torch_geometric/utils/smiles.py,sha256=lGQ2BwJ49uBrQfIxxPz8ceTO9Jo-XCjlLxs1ql3xrsA,7130
|
635
|
-
torch_geometric/utils/sparse.py,sha256=
|
635
|
+
torch_geometric/utils/sparse.py,sha256=MJyWkn-r9sdMyR_m-aBUIQUkvxsYLUNP9jYlfntSUpI,25118
|
636
636
|
torch_geometric/utils/undirected.py,sha256=H_nfpI0_WluOG6VfjPyldvcjL4w5USAKWu2x5nUAUjw,6222
|
637
637
|
torch_geometric/visualization/__init__.py,sha256=b-HnVesXjyJ_L1N-DnjiRiRVf7lhwKaBQF_2i5YMVSU,208
|
638
638
|
torch_geometric/visualization/graph.py,sha256=PoI9tjbEXZVkMUg4CvTLbzqtEfzUwMUcsw57DNBEU0s,14311
|
639
639
|
torch_geometric/visualization/influence.py,sha256=CWMvuNA_Nf1sfbJmQgn58yS4OFpeKXeZPe7kEuvkUBw,477
|
640
|
-
pyg_nightly-2.7.0.
|
641
|
-
pyg_nightly-2.7.0.
|
642
|
-
pyg_nightly-2.7.0.
|
643
|
-
pyg_nightly-2.7.0.
|
640
|
+
pyg_nightly-2.7.0.dev20250505.dist-info/licenses/LICENSE,sha256=ic-27cMJc1kWoMEYncz3Ya3Ur2Bi3bNLWib2DT763-o,1067
|
641
|
+
pyg_nightly-2.7.0.dev20250505.dist-info/WHEEL,sha256=G2gURzTEtmeR8nrdXUJfNiB3VYVxigPQ-bEQujpNiNs,82
|
642
|
+
pyg_nightly-2.7.0.dev20250505.dist-info/METADATA,sha256=JqptxbND3anHymFM2E5zV7xbDWSmL7mf9V5YVQkoJmk,62979
|
643
|
+
pyg_nightly-2.7.0.dev20250505.dist-info/RECORD,,
|
torch_geometric/__init__.py
CHANGED
@@ -31,7 +31,7 @@ from .lazy_loader import LazyLoader
|
|
31
31
|
contrib = LazyLoader('contrib', globals(), 'torch_geometric.contrib')
|
32
32
|
graphgym = LazyLoader('graphgym', globals(), 'torch_geometric.graphgym')
|
33
33
|
|
34
|
-
__version__ = '2.7.0.
|
34
|
+
__version__ = '2.7.0.dev20250505'
|
35
35
|
|
36
36
|
__all__ = [
|
37
37
|
'Index',
|
@@ -57,6 +57,13 @@ __all__ = [
|
|
57
57
|
'__version__',
|
58
58
|
]
|
59
59
|
|
60
|
+
if not torch_geometric.typing.WITH_PT113:
|
61
|
+
import warnings as std_warnings
|
62
|
+
|
63
|
+
std_warnings.warn("PyG 2.7 removed support for PyTorch < 1.13. Consider "
|
64
|
+
"Consider upgrading to PyTorch >= 1.13 or downgrading "
|
65
|
+
"to PyG <= 2.6. ")
|
66
|
+
|
60
67
|
# Serialization ###############################################################
|
61
68
|
|
62
69
|
if torch_geometric.typing.WITH_PT24:
|
torch_geometric/data/collate.py
CHANGED
@@ -191,10 +191,8 @@ def _collate(
|
|
191
191
|
if torch_geometric.typing.WITH_PT20:
|
192
192
|
storage = elem.untyped_storage()._new_shared(
|
193
193
|
numel * elem.element_size(), device=elem.device)
|
194
|
-
elif torch_geometric.typing.WITH_PT112:
|
195
|
-
storage = elem.storage()._new_shared(numel, device=elem.device)
|
196
194
|
else:
|
197
|
-
storage = elem.storage()._new_shared(numel)
|
195
|
+
storage = elem.storage()._new_shared(numel, device=elem.device)
|
198
196
|
shape = list(elem.size())
|
199
197
|
if cat_dim is None or elem.dim() == 0:
|
200
198
|
shape = [len(values)] + shape
|
torch_geometric/edge_index.py
CHANGED
@@ -298,8 +298,7 @@ class EdgeIndex(Tensor):
|
|
298
298
|
indptr = None
|
299
299
|
data = torch.stack([row, col], dim=0)
|
300
300
|
|
301
|
-
if
|
302
|
-
and data.layout == torch.sparse_csc):
|
301
|
+
if data.layout == torch.sparse_csc:
|
303
302
|
row = data.row_indices()
|
304
303
|
indptr = data.ccol_indices()
|
305
304
|
|
@@ -882,10 +881,6 @@ class EdgeIndex(Tensor):
|
|
882
881
|
If not specified, non-zero elements will be assigned a value of
|
883
882
|
:obj:`1.0`. (default: :obj:`None`)
|
884
883
|
"""
|
885
|
-
if not torch_geometric.typing.WITH_PT112:
|
886
|
-
raise NotImplementedError(
|
887
|
-
"'to_sparse_csc' not supported for PyTorch < 1.12")
|
888
|
-
|
889
884
|
(colptr, row), perm = self.get_csc()
|
890
885
|
if value is not None and perm is not None:
|
891
886
|
value = value[perm]
|
@@ -922,7 +917,7 @@ class EdgeIndex(Tensor):
|
|
922
917
|
return self.to_sparse_coo(value)
|
923
918
|
if layout == torch.sparse_csr:
|
924
919
|
return self.to_sparse_csr(value)
|
925
|
-
if
|
920
|
+
if layout == torch.sparse_csc:
|
926
921
|
return self.to_sparse_csc(value)
|
927
922
|
|
928
923
|
raise ValueError(f"Unexpected tensor layout (got '{layout}')")
|
@@ -1,4 +1,3 @@
|
|
1
|
-
import copy
|
2
1
|
import math
|
3
2
|
import sys
|
4
3
|
import time
|
@@ -114,25 +113,6 @@ class Linear(torch.nn.Module):
|
|
114
113
|
|
115
114
|
self.reset_parameters()
|
116
115
|
|
117
|
-
def __deepcopy__(self, memo):
|
118
|
-
# PyTorch<1.13 cannot handle deep copies of uninitialized parameters :(
|
119
|
-
# TODO Drop this code once PyTorch 1.12 is no longer supported.
|
120
|
-
out = Linear(
|
121
|
-
self.in_channels,
|
122
|
-
self.out_channels,
|
123
|
-
self.bias is not None,
|
124
|
-
self.weight_initializer,
|
125
|
-
self.bias_initializer,
|
126
|
-
).to(self.weight.device)
|
127
|
-
|
128
|
-
if self.in_channels > 0:
|
129
|
-
out.weight = copy.deepcopy(self.weight, memo)
|
130
|
-
|
131
|
-
if self.bias is not None:
|
132
|
-
out.bias = copy.deepcopy(self.bias, memo)
|
133
|
-
|
134
|
-
return out
|
135
|
-
|
136
116
|
def reset_parameters(self):
|
137
117
|
r"""Resets all learnable parameters of the module."""
|
138
118
|
reset_weight_(self.weight, self.in_channels, self.weight_initializer)
|
@@ -4,7 +4,6 @@ from typing import Optional
|
|
4
4
|
import torch
|
5
5
|
from torch import Tensor
|
6
6
|
|
7
|
-
import torch_geometric.typing
|
8
7
|
from torch_geometric.nn.pool.select import SelectOutput
|
9
8
|
|
10
9
|
|
@@ -49,8 +48,7 @@ class ConnectOutput:
|
|
49
48
|
self.batch = batch
|
50
49
|
|
51
50
|
|
52
|
-
|
53
|
-
ConnectOutput = torch.jit.script(ConnectOutput)
|
51
|
+
ConnectOutput = torch.jit.script(ConnectOutput)
|
54
52
|
|
55
53
|
|
56
54
|
class Connect(torch.nn.Module):
|
@@ -4,8 +4,6 @@ from typing import Optional
|
|
4
4
|
import torch
|
5
5
|
from torch import Tensor
|
6
6
|
|
7
|
-
import torch_geometric.typing
|
8
|
-
|
9
7
|
|
10
8
|
@dataclass(init=False)
|
11
9
|
class SelectOutput:
|
@@ -64,8 +62,7 @@ class SelectOutput:
|
|
64
62
|
self.weight = weight
|
65
63
|
|
66
64
|
|
67
|
-
|
68
|
-
SelectOutput = torch.jit.script(SelectOutput)
|
65
|
+
SelectOutput = torch.jit.script(SelectOutput)
|
69
66
|
|
70
67
|
|
71
68
|
class Select(torch.nn.Module):
|
torch_geometric/typing.py
CHANGED
@@ -21,8 +21,6 @@ WITH_PT23 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 3
|
|
21
21
|
WITH_PT24 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 4
|
22
22
|
WITH_PT25 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 5
|
23
23
|
WITH_PT26 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 6
|
24
|
-
WITH_PT111 = WITH_PT20 or int(torch.__version__.split('.')[1]) >= 11
|
25
|
-
WITH_PT112 = WITH_PT20 or int(torch.__version__.split('.')[1]) >= 12
|
26
24
|
WITH_PT113 = WITH_PT20 or int(torch.__version__.split('.')[1]) >= 13
|
27
25
|
|
28
26
|
WITH_WINDOWS = os.name == 'nt'
|
@@ -1,11 +1,7 @@
|
|
1
1
|
from typing import List
|
2
2
|
|
3
|
-
import numpy as np
|
4
|
-
import torch
|
5
3
|
from torch import Tensor
|
6
4
|
|
7
|
-
import torch_geometric.typing
|
8
|
-
|
9
5
|
|
10
6
|
def lexsort(
|
11
7
|
keys: List[Tensor],
|
@@ -28,11 +24,6 @@ def lexsort(
|
|
28
24
|
"""
|
29
25
|
assert len(keys) >= 1
|
30
26
|
|
31
|
-
if not torch_geometric.typing.WITH_PT113:
|
32
|
-
keys = [k.neg() for k in keys] if descending else keys
|
33
|
-
out = np.lexsort([k.detach().cpu().numpy() for k in keys], axis=dim)
|
34
|
-
return torch.from_numpy(out).to(keys[0].device)
|
35
|
-
|
36
27
|
out = keys[0].argsort(dim=dim, descending=descending, stable=True)
|
37
28
|
for k in keys[1:]:
|
38
29
|
index = k.gather(dim, out)
|
@@ -8,185 +8,132 @@ from torch_geometric import is_compiling, is_in_onnx_export, warnings
|
|
8
8
|
from torch_geometric.typing import torch_scatter
|
9
9
|
from torch_geometric.utils.functions import cumsum
|
10
10
|
|
11
|
-
|
12
|
-
|
13
|
-
warnings.filterwarnings('ignore', '.*is in beta and the API may change.*')
|
14
|
-
|
15
|
-
def scatter(
|
16
|
-
src: Tensor,
|
17
|
-
index: Tensor,
|
18
|
-
dim: int = 0,
|
19
|
-
dim_size: Optional[int] = None,
|
20
|
-
reduce: str = 'sum',
|
21
|
-
) -> Tensor:
|
22
|
-
r"""Reduces all values from the :obj:`src` tensor at the indices
|
23
|
-
specified in the :obj:`index` tensor along a given dimension
|
24
|
-
:obj:`dim`. See the `documentation
|
25
|
-
<https://pytorch-scatter.readthedocs.io/en/latest/functions/
|
26
|
-
scatter.html>`__ of the :obj:`torch_scatter` package for more
|
27
|
-
information.
|
28
|
-
|
29
|
-
Args:
|
30
|
-
src (torch.Tensor): The source tensor.
|
31
|
-
index (torch.Tensor): The index tensor.
|
32
|
-
dim (int, optional): The dimension along which to index.
|
33
|
-
(default: :obj:`0`)
|
34
|
-
dim_size (int, optional): The size of the output tensor at
|
35
|
-
dimension :obj:`dim`. If set to :obj:`None`, will create a
|
36
|
-
minimal-sized output tensor according to
|
37
|
-
:obj:`index.max() + 1`. (default: :obj:`None`)
|
38
|
-
reduce (str, optional): The reduce operation (:obj:`"sum"`,
|
39
|
-
:obj:`"mean"`, :obj:`"mul"`, :obj:`"min"` or :obj:`"max"`,
|
40
|
-
:obj:`"any"`). (default: :obj:`"sum"`)
|
41
|
-
"""
|
42
|
-
if isinstance(index, Tensor) and index.dim() != 1:
|
43
|
-
raise ValueError(f"The `index` argument must be one-dimensional "
|
44
|
-
f"(got {index.dim()} dimensions)")
|
45
|
-
|
46
|
-
dim = src.dim() + dim if dim < 0 else dim
|
47
|
-
|
48
|
-
if isinstance(src, Tensor) and (dim < 0 or dim >= src.dim()):
|
49
|
-
raise ValueError(f"The `dim` argument must lay between 0 and "
|
50
|
-
f"{src.dim() - 1} (got {dim})")
|
51
|
-
|
52
|
-
if dim_size is None:
|
53
|
-
dim_size = int(index.max()) + 1 if index.numel() > 0 else 0
|
54
|
-
|
55
|
-
# For now, we maintain various different code paths, based on whether
|
56
|
-
# the input requires gradients and whether it lays on the CPU/GPU.
|
57
|
-
# For example, `torch_scatter` is usually faster than
|
58
|
-
# `torch.scatter_reduce` on GPU, while `torch.scatter_reduce` is faster
|
59
|
-
# on CPU.
|
60
|
-
# `torch.scatter_reduce` has a faster forward implementation for
|
61
|
-
# "min"/"max" reductions since it does not compute additional arg
|
62
|
-
# indices, but is therefore way slower in its backward implementation.
|
63
|
-
# More insights can be found in `test/utils/test_scatter.py`.
|
64
|
-
|
65
|
-
size = src.size()[:dim] + (dim_size, ) + src.size()[dim + 1:]
|
66
|
-
|
67
|
-
# For "any" reduction, we use regular `scatter_`:
|
68
|
-
if reduce == 'any':
|
69
|
-
index = broadcast(index, src, dim)
|
70
|
-
return src.new_zeros(size).scatter_(dim, index, src)
|
11
|
+
warnings.filterwarnings('ignore', '.*is in beta and the API may change.*')
|
71
12
|
|
72
|
-
# For "sum" and "mean" reduction, we make use of `scatter_add_`:
|
73
|
-
if reduce == 'sum' or reduce == 'add':
|
74
|
-
index = broadcast(index, src, dim)
|
75
|
-
return src.new_zeros(size).scatter_add_(dim, index, src)
|
76
13
|
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
14
|
+
def scatter(
|
15
|
+
src: Tensor,
|
16
|
+
index: Tensor,
|
17
|
+
dim: int = 0,
|
18
|
+
dim_size: Optional[int] = None,
|
19
|
+
reduce: str = 'sum',
|
20
|
+
) -> Tensor:
|
21
|
+
r"""Reduces all values from the :obj:`src` tensor at the indices specified
|
22
|
+
in the :obj:`index` tensor along a given dimension ``dim``. See the
|
23
|
+
`documentation <https://pytorch-scatter.readthedocs.io/en/latest/functions/scatter.html>`__ # noqa: E501
|
24
|
+
of the ``torch_scatter`` package for more information.
|
81
25
|
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
) -> Tensor:
|
151
|
-
r"""Reduces all values from the :obj:`src` tensor at the indices
|
152
|
-
specified in the :obj:`index` tensor along a given dimension
|
153
|
-
:obj:`dim`. See the `documentation
|
154
|
-
<https://pytorch-scatter.readthedocs.io/en/latest/functions/
|
155
|
-
scatter.html>`_ of the :obj:`torch_scatter` package for more
|
156
|
-
information.
|
157
|
-
|
158
|
-
Args:
|
159
|
-
src (torch.Tensor): The source tensor.
|
160
|
-
index (torch.Tensor): The index tensor.
|
161
|
-
dim (int, optional): The dimension along which to index.
|
162
|
-
(default: :obj:`0`)
|
163
|
-
dim_size (int, optional): The size of the output tensor at
|
164
|
-
dimension :obj:`dim`. If set to :obj:`None`, will create a
|
165
|
-
minimal-sized output tensor according to
|
166
|
-
:obj:`index.max() + 1`. (default: :obj:`None`)
|
167
|
-
reduce (str, optional): The reduce operation (:obj:`"sum"`,
|
168
|
-
:obj:`"mean"`, :obj:`"mul"`, :obj:`"min"` or :obj:`"max"`,
|
169
|
-
:obj:`"any"`). (default: :obj:`"sum"`)
|
170
|
-
"""
|
171
|
-
if reduce == 'any':
|
172
|
-
dim = src.dim() + dim if dim < 0 else dim
|
173
|
-
|
174
|
-
if dim_size is None:
|
175
|
-
dim_size = int(index.max()) + 1 if index.numel() > 0 else 0
|
176
|
-
|
177
|
-
size = src.size()[:dim] + (dim_size, ) + src.size()[dim + 1:]
|
26
|
+
Args:
|
27
|
+
src (torch.Tensor): The source tensor.
|
28
|
+
index (torch.Tensor): The index tensor.
|
29
|
+
dim (int, optional): The dimension along which to index.
|
30
|
+
(default: ``0``)
|
31
|
+
dim_size (int, optional): The size of the output tensor at dimension
|
32
|
+
``dim``. If set to :obj:`None`, will create a minimal-sized output
|
33
|
+
tensor according to ``index.max() + 1``. (default: :obj:`None`)
|
34
|
+
reduce (str, optional): The reduce operation (``"sum"``, ``"mean"``,
|
35
|
+
``"mul"``, ``"min"``, ``"max"`` or ``"any"``). (default: ``"sum"``)
|
36
|
+
"""
|
37
|
+
if isinstance(index, Tensor) and index.dim() != 1:
|
38
|
+
raise ValueError(f"The `index` argument must be one-dimensional "
|
39
|
+
f"(got {index.dim()} dimensions)")
|
40
|
+
|
41
|
+
dim = src.dim() + dim if dim < 0 else dim
|
42
|
+
|
43
|
+
if isinstance(src, Tensor) and (dim < 0 or dim >= src.dim()):
|
44
|
+
raise ValueError(f"The `dim` argument must lay between 0 and "
|
45
|
+
f"{src.dim() - 1} (got {dim})")
|
46
|
+
|
47
|
+
if dim_size is None:
|
48
|
+
dim_size = int(index.max()) + 1 if index.numel() > 0 else 0
|
49
|
+
|
50
|
+
# For now, we maintain various different code paths, based on whether
|
51
|
+
# the input requires gradients and whether it lays on the CPU/GPU.
|
52
|
+
# For example, `torch_scatter` is usually faster than
|
53
|
+
# `torch.scatter_reduce` on GPU, while `torch.scatter_reduce` is faster
|
54
|
+
# on CPU.
|
55
|
+
# `torch.scatter_reduce` has a faster forward implementation for
|
56
|
+
# "min"/"max" reductions since it does not compute additional arg
|
57
|
+
# indices, but is therefore way slower in its backward implementation.
|
58
|
+
# More insights can be found in `test/utils/test_scatter.py`.
|
59
|
+
|
60
|
+
size = src.size()[:dim] + (dim_size, ) + src.size()[dim + 1:]
|
61
|
+
|
62
|
+
# For "any" reduction, we use regular `scatter_`:
|
63
|
+
if reduce == 'any':
|
64
|
+
index = broadcast(index, src, dim)
|
65
|
+
return src.new_zeros(size).scatter_(dim, index, src)
|
66
|
+
|
67
|
+
# For "sum" and "mean" reduction, we make use of `scatter_add_`:
|
68
|
+
if reduce == 'sum' or reduce == 'add':
|
69
|
+
index = broadcast(index, src, dim)
|
70
|
+
return src.new_zeros(size).scatter_add_(dim, index, src)
|
71
|
+
|
72
|
+
if reduce == 'mean':
|
73
|
+
count = src.new_zeros(dim_size)
|
74
|
+
count.scatter_add_(0, index, src.new_ones(src.size(dim)))
|
75
|
+
count = count.clamp(min=1)
|
76
|
+
|
77
|
+
index = broadcast(index, src, dim)
|
78
|
+
out = src.new_zeros(size).scatter_add_(dim, index, src)
|
79
|
+
|
80
|
+
return out / broadcast(count, out, dim)
|
81
|
+
|
82
|
+
# For "min" and "max" reduction, we prefer `scatter_reduce_` on CPU or
|
83
|
+
# in case the input does not require gradients:
|
84
|
+
if reduce in ['min', 'max', 'amin', 'amax']:
|
85
|
+
if (not torch_geometric.typing.WITH_TORCH_SCATTER or is_compiling()
|
86
|
+
or is_in_onnx_export() or not src.is_cuda
|
87
|
+
or not src.requires_grad):
|
88
|
+
|
89
|
+
if (src.is_cuda and src.requires_grad and not is_compiling()
|
90
|
+
and not is_in_onnx_export()):
|
91
|
+
warnings.warn(f"The usage of `scatter(reduce='{reduce}')` "
|
92
|
+
f"can be accelerated via the 'torch-scatter'"
|
93
|
+
f" package, but it was not found")
|
178
94
|
|
179
95
|
index = broadcast(index, src, dim)
|
180
|
-
|
96
|
+
if not is_in_onnx_export():
|
97
|
+
return src.new_zeros(size).scatter_reduce_(
|
98
|
+
dim, index, src, reduce=f'a{reduce[-3:]}',
|
99
|
+
include_self=False)
|
100
|
+
|
101
|
+
fill = torch.full( # type: ignore
|
102
|
+
size=(1, ),
|
103
|
+
fill_value=src.min() if 'max' in reduce else src.max(),
|
104
|
+
dtype=src.dtype,
|
105
|
+
device=src.device,
|
106
|
+
).expand_as(src)
|
107
|
+
out = src.new_zeros(size).scatter_reduce_(dim, index, fill,
|
108
|
+
reduce=f'a{reduce[-3:]}',
|
109
|
+
include_self=True)
|
110
|
+
return out.scatter_reduce_(dim, index, src,
|
111
|
+
reduce=f'a{reduce[-3:]}',
|
112
|
+
include_self=True)
|
181
113
|
|
182
|
-
|
183
|
-
|
114
|
+
return torch_scatter.scatter(src, index, dim, dim_size=dim_size,
|
115
|
+
reduce=reduce[-3:])
|
116
|
+
|
117
|
+
# For "mul" reduction, we prefer `scatter_reduce_` on CPU:
|
118
|
+
if reduce == 'mul':
|
119
|
+
if (not torch_geometric.typing.WITH_TORCH_SCATTER or is_compiling()
|
120
|
+
or not src.is_cuda):
|
121
|
+
|
122
|
+
if src.is_cuda and not is_compiling():
|
123
|
+
warnings.warn(f"The usage of `scatter(reduce='{reduce}')` "
|
124
|
+
f"can be accelerated via the 'torch-scatter'"
|
125
|
+
f" package, but it was not found")
|
184
126
|
|
185
|
-
|
186
|
-
|
127
|
+
index = broadcast(index, src, dim)
|
128
|
+
# We initialize with `one` here to match `scatter_mul` output:
|
129
|
+
return src.new_ones(size).scatter_reduce_(dim, index, src,
|
130
|
+
reduce='prod',
|
131
|
+
include_self=True)
|
187
132
|
|
188
133
|
return torch_scatter.scatter(src, index, dim, dim_size=dim_size,
|
189
|
-
reduce=
|
134
|
+
reduce='mul')
|
135
|
+
|
136
|
+
raise ValueError(f"Encountered invalid `reduce` argument '{reduce}'")
|
190
137
|
|
191
138
|
|
192
139
|
def broadcast(src: Tensor, ref: Tensor, dim: int) -> Tensor:
|
@@ -215,24 +162,18 @@ def scatter_argmax(
|
|
215
162
|
if dim_size is None:
|
216
163
|
dim_size = int(index.max()) + 1 if index.numel() > 0 else 0
|
217
164
|
|
218
|
-
if
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
include_self=False)
|
223
|
-
else:
|
224
|
-
# `include_self=False` is currently not supported by ONNX:
|
225
|
-
res = src.new_full(
|
226
|
-
size=(dim_size, ),
|
227
|
-
fill_value=src.min(), # type: ignore
|
228
|
-
)
|
229
|
-
res.scatter_reduce_(0, index, src.detach(), reduce="amax",
|
230
|
-
include_self=True)
|
231
|
-
elif torch_geometric.typing.WITH_PT111:
|
232
|
-
res = torch.scatter_reduce(src.detach(), 0, index, reduce='amax',
|
233
|
-
output_size=dim_size) # type: ignore
|
165
|
+
if not is_in_onnx_export():
|
166
|
+
res = src.new_empty(dim_size)
|
167
|
+
res.scatter_reduce_(0, index, src.detach(), reduce='amax',
|
168
|
+
include_self=False)
|
234
169
|
else:
|
235
|
-
|
170
|
+
# `include_self=False` is currently not supported by ONNX:
|
171
|
+
res = src.new_full(
|
172
|
+
size=(dim_size, ),
|
173
|
+
fill_value=src.min(), # type: ignore
|
174
|
+
)
|
175
|
+
res.scatter_reduce_(0, index, src.detach(), reduce="amax",
|
176
|
+
include_self=True)
|
236
177
|
|
237
178
|
out = index.new_full((dim_size, ), fill_value=dim_size - 1)
|
238
179
|
nonzero = (src == res[index]).nonzero().view(-1)
|
@@ -290,13 +231,7 @@ def group_argsort(
|
|
290
231
|
|
291
232
|
# Compute `grouped_argsort`:
|
292
233
|
src = src - 2 * index if descending else src + 2 * index
|
293
|
-
|
294
|
-
perm = src.argsort(descending=descending, stable=stable)
|
295
|
-
else:
|
296
|
-
perm = src.argsort(descending=descending)
|
297
|
-
if stable:
|
298
|
-
warnings.warn("Ignoring option `stable=True` in 'group_argsort' "
|
299
|
-
"since it requires PyTorch >= 1.13.0")
|
234
|
+
perm = src.argsort(descending=descending, stable=stable)
|
300
235
|
out = torch.empty_like(index)
|
301
236
|
out[perm] = torch.arange(index.numel(), device=index.device)
|
302
237
|
|
@@ -107,8 +107,6 @@ def sort_edge_index( # noqa: F811
|
|
107
107
|
num_nodes = maybe_num_nodes(edge_index, num_nodes)
|
108
108
|
|
109
109
|
if num_nodes * num_nodes > torch_geometric.typing.MAX_INT64:
|
110
|
-
if not torch_geometric.typing.WITH_PT113:
|
111
|
-
raise ValueError("'sort_edge_index' will result in an overflow")
|
112
110
|
perm = lexsort(keys=[
|
113
111
|
edge_index[int(sort_by_row)],
|
114
112
|
edge_index[1 - int(sort_by_row)],
|
torch_geometric/utils/_spmm.py
CHANGED
@@ -115,8 +115,7 @@ def spmm(
|
|
115
115
|
if src.layout == torch.sparse_csr:
|
116
116
|
ptr = src.crow_indices()
|
117
117
|
deg = ptr[1:] - ptr[:-1]
|
118
|
-
elif
|
119
|
-
and src.layout == torch.sparse_csc):
|
118
|
+
elif src.layout == torch.sparse_csc:
|
120
119
|
assert src.layout == torch.sparse_csc
|
121
120
|
ones = torch.ones_like(src.values())
|
122
121
|
index = src.row_indices()
|
torch_geometric/utils/sparse.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1
|
-
import typing
|
2
1
|
import warnings
|
3
2
|
from typing import Any, List, Optional, Tuple, Union
|
4
3
|
|
@@ -124,8 +123,7 @@ def is_torch_sparse_tensor(src: Any) -> bool:
|
|
124
123
|
return True
|
125
124
|
if src.layout == torch.sparse_csr:
|
126
125
|
return True
|
127
|
-
if
|
128
|
-
and src.layout == torch.sparse_csc):
|
126
|
+
if src.layout == torch.sparse_csc:
|
129
127
|
return True
|
130
128
|
return False
|
131
129
|
|
@@ -320,12 +318,6 @@ def to_torch_csc_tensor(
|
|
320
318
|
size=(4, 4), nnz=6, layout=torch.sparse_csc)
|
321
319
|
|
322
320
|
"""
|
323
|
-
if not torch_geometric.typing.WITH_PT112:
|
324
|
-
if typing.TYPE_CHECKING:
|
325
|
-
raise NotImplementedError
|
326
|
-
return torch_geometric.typing.MockTorchCSCTensor(
|
327
|
-
edge_index, edge_attr, size)
|
328
|
-
|
329
321
|
if size is None:
|
330
322
|
size = int(edge_index.max()) + 1
|
331
323
|
|
@@ -392,7 +384,7 @@ def to_torch_sparse_tensor(
|
|
392
384
|
return to_torch_coo_tensor(edge_index, edge_attr, size, is_coalesced)
|
393
385
|
if layout == torch.sparse_csr:
|
394
386
|
return to_torch_csr_tensor(edge_index, edge_attr, size, is_coalesced)
|
395
|
-
if
|
387
|
+
if layout == torch.sparse_csc:
|
396
388
|
return to_torch_csc_tensor(edge_index, edge_attr, size, is_coalesced)
|
397
389
|
|
398
390
|
raise ValueError(f"Unexpected sparse tensor layout (got '{layout}')")
|
@@ -431,7 +423,7 @@ def to_edge_index(adj: Union[Tensor, SparseTensor]) -> Tuple[Tensor, Tensor]:
|
|
431
423
|
col = adj.col_indices().detach()
|
432
424
|
return torch.stack([row, col], dim=0).long(), adj.values()
|
433
425
|
|
434
|
-
if
|
426
|
+
if adj.layout == torch.sparse_csc:
|
435
427
|
col = ptr2index(adj.ccol_indices().detach())
|
436
428
|
row = adj.row_indices().detach()
|
437
429
|
return torch.stack([row, col], dim=0).long(), adj.values()
|
@@ -480,7 +472,7 @@ def set_sparse_value(adj: Tensor, value: Tensor) -> Tensor:
|
|
480
472
|
device=value.device,
|
481
473
|
)
|
482
474
|
|
483
|
-
if
|
475
|
+
if adj.layout == torch.sparse_csc:
|
484
476
|
return torch.sparse_csc_tensor(
|
485
477
|
ccol_indices=adj.ccol_indices(),
|
486
478
|
row_indices=adj.row_indices(),
|
File without changes
|
{pyg_nightly-2.7.0.dev20250503.dist-info → pyg_nightly-2.7.0.dev20250505.dist-info}/licenses/LICENSE
RENAMED
File without changes
|