pyg-nightly 2.7.0.dev20241113__py3-none-any.whl → 2.7.0.dev20241118__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pyg_nightly-2.7.0.dev20241113.dist-info → pyg_nightly-2.7.0.dev20241118.dist-info}/METADATA +14 -14
- {pyg_nightly-2.7.0.dev20241113.dist-info → pyg_nightly-2.7.0.dev20241118.dist-info}/RECORD +9 -9
- torch_geometric/__init__.py +1 -1
- torch_geometric/_compile.py +1 -1
- torch_geometric/data/feature_store.py +13 -18
- torch_geometric/transforms/delaunay.py +65 -15
- torch_geometric/typing.py +1 -0
- torch_geometric/utils/_subgraph.py +4 -0
- {pyg_nightly-2.7.0.dev20241113.dist-info → pyg_nightly-2.7.0.dev20241118.dist-info}/WHEEL +0 -0
{pyg_nightly-2.7.0.dev20241113.dist-info → pyg_nightly-2.7.0.dev20241118.dist-info}/METADATA
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.3
|
2
2
|
Name: pyg-nightly
|
3
|
-
Version: 2.7.0.
|
3
|
+
Version: 2.7.0.dev20241118
|
4
4
|
Summary: Graph Neural Network Library for PyTorch
|
5
5
|
Keywords: deep-learning,pytorch,geometric-deep-learning,graph-neural-networks,graph-convolutional-networks
|
6
6
|
Author-email: Matthias Fey <matthias@pyg.org>
|
@@ -446,12 +446,12 @@ We recommend to start with a minimal installation, and install additional depend
|
|
446
446
|
|
447
447
|
For ease of installation of these extensions, we provide `pip` wheels for all major OS/PyTorch/CUDA combinations, see [here](https://data.pyg.org/whl).
|
448
448
|
|
449
|
-
#### PyTorch 2.
|
449
|
+
#### PyTorch 2.5
|
450
450
|
|
451
|
-
To install the binaries for PyTorch 2.
|
451
|
+
To install the binaries for PyTorch 2.5.0, simply run
|
452
452
|
|
453
453
|
```
|
454
|
-
pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.
|
454
|
+
pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.5.0+${CUDA}.html
|
455
455
|
```
|
456
456
|
|
457
457
|
where `${CUDA}` should be replaced by either `cpu`, `cu118`, `cu121`, or `cu124` depending on your PyTorch installation.
|
@@ -462,23 +462,23 @@ where `${CUDA}` should be replaced by either `cpu`, `cu118`, `cu121`, or `cu124`
|
|
462
462
|
| **Windows** | ✅ | ✅ | ✅ | ✅ |
|
463
463
|
| **macOS** | ✅ | | | |
|
464
464
|
|
465
|
-
#### PyTorch 2.
|
465
|
+
#### PyTorch 2.4
|
466
466
|
|
467
|
-
To install the binaries for PyTorch 2.
|
467
|
+
To install the binaries for PyTorch 2.4.0, simply run
|
468
468
|
|
469
469
|
```
|
470
|
-
pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.
|
470
|
+
pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.4.0+${CUDA}.html
|
471
471
|
```
|
472
472
|
|
473
|
-
where `${CUDA}` should be replaced by either `cpu`, `cu118`, or `
|
473
|
+
where `${CUDA}` should be replaced by either `cpu`, `cu118`, `cu121`, or `cu124` depending on your PyTorch installation.
|
474
474
|
|
475
|
-
| | `cpu` | `cu118` | `cu121` |
|
476
|
-
| ----------- | ----- | ------- | ------- |
|
477
|
-
| **Linux** | ✅ | ✅ | ✅ |
|
478
|
-
| **Windows** | ✅ | ✅ | ✅ |
|
479
|
-
| **macOS** | ✅ | | |
|
475
|
+
| | `cpu` | `cu118` | `cu121` | `cu124` |
|
476
|
+
| ----------- | ----- | ------- | ------- | ------- |
|
477
|
+
| **Linux** | ✅ | ✅ | ✅ | ✅ |
|
478
|
+
| **Windows** | ✅ | ✅ | ✅ | ✅ |
|
479
|
+
| **macOS** | ✅ | | | |
|
480
480
|
|
481
|
-
**Note:** Binaries of older versions are also provided for PyTorch 1.4.0, PyTorch 1.5.0, PyTorch 1.6.0, PyTorch 1.7.0/1.7.1, PyTorch 1.8.0/1.8.1, PyTorch 1.9.0, PyTorch 1.10.0/1.10.1/1.10.2, PyTorch 1.11.0, PyTorch 1.12.0/1.12.1, PyTorch 1.13.0/1.13.1, PyTorch 2.0.0/2.0.1, PyTorch 2.1.0/2.1.1/2.1.2,
|
481
|
+
**Note:** Binaries of older versions are also provided for PyTorch 1.4.0, PyTorch 1.5.0, PyTorch 1.6.0, PyTorch 1.7.0/1.7.1, PyTorch 1.8.0/1.8.1, PyTorch 1.9.0, PyTorch 1.10.0/1.10.1/1.10.2, PyTorch 1.11.0, PyTorch 1.12.0/1.12.1, PyTorch 1.13.0/1.13.1, PyTorch 2.0.0/2.0.1, PyTorch 2.1.0/2.1.1/2.1.2, PyTorch 2.2.0/2.2.1/2.2.2, and PyTorch 2.3.0/2.3.1 (following the same procedure).
|
482
482
|
**For older versions, you might need to explicitly specify the latest supported version number** or install via `pip install --no-index` in order to prevent a manual installation from source.
|
483
483
|
You can look up the latest supported version number [here](https://data.pyg.org/whl).
|
484
484
|
|
@@ -1,5 +1,5 @@
|
|
1
|
-
torch_geometric/__init__.py,sha256=
|
2
|
-
torch_geometric/_compile.py,sha256=
|
1
|
+
torch_geometric/__init__.py,sha256=vTJWoDf2s2ZvILNC7ZptH1Bn78dW0KnWBGKrR_Vgolo,1904
|
2
|
+
torch_geometric/_compile.py,sha256=f-WQeH4VLi5Hn9lrgztFUCSrN_FImjhQa6BxFzcYC38,1338
|
3
3
|
torch_geometric/_onnx.py,sha256=V9ffrIKSqhDw6xUZ12lkuSfNs48cQp2EeJ6Z19GfnVw,349
|
4
4
|
torch_geometric/backend.py,sha256=lVaf7aLoVaB3M-UcByUJ1G4T4FOK6LXAg0CF4W3E8jo,1575
|
5
5
|
torch_geometric/config_mixin.py,sha256=DGQXStKNiPp4iBvtx7aVoofWcUSuVKEHdG5WUL2oNJs,4230
|
@@ -18,7 +18,7 @@ torch_geometric/logging.py,sha256=HmHHLiCcM64k-6UYNOSfXPIeSGNAyiGGcn8cD8tlyuQ,85
|
|
18
18
|
torch_geometric/resolver.py,sha256=fn-_6mCpI2xv7eDZnIFcYrHOn0IrwbkWFLDb9laQrWI,1270
|
19
19
|
torch_geometric/seed.py,sha256=MJLbVwpb9i8mK3oi32sS__Cq-dRq_afTeoOL_HoA9ko,372
|
20
20
|
torch_geometric/template.py,sha256=rqjDWgcSAgTCiV4bkOjWRPaO4PpUdC_RXigzxxBqAu8,1060
|
21
|
-
torch_geometric/typing.py,sha256=
|
21
|
+
torch_geometric/typing.py,sha256=0pxCLue4iqqFC-k5ByqAeyg2mogtWXqgtod3ZOEMq1A,13933
|
22
22
|
torch_geometric/warnings.py,sha256=t114CbkrmiqkXaavx5g7OO52dLdktf-U__B5QqYIQvI,413
|
23
23
|
torch_geometric/contrib/__init__.py,sha256=0pWkmXfZtbdr-AKwlii5LTFggTEH-MCrSKpZxrtPlVs,352
|
24
24
|
torch_geometric/contrib/datasets/__init__.py,sha256=lrGnWsEiJf5zsBRmshGZZFN_uYR2ezDjbj9n9nCpvtk,23
|
@@ -38,7 +38,7 @@ torch_geometric/data/datapipes.py,sha256=9_Cq3j_7LIF4plQFzbLaqyy0LcpKdAic6yiKgMq
|
|
38
38
|
torch_geometric/data/dataset.py,sha256=TX2AM3OQkMLOx5Ie8IFtFFYuoA3AGeYwoT3ZqW56N7c,16768
|
39
39
|
torch_geometric/data/download.py,sha256=kcesTu6jlgmCeePpOxDQOnVhxB_GuZ9iu9ds72KEORc,1889
|
40
40
|
torch_geometric/data/extract.py,sha256=X_f0JEo67DF9hOpIlq3QPWXA9RF8uoVFi195UjstzDc,2324
|
41
|
-
torch_geometric/data/feature_store.py,sha256=
|
41
|
+
torch_geometric/data/feature_store.py,sha256=ma65GAHHEoYiZqHs_CkMGAYxeepGc1Bp0TMXmioIfCs,20044
|
42
42
|
torch_geometric/data/graph_store.py,sha256=oFrLDNP5hKf3HWWsFsjcamx5vLIEk8JnLjuGpjrFLdc,13867
|
43
43
|
torch_geometric/data/hetero_data.py,sha256=q0L3bENyEvo_BGLPwZPVzh730Aak6sQ7yXoawPgM72E,47982
|
44
44
|
torch_geometric/data/hypergraph_data.py,sha256=33hsXW25Yz4Ju8mKajYinZOrkqrUi1SqThG7MlOOYNM,8294
|
@@ -515,7 +515,7 @@ torch_geometric/transforms/cartesian.py,sha256=_gdFrPP5q3aPmQW6QvYeI8-nvKNVyQF-W
|
|
515
515
|
torch_geometric/transforms/center.py,sha256=4avx4_wm7Q11epOaMQ2YaVcdcTFjozPBEWG2h6GyKc4,645
|
516
516
|
torch_geometric/transforms/compose.py,sha256=P5AFGd6s9L-lpb8io1jKIm2LjAccp_6Q2XocR5T1z5c,1658
|
517
517
|
torch_geometric/transforms/constant.py,sha256=zDJbO1sEds1vjbRmgzSd-8D8gM4PtvWESuC-gX2qB9E,2005
|
518
|
-
torch_geometric/transforms/delaunay.py,sha256
|
518
|
+
torch_geometric/transforms/delaunay.py,sha256=2mbTqs7oeDZQdifrqldLaLeCZz7pGd_5VMSLOzKMIGE,2668
|
519
519
|
torch_geometric/transforms/distance.py,sha256=DvvI2vAYpxklnKCz3-4w2EXz7AYra0xBZ5m7MAL1tok,2360
|
520
520
|
torch_geometric/transforms/face_to_edge.py,sha256=ohAWtpiCs_qHwrrYmz2eLvAcd_bhfPsxkpHdFnAZkhk,2144
|
521
521
|
torch_geometric/transforms/feature_propagation.py,sha256=GPiKiGU7OuOpBBJeATlCtAtBUy_DSHUoBnJdDK8T81E,3056
|
@@ -584,7 +584,7 @@ torch_geometric/utils/_select.py,sha256=BZ5P6-1riR4xhCIJZnsNg5HmeAGelRzH42TpADj9
|
|
584
584
|
torch_geometric/utils/_softmax.py,sha256=6dTVbWX04laemRP-ZFPMS6ymRZtRa8zYF22QCXl_m4w,3242
|
585
585
|
torch_geometric/utils/_sort_edge_index.py,sha256=oPS1AkAmPm5Aq8sSGI-e-OkeXOnU5X58Q86eamHi4gA,4500
|
586
586
|
torch_geometric/utils/_spmm.py,sha256=uRq21nCgostC2jE6SKfp2xIY4_BtrXOG5YHbWUEPv10,5794
|
587
|
-
torch_geometric/utils/_subgraph.py,sha256=
|
587
|
+
torch_geometric/utils/_subgraph.py,sha256=GcOGNUcVe97tifuQyi5qBZ88A_Wo3-o17l9xCSIsau4,18456
|
588
588
|
torch_geometric/utils/_to_dense_adj.py,sha256=hl1sboUBvED5Er66bqLms4VdmxKA-7Y3ozJIR-YIAUc,3606
|
589
589
|
torch_geometric/utils/_to_dense_batch.py,sha256=-K5NjjfvjKYKJQ3kXgNIDR7lwMJ_GGISI45b50IGMvY,4582
|
590
590
|
torch_geometric/utils/_train_test_split_edges.py,sha256=KnBDgnaKuJYTHUOIlvFtzvkHUe-93DG3ckST4-wOERM,3569
|
@@ -618,6 +618,6 @@ torch_geometric/utils/undirected.py,sha256=H_nfpI0_WluOG6VfjPyldvcjL4w5USAKWu2x5
|
|
618
618
|
torch_geometric/visualization/__init__.py,sha256=PyR_4K5SafsJrBr6qWrkjKr6GBL1b7FtZybyXCDEVwY,154
|
619
619
|
torch_geometric/visualization/graph.py,sha256=ZuLPL92yGRi7lxlqsUPwL_EVVXF7P2kMcveTtW79vpA,4784
|
620
620
|
torch_geometric/visualization/influence.py,sha256=CWMvuNA_Nf1sfbJmQgn58yS4OFpeKXeZPe7kEuvkUBw,477
|
621
|
-
pyg_nightly-2.7.0.
|
622
|
-
pyg_nightly-2.7.0.
|
623
|
-
pyg_nightly-2.7.0.
|
621
|
+
pyg_nightly-2.7.0.dev20241118.dist-info/WHEEL,sha256=CpUCUxeHQbRN5UGRQHYRJorO5Af-Qy_fHMctcQ8DSGI,82
|
622
|
+
pyg_nightly-2.7.0.dev20241118.dist-info/METADATA,sha256=1dIlfK5AzrXv52TTBmLL5OpSOhWL_ncnbsVQpwyY33Q,62979
|
623
|
+
pyg_nightly-2.7.0.dev20241118.dist-info/RECORD,,
|
torch_geometric/__init__.py
CHANGED
@@ -30,7 +30,7 @@ from .lazy_loader import LazyLoader
|
|
30
30
|
contrib = LazyLoader('contrib', globals(), 'torch_geometric.contrib')
|
31
31
|
graphgym = LazyLoader('graphgym', globals(), 'torch_geometric.graphgym')
|
32
32
|
|
33
|
-
__version__ = '2.7.0.
|
33
|
+
__version__ = '2.7.0.dev20241118'
|
34
34
|
|
35
35
|
__all__ = [
|
36
36
|
'Index',
|
torch_geometric/_compile.py
CHANGED
@@ -74,13 +74,6 @@ class TensorAttr(CastMixin):
|
|
74
74
|
r"""Whether the :obj:`TensorAttr` has no unset fields."""
|
75
75
|
return all([self.is_set(key) for key in self.__dataclass_fields__])
|
76
76
|
|
77
|
-
def fully_specify(self) -> 'TensorAttr':
|
78
|
-
r"""Sets all :obj:`UNSET` fields to :obj:`None`."""
|
79
|
-
for key in self.__dataclass_fields__:
|
80
|
-
if not self.is_set(key):
|
81
|
-
setattr(self, key, None)
|
82
|
-
return self
|
83
|
-
|
84
77
|
def update(self, attr: 'TensorAttr') -> 'TensorAttr':
|
85
78
|
r"""Updates an :class:`TensorAttr` with set attributes from another
|
86
79
|
:class:`TensorAttr`.
|
@@ -230,10 +223,11 @@ class AttrView(CastMixin):
|
|
230
223
|
|
231
224
|
store[group_name, attr_name]()
|
232
225
|
"""
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
226
|
+
attr = copy.copy(self._attr)
|
227
|
+
for key in attr.__dataclass_fields__: # Set all UNSET values to None.
|
228
|
+
if not attr.is_set(key):
|
229
|
+
setattr(attr, key, None)
|
230
|
+
return self._store.get_tensor(attr)
|
237
231
|
|
238
232
|
def __copy__(self) -> 'AttrView':
|
239
233
|
out = self.__class__.__new__(self.__class__)
|
@@ -479,9 +473,7 @@ class FeatureStore(ABC):
|
|
479
473
|
# CastMixin will handle the case of key being a tuple or TensorAttr
|
480
474
|
# object:
|
481
475
|
key = self._tensor_attr_cls.cast(key)
|
482
|
-
|
483
|
-
# sense to work with a view here:
|
484
|
-
key.fully_specify()
|
476
|
+
assert key.is_fully_specified()
|
485
477
|
self.put_tensor(value, key)
|
486
478
|
|
487
479
|
def __getitem__(self, key: TensorAttr) -> Any:
|
@@ -503,13 +495,16 @@ class FeatureStore(ABC):
|
|
503
495
|
# If the view is not fully-specified, return a :class:`AttrView`:
|
504
496
|
return self.view(attr)
|
505
497
|
|
506
|
-
def __delitem__(self,
|
498
|
+
def __delitem__(self, attr: TensorAttr):
|
507
499
|
r"""Supports :obj:`del store[tensor_attr]`."""
|
508
500
|
# CastMixin will handle the case of key being a tuple or TensorAttr
|
509
501
|
# object:
|
510
|
-
|
511
|
-
|
512
|
-
|
502
|
+
attr = self._tensor_attr_cls.cast(attr)
|
503
|
+
attr = copy.copy(attr)
|
504
|
+
for key in attr.__dataclass_fields__: # Set all UNSET values to None.
|
505
|
+
if not attr.is_set(key):
|
506
|
+
setattr(attr, key, None)
|
507
|
+
self.remove_tensor(attr)
|
513
508
|
|
514
509
|
def __iter__(self):
|
515
510
|
raise NotImplementedError
|
@@ -1,3 +1,5 @@
|
|
1
|
+
from typing import List
|
2
|
+
|
1
3
|
import torch
|
2
4
|
|
3
5
|
from torch_geometric.data import Data
|
@@ -5,30 +7,78 @@ from torch_geometric.data.datapipes import functional_transform
|
|
5
7
|
from torch_geometric.transforms import BaseTransform
|
6
8
|
|
7
9
|
|
10
|
+
class _QhullTransform(BaseTransform):
|
11
|
+
r"""Q-hull implementation of delaunay triangulation."""
|
12
|
+
def forward(self, data: Data) -> Data:
|
13
|
+
assert data.pos is not None
|
14
|
+
import scipy.spatial
|
15
|
+
|
16
|
+
pos = data.pos.cpu().numpy()
|
17
|
+
tri = scipy.spatial.Delaunay(pos, qhull_options='QJ')
|
18
|
+
face = torch.from_numpy(tri.simplices)
|
19
|
+
|
20
|
+
data.face = face.t().contiguous().to(data.pos.device, torch.long)
|
21
|
+
return data
|
22
|
+
|
23
|
+
|
24
|
+
class _ShullTransform(BaseTransform):
|
25
|
+
r"""Sweep-hull implementation of delaunay triangulation."""
|
26
|
+
def forward(self, data: Data) -> Data:
|
27
|
+
assert data.pos is not None
|
28
|
+
from torch_delaunay.functional import shull2d
|
29
|
+
|
30
|
+
face = shull2d(data.pos.cpu())
|
31
|
+
data.face = face.t().contiguous().to(data.pos.device)
|
32
|
+
return data
|
33
|
+
|
34
|
+
|
35
|
+
class _SequentialTransform(BaseTransform):
|
36
|
+
r"""Runs the first successful transformation.
|
37
|
+
|
38
|
+
All intermediate exceptions are suppressed except the last.
|
39
|
+
"""
|
40
|
+
def __init__(self, transforms: List[BaseTransform]) -> None:
|
41
|
+
assert len(transforms) > 0
|
42
|
+
self.transforms = transforms
|
43
|
+
|
44
|
+
def forward(self, data: Data) -> Data:
|
45
|
+
for i, transform in enumerate(self.transforms):
|
46
|
+
try:
|
47
|
+
return transform.forward(data)
|
48
|
+
except ImportError as e:
|
49
|
+
if i == len(self.transforms) - 1:
|
50
|
+
raise e
|
51
|
+
return data
|
52
|
+
|
53
|
+
|
8
54
|
@functional_transform('delaunay')
|
9
55
|
class Delaunay(BaseTransform):
|
10
56
|
r"""Computes the delaunay triangulation of a set of points
|
11
57
|
(functional name: :obj:`delaunay`).
|
58
|
+
|
59
|
+
.. hint::
|
60
|
+
Consider installing the
|
61
|
+
`torch_delaunay <https://github.com/ybubnov/torch_delaunay>`_ package
|
62
|
+
to speed up computation.
|
12
63
|
"""
|
13
|
-
def
|
14
|
-
|
64
|
+
def __init__(self) -> None:
|
65
|
+
self._transform = _SequentialTransform([
|
66
|
+
_ShullTransform(),
|
67
|
+
_QhullTransform(),
|
68
|
+
])
|
15
69
|
|
70
|
+
def forward(self, data: Data) -> Data:
|
16
71
|
assert data.pos is not None
|
72
|
+
device = data.pos.device
|
17
73
|
|
18
74
|
if data.pos.size(0) < 2:
|
19
|
-
data.edge_index = torch.
|
20
|
-
|
21
|
-
|
22
|
-
data.edge_index = torch.tensor([[0, 1], [1, 0]],
|
23
|
-
device=data.pos.device)
|
75
|
+
data.edge_index = torch.empty(2, 0, dtype=torch.long,
|
76
|
+
device=device)
|
77
|
+
elif data.pos.size(0) == 2:
|
78
|
+
data.edge_index = torch.tensor([[0, 1], [1, 0]], device=device)
|
24
79
|
elif data.pos.size(0) == 3:
|
25
|
-
data.face = torch.tensor([[0], [1], [2]],
|
26
|
-
|
27
|
-
|
28
|
-
pos = data.pos.cpu().numpy()
|
29
|
-
tri = scipy.spatial.Delaunay(pos, qhull_options='QJ')
|
30
|
-
face = torch.from_numpy(tri.simplices)
|
31
|
-
|
32
|
-
data.face = face.t().contiguous().to(data.pos.device, torch.long)
|
80
|
+
data.face = torch.tensor([[0], [1], [2]], device=device)
|
81
|
+
else:
|
82
|
+
data = self._transform.forward(data)
|
33
83
|
|
34
84
|
return data
|
torch_geometric/typing.py
CHANGED
@@ -15,6 +15,7 @@ WITH_PT22 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 2
|
|
15
15
|
WITH_PT23 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 3
|
16
16
|
WITH_PT24 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 4
|
17
17
|
WITH_PT25 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 5
|
18
|
+
WITH_PT26 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 6
|
18
19
|
WITH_PT111 = WITH_PT20 or int(torch.__version__.split('.')[1]) >= 11
|
19
20
|
WITH_PT112 = WITH_PT20 or int(torch.__version__.split('.')[1]) >= 12
|
20
21
|
WITH_PT113 = WITH_PT20 or int(torch.__version__.split('.')[1]) >= 13
|
@@ -346,10 +346,12 @@ def k_hop_subgraph(
|
|
346
346
|
|
347
347
|
subsets = [node_idx]
|
348
348
|
|
349
|
+
preserved_edge_mask = torch.zeros_like(edge_mask)
|
349
350
|
for _ in range(num_hops):
|
350
351
|
node_mask.fill_(False)
|
351
352
|
node_mask[subsets[-1]] = True
|
352
353
|
torch.index_select(node_mask, 0, row, out=edge_mask)
|
354
|
+
preserved_edge_mask |= edge_mask
|
353
355
|
subsets.append(col[edge_mask])
|
354
356
|
|
355
357
|
subset, inv = torch.cat(subsets).unique(return_inverse=True)
|
@@ -360,6 +362,8 @@ def k_hop_subgraph(
|
|
360
362
|
|
361
363
|
if not directed:
|
362
364
|
edge_mask = node_mask[row] & node_mask[col]
|
365
|
+
else:
|
366
|
+
edge_mask = preserved_edge_mask
|
363
367
|
|
364
368
|
edge_index = edge_index[:, edge_mask]
|
365
369
|
|
File without changes
|