pyg-nightly 2.7.0.dev20241111__py3-none-any.whl → 2.7.0.dev20241114__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pyg_nightly-2.7.0.dev20241111.dist-info → pyg_nightly-2.7.0.dev20241114.dist-info}/METADATA +14 -14
- {pyg_nightly-2.7.0.dev20241111.dist-info → pyg_nightly-2.7.0.dev20241114.dist-info}/RECORD +13 -13
- torch_geometric/__init__.py +1 -1
- torch_geometric/_compile.py +1 -1
- torch_geometric/data/data.py +7 -1
- torch_geometric/data/feature_store.py +13 -18
- torch_geometric/transforms/delaunay.py +65 -15
- torch_geometric/transforms/face_to_edge.py +32 -3
- torch_geometric/typing.py +1 -0
- torch_geometric/utils/_scatter.py +1 -1
- torch_geometric/utils/_subgraph.py +4 -0
- torch_geometric/visualization/graph.py +1 -1
- {pyg_nightly-2.7.0.dev20241111.dist-info → pyg_nightly-2.7.0.dev20241114.dist-info}/WHEEL +0 -0
{pyg_nightly-2.7.0.dev20241111.dist-info → pyg_nightly-2.7.0.dev20241114.dist-info}/METADATA
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.3
|
2
2
|
Name: pyg-nightly
|
3
|
-
Version: 2.7.0.
|
3
|
+
Version: 2.7.0.dev20241114
|
4
4
|
Summary: Graph Neural Network Library for PyTorch
|
5
5
|
Keywords: deep-learning,pytorch,geometric-deep-learning,graph-neural-networks,graph-convolutional-networks
|
6
6
|
Author-email: Matthias Fey <matthias@pyg.org>
|
@@ -446,12 +446,12 @@ We recommend to start with a minimal installation, and install additional depend
|
|
446
446
|
|
447
447
|
For ease of installation of these extensions, we provide `pip` wheels for all major OS/PyTorch/CUDA combinations, see [here](https://data.pyg.org/whl).
|
448
448
|
|
449
|
-
#### PyTorch 2.
|
449
|
+
#### PyTorch 2.5
|
450
450
|
|
451
|
-
To install the binaries for PyTorch 2.
|
451
|
+
To install the binaries for PyTorch 2.5.0, simply run
|
452
452
|
|
453
453
|
```
|
454
|
-
pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.
|
454
|
+
pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.5.0+${CUDA}.html
|
455
455
|
```
|
456
456
|
|
457
457
|
where `${CUDA}` should be replaced by either `cpu`, `cu118`, `cu121`, or `cu124` depending on your PyTorch installation.
|
@@ -462,23 +462,23 @@ where `${CUDA}` should be replaced by either `cpu`, `cu118`, `cu121`, or `cu124`
|
|
462
462
|
| **Windows** | ✅ | ✅ | ✅ | ✅ |
|
463
463
|
| **macOS** | ✅ | | | |
|
464
464
|
|
465
|
-
#### PyTorch 2.
|
465
|
+
#### PyTorch 2.4
|
466
466
|
|
467
|
-
To install the binaries for PyTorch 2.
|
467
|
+
To install the binaries for PyTorch 2.4.0, simply run
|
468
468
|
|
469
469
|
```
|
470
|
-
pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.
|
470
|
+
pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.4.0+${CUDA}.html
|
471
471
|
```
|
472
472
|
|
473
|
-
where `${CUDA}` should be replaced by either `cpu`, `cu118`, or `
|
473
|
+
where `${CUDA}` should be replaced by either `cpu`, `cu118`, `cu121`, or `cu124` depending on your PyTorch installation.
|
474
474
|
|
475
|
-
| | `cpu` | `cu118` | `cu121` |
|
476
|
-
| ----------- | ----- | ------- | ------- |
|
477
|
-
| **Linux** | ✅ | ✅ | ✅ |
|
478
|
-
| **Windows** | ✅ | ✅ | ✅ |
|
479
|
-
| **macOS** | ✅ | | |
|
475
|
+
| | `cpu` | `cu118` | `cu121` | `cu124` |
|
476
|
+
| ----------- | ----- | ------- | ------- | ------- |
|
477
|
+
| **Linux** | ✅ | ✅ | ✅ | ✅ |
|
478
|
+
| **Windows** | ✅ | ✅ | ✅ | ✅ |
|
479
|
+
| **macOS** | ✅ | | | |
|
480
480
|
|
481
|
-
**Note:** Binaries of older versions are also provided for PyTorch 1.4.0, PyTorch 1.5.0, PyTorch 1.6.0, PyTorch 1.7.0/1.7.1, PyTorch 1.8.0/1.8.1, PyTorch 1.9.0, PyTorch 1.10.0/1.10.1/1.10.2, PyTorch 1.11.0, PyTorch 1.12.0/1.12.1, PyTorch 1.13.0/1.13.1, PyTorch 2.0.0/2.0.1, PyTorch 2.1.0/2.1.1/2.1.2,
|
481
|
+
**Note:** Binaries of older versions are also provided for PyTorch 1.4.0, PyTorch 1.5.0, PyTorch 1.6.0, PyTorch 1.7.0/1.7.1, PyTorch 1.8.0/1.8.1, PyTorch 1.9.0, PyTorch 1.10.0/1.10.1/1.10.2, PyTorch 1.11.0, PyTorch 1.12.0/1.12.1, PyTorch 1.13.0/1.13.1, PyTorch 2.0.0/2.0.1, PyTorch 2.1.0/2.1.1/2.1.2, PyTorch 2.2.0/2.2.1/2.2.2, and PyTorch 2.3.0/2.3.1 (following the same procedure).
|
482
482
|
**For older versions, you might need to explicitly specify the latest supported version number** or install via `pip install --no-index` in order to prevent a manual installation from source.
|
483
483
|
You can look up the latest supported version number [here](https://data.pyg.org/whl).
|
484
484
|
|
@@ -1,5 +1,5 @@
|
|
1
|
-
torch_geometric/__init__.py,sha256=
|
2
|
-
torch_geometric/_compile.py,sha256=
|
1
|
+
torch_geometric/__init__.py,sha256=48k-an8_K-Va5jV0wsOlouiB4PjW_d1wuzw3wABJB3M,1904
|
2
|
+
torch_geometric/_compile.py,sha256=f-WQeH4VLi5Hn9lrgztFUCSrN_FImjhQa6BxFzcYC38,1338
|
3
3
|
torch_geometric/_onnx.py,sha256=V9ffrIKSqhDw6xUZ12lkuSfNs48cQp2EeJ6Z19GfnVw,349
|
4
4
|
torch_geometric/backend.py,sha256=lVaf7aLoVaB3M-UcByUJ1G4T4FOK6LXAg0CF4W3E8jo,1575
|
5
5
|
torch_geometric/config_mixin.py,sha256=DGQXStKNiPp4iBvtx7aVoofWcUSuVKEHdG5WUL2oNJs,4230
|
@@ -18,7 +18,7 @@ torch_geometric/logging.py,sha256=HmHHLiCcM64k-6UYNOSfXPIeSGNAyiGGcn8cD8tlyuQ,85
|
|
18
18
|
torch_geometric/resolver.py,sha256=fn-_6mCpI2xv7eDZnIFcYrHOn0IrwbkWFLDb9laQrWI,1270
|
19
19
|
torch_geometric/seed.py,sha256=MJLbVwpb9i8mK3oi32sS__Cq-dRq_afTeoOL_HoA9ko,372
|
20
20
|
torch_geometric/template.py,sha256=rqjDWgcSAgTCiV4bkOjWRPaO4PpUdC_RXigzxxBqAu8,1060
|
21
|
-
torch_geometric/typing.py,sha256=
|
21
|
+
torch_geometric/typing.py,sha256=0pxCLue4iqqFC-k5ByqAeyg2mogtWXqgtod3ZOEMq1A,13933
|
22
22
|
torch_geometric/warnings.py,sha256=t114CbkrmiqkXaavx5g7OO52dLdktf-U__B5QqYIQvI,413
|
23
23
|
torch_geometric/contrib/__init__.py,sha256=0pWkmXfZtbdr-AKwlii5LTFggTEH-MCrSKpZxrtPlVs,352
|
24
24
|
torch_geometric/contrib/datasets/__init__.py,sha256=lrGnWsEiJf5zsBRmshGZZFN_uYR2ezDjbj9n9nCpvtk,23
|
@@ -32,13 +32,13 @@ torch_geometric/contrib/transforms/__init__.py,sha256=lrGnWsEiJf5zsBRmshGZZFN_uY
|
|
32
32
|
torch_geometric/data/__init__.py,sha256=OLkV82AGm6xMSynT_DHfRE6_INfPxLx4BQnY0-WVn54,4323
|
33
33
|
torch_geometric/data/batch.py,sha256=C9cT7-rcWPgnG68Eb_uAcn90HS3OvOG6n4fY3ihpFhI,8764
|
34
34
|
torch_geometric/data/collate.py,sha256=RRiUMBLxDAitaHx7zF0qiMR2nW1NY_0uaNdxlUo5-bo,12756
|
35
|
-
torch_geometric/data/data.py,sha256=
|
35
|
+
torch_geometric/data/data.py,sha256=l_gHy18g9WtiSCm1mDinR4vGrZOLetogrw5wJEcn23E,43807
|
36
36
|
torch_geometric/data/database.py,sha256=VTct1xyzXsK0GZahBV9-noviCzjRteAsKMG7VgJ52n0,22998
|
37
37
|
torch_geometric/data/datapipes.py,sha256=9_Cq3j_7LIF4plQFzbLaqyy0LcpKdAic6yiKgMqSX9A,3083
|
38
38
|
torch_geometric/data/dataset.py,sha256=TX2AM3OQkMLOx5Ie8IFtFFYuoA3AGeYwoT3ZqW56N7c,16768
|
39
39
|
torch_geometric/data/download.py,sha256=kcesTu6jlgmCeePpOxDQOnVhxB_GuZ9iu9ds72KEORc,1889
|
40
40
|
torch_geometric/data/extract.py,sha256=X_f0JEo67DF9hOpIlq3QPWXA9RF8uoVFi195UjstzDc,2324
|
41
|
-
torch_geometric/data/feature_store.py,sha256=
|
41
|
+
torch_geometric/data/feature_store.py,sha256=ma65GAHHEoYiZqHs_CkMGAYxeepGc1Bp0TMXmioIfCs,20044
|
42
42
|
torch_geometric/data/graph_store.py,sha256=oFrLDNP5hKf3HWWsFsjcamx5vLIEk8JnLjuGpjrFLdc,13867
|
43
43
|
torch_geometric/data/hetero_data.py,sha256=q0L3bENyEvo_BGLPwZPVzh730Aak6sQ7yXoawPgM72E,47982
|
44
44
|
torch_geometric/data/hypergraph_data.py,sha256=33hsXW25Yz4Ju8mKajYinZOrkqrUi1SqThG7MlOOYNM,8294
|
@@ -515,9 +515,9 @@ torch_geometric/transforms/cartesian.py,sha256=_gdFrPP5q3aPmQW6QvYeI8-nvKNVyQF-W
|
|
515
515
|
torch_geometric/transforms/center.py,sha256=4avx4_wm7Q11epOaMQ2YaVcdcTFjozPBEWG2h6GyKc4,645
|
516
516
|
torch_geometric/transforms/compose.py,sha256=P5AFGd6s9L-lpb8io1jKIm2LjAccp_6Q2XocR5T1z5c,1658
|
517
517
|
torch_geometric/transforms/constant.py,sha256=zDJbO1sEds1vjbRmgzSd-8D8gM4PtvWESuC-gX2qB9E,2005
|
518
|
-
torch_geometric/transforms/delaunay.py,sha256
|
518
|
+
torch_geometric/transforms/delaunay.py,sha256=2mbTqs7oeDZQdifrqldLaLeCZz7pGd_5VMSLOzKMIGE,2668
|
519
519
|
torch_geometric/transforms/distance.py,sha256=DvvI2vAYpxklnKCz3-4w2EXz7AYra0xBZ5m7MAL1tok,2360
|
520
|
-
torch_geometric/transforms/face_to_edge.py,sha256=
|
520
|
+
torch_geometric/transforms/face_to_edge.py,sha256=ohAWtpiCs_qHwrrYmz2eLvAcd_bhfPsxkpHdFnAZkhk,2144
|
521
521
|
torch_geometric/transforms/feature_propagation.py,sha256=GPiKiGU7OuOpBBJeATlCtAtBUy_DSHUoBnJdDK8T81E,3056
|
522
522
|
torch_geometric/transforms/fixed_points.py,sha256=sfcqHZSw542LIYmq1DrTJdyncDRa2Uxf5N50G5lYSfQ,2426
|
523
523
|
torch_geometric/transforms/gcn_norm.py,sha256=INi8f8J3i2OXWgX5U4GNKROpcvJNW42qO39EdLPRPS8,1397
|
@@ -578,13 +578,13 @@ torch_geometric/utils/_negative_sampling.py,sha256=u-7oDg8luSFto-iUqsq7eC9uek6yq
|
|
578
578
|
torch_geometric/utils/_normalize_edge_index.py,sha256=H6DY-Dzi1Psr3igG_nb0U3ZPNZz-BBDntO2iuA8FtzA,1682
|
579
579
|
torch_geometric/utils/_normalized_cut.py,sha256=uwVJkl-Q0tpY-w0nvcHajcQYcqFh1oDOf55XELdjJBU,1167
|
580
580
|
torch_geometric/utils/_one_hot.py,sha256=vXC7l7zudYRZIwWv6mT-Biuk2zKELyqteJXLynPocPM,1404
|
581
|
-
torch_geometric/utils/_scatter.py,sha256=
|
581
|
+
torch_geometric/utils/_scatter.py,sha256=f8nSA_zZXO2YwKMnaGcx_Cz-11UdMxfck-hl0B6Mcng,14614
|
582
582
|
torch_geometric/utils/_segment.py,sha256=CqS7_NMQihX89gEwFVHbyMEZgaEnSlJGpyuWqy3i8HI,1976
|
583
583
|
torch_geometric/utils/_select.py,sha256=BZ5P6-1riR4xhCIJZnsNg5HmeAGelRzH42TpADj9xpQ,2439
|
584
584
|
torch_geometric/utils/_softmax.py,sha256=6dTVbWX04laemRP-ZFPMS6ymRZtRa8zYF22QCXl_m4w,3242
|
585
585
|
torch_geometric/utils/_sort_edge_index.py,sha256=oPS1AkAmPm5Aq8sSGI-e-OkeXOnU5X58Q86eamHi4gA,4500
|
586
586
|
torch_geometric/utils/_spmm.py,sha256=uRq21nCgostC2jE6SKfp2xIY4_BtrXOG5YHbWUEPv10,5794
|
587
|
-
torch_geometric/utils/_subgraph.py,sha256=
|
587
|
+
torch_geometric/utils/_subgraph.py,sha256=GcOGNUcVe97tifuQyi5qBZ88A_Wo3-o17l9xCSIsau4,18456
|
588
588
|
torch_geometric/utils/_to_dense_adj.py,sha256=hl1sboUBvED5Er66bqLms4VdmxKA-7Y3ozJIR-YIAUc,3606
|
589
589
|
torch_geometric/utils/_to_dense_batch.py,sha256=-K5NjjfvjKYKJQ3kXgNIDR7lwMJ_GGISI45b50IGMvY,4582
|
590
590
|
torch_geometric/utils/_train_test_split_edges.py,sha256=KnBDgnaKuJYTHUOIlvFtzvkHUe-93DG3ckST4-wOERM,3569
|
@@ -616,8 +616,8 @@ torch_geometric/utils/smiles.py,sha256=4xTW56OWqvQcM5i2LEvsESAIvd2n0I17n9tvarHok
|
|
616
616
|
torch_geometric/utils/sparse.py,sha256=uYd0oPrp5XN0c2Zc15f-00rhhVMfLnRMqNcqcmILNKQ,25519
|
617
617
|
torch_geometric/utils/undirected.py,sha256=H_nfpI0_WluOG6VfjPyldvcjL4w5USAKWu2x5nUAUjw,6222
|
618
618
|
torch_geometric/visualization/__init__.py,sha256=PyR_4K5SafsJrBr6qWrkjKr6GBL1b7FtZybyXCDEVwY,154
|
619
|
-
torch_geometric/visualization/graph.py,sha256=
|
619
|
+
torch_geometric/visualization/graph.py,sha256=ZuLPL92yGRi7lxlqsUPwL_EVVXF7P2kMcveTtW79vpA,4784
|
620
620
|
torch_geometric/visualization/influence.py,sha256=CWMvuNA_Nf1sfbJmQgn58yS4OFpeKXeZPe7kEuvkUBw,477
|
621
|
-
pyg_nightly-2.7.0.
|
622
|
-
pyg_nightly-2.7.0.
|
623
|
-
pyg_nightly-2.7.0.
|
621
|
+
pyg_nightly-2.7.0.dev20241114.dist-info/WHEEL,sha256=CpUCUxeHQbRN5UGRQHYRJorO5Af-Qy_fHMctcQ8DSGI,82
|
622
|
+
pyg_nightly-2.7.0.dev20241114.dist-info/METADATA,sha256=urxpy2ZlbcTbEvAnn12xbIIGtjyQbXfja_co7d7DC9c,62979
|
623
|
+
pyg_nightly-2.7.0.dev20241114.dist-info/RECORD,,
|
torch_geometric/__init__.py
CHANGED
@@ -30,7 +30,7 @@ from .lazy_loader import LazyLoader
|
|
30
30
|
contrib = LazyLoader('contrib', globals(), 'torch_geometric.contrib')
|
31
31
|
graphgym = LazyLoader('graphgym', globals(), 'torch_geometric.graphgym')
|
32
32
|
|
33
|
-
__version__ = '2.7.0.
|
33
|
+
__version__ = '2.7.0.dev20241114'
|
34
34
|
|
35
35
|
__all__ = [
|
36
36
|
'Index',
|
torch_geometric/_compile.py
CHANGED
torch_geometric/data/data.py
CHANGED
@@ -659,7 +659,13 @@ class Data(BaseData, FeatureStore, GraphStore):
|
|
659
659
|
return value.get_dim_size()
|
660
660
|
return int(value.max()) + 1
|
661
661
|
elif 'index' in key or key == 'face':
|
662
|
-
|
662
|
+
num_nodes = self.num_nodes
|
663
|
+
if num_nodes is None:
|
664
|
+
raise RuntimeError(f"Unable to infer 'num_nodes' from the "
|
665
|
+
f"attribute '{key}'. Please explicitly set "
|
666
|
+
f"'num_nodes' as an attribute of 'data' to "
|
667
|
+
f"prevent this error")
|
668
|
+
return num_nodes
|
663
669
|
else:
|
664
670
|
return 0
|
665
671
|
|
@@ -74,13 +74,6 @@ class TensorAttr(CastMixin):
|
|
74
74
|
r"""Whether the :obj:`TensorAttr` has no unset fields."""
|
75
75
|
return all([self.is_set(key) for key in self.__dataclass_fields__])
|
76
76
|
|
77
|
-
def fully_specify(self) -> 'TensorAttr':
|
78
|
-
r"""Sets all :obj:`UNSET` fields to :obj:`None`."""
|
79
|
-
for key in self.__dataclass_fields__:
|
80
|
-
if not self.is_set(key):
|
81
|
-
setattr(self, key, None)
|
82
|
-
return self
|
83
|
-
|
84
77
|
def update(self, attr: 'TensorAttr') -> 'TensorAttr':
|
85
78
|
r"""Updates an :class:`TensorAttr` with set attributes from another
|
86
79
|
:class:`TensorAttr`.
|
@@ -230,10 +223,11 @@ class AttrView(CastMixin):
|
|
230
223
|
|
231
224
|
store[group_name, attr_name]()
|
232
225
|
"""
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
226
|
+
attr = copy.copy(self._attr)
|
227
|
+
for key in attr.__dataclass_fields__: # Set all UNSET values to None.
|
228
|
+
if not attr.is_set(key):
|
229
|
+
setattr(attr, key, None)
|
230
|
+
return self._store.get_tensor(attr)
|
237
231
|
|
238
232
|
def __copy__(self) -> 'AttrView':
|
239
233
|
out = self.__class__.__new__(self.__class__)
|
@@ -479,9 +473,7 @@ class FeatureStore(ABC):
|
|
479
473
|
# CastMixin will handle the case of key being a tuple or TensorAttr
|
480
474
|
# object:
|
481
475
|
key = self._tensor_attr_cls.cast(key)
|
482
|
-
|
483
|
-
# sense to work with a view here:
|
484
|
-
key.fully_specify()
|
476
|
+
assert key.is_fully_specified()
|
485
477
|
self.put_tensor(value, key)
|
486
478
|
|
487
479
|
def __getitem__(self, key: TensorAttr) -> Any:
|
@@ -503,13 +495,16 @@ class FeatureStore(ABC):
|
|
503
495
|
# If the view is not fully-specified, return a :class:`AttrView`:
|
504
496
|
return self.view(attr)
|
505
497
|
|
506
|
-
def __delitem__(self,
|
498
|
+
def __delitem__(self, attr: TensorAttr):
|
507
499
|
r"""Supports :obj:`del store[tensor_attr]`."""
|
508
500
|
# CastMixin will handle the case of key being a tuple or TensorAttr
|
509
501
|
# object:
|
510
|
-
|
511
|
-
|
512
|
-
|
502
|
+
attr = self._tensor_attr_cls.cast(attr)
|
503
|
+
attr = copy.copy(attr)
|
504
|
+
for key in attr.__dataclass_fields__: # Set all UNSET values to None.
|
505
|
+
if not attr.is_set(key):
|
506
|
+
setattr(attr, key, None)
|
507
|
+
self.remove_tensor(attr)
|
513
508
|
|
514
509
|
def __iter__(self):
|
515
510
|
raise NotImplementedError
|
@@ -1,3 +1,5 @@
|
|
1
|
+
from typing import List
|
2
|
+
|
1
3
|
import torch
|
2
4
|
|
3
5
|
from torch_geometric.data import Data
|
@@ -5,30 +7,78 @@ from torch_geometric.data.datapipes import functional_transform
|
|
5
7
|
from torch_geometric.transforms import BaseTransform
|
6
8
|
|
7
9
|
|
10
|
+
class _QhullTransform(BaseTransform):
|
11
|
+
r"""Q-hull implementation of delaunay triangulation."""
|
12
|
+
def forward(self, data: Data) -> Data:
|
13
|
+
assert data.pos is not None
|
14
|
+
import scipy.spatial
|
15
|
+
|
16
|
+
pos = data.pos.cpu().numpy()
|
17
|
+
tri = scipy.spatial.Delaunay(pos, qhull_options='QJ')
|
18
|
+
face = torch.from_numpy(tri.simplices)
|
19
|
+
|
20
|
+
data.face = face.t().contiguous().to(data.pos.device, torch.long)
|
21
|
+
return data
|
22
|
+
|
23
|
+
|
24
|
+
class _ShullTransform(BaseTransform):
|
25
|
+
r"""Sweep-hull implementation of delaunay triangulation."""
|
26
|
+
def forward(self, data: Data) -> Data:
|
27
|
+
assert data.pos is not None
|
28
|
+
from torch_delaunay.functional import shull2d
|
29
|
+
|
30
|
+
face = shull2d(data.pos.cpu())
|
31
|
+
data.face = face.t().contiguous().to(data.pos.device)
|
32
|
+
return data
|
33
|
+
|
34
|
+
|
35
|
+
class _SequentialTransform(BaseTransform):
|
36
|
+
r"""Runs the first successful transformation.
|
37
|
+
|
38
|
+
All intermediate exceptions are suppressed except the last.
|
39
|
+
"""
|
40
|
+
def __init__(self, transforms: List[BaseTransform]) -> None:
|
41
|
+
assert len(transforms) > 0
|
42
|
+
self.transforms = transforms
|
43
|
+
|
44
|
+
def forward(self, data: Data) -> Data:
|
45
|
+
for i, transform in enumerate(self.transforms):
|
46
|
+
try:
|
47
|
+
return transform.forward(data)
|
48
|
+
except ImportError as e:
|
49
|
+
if i == len(self.transforms) - 1:
|
50
|
+
raise e
|
51
|
+
return data
|
52
|
+
|
53
|
+
|
8
54
|
@functional_transform('delaunay')
|
9
55
|
class Delaunay(BaseTransform):
|
10
56
|
r"""Computes the delaunay triangulation of a set of points
|
11
57
|
(functional name: :obj:`delaunay`).
|
58
|
+
|
59
|
+
.. hint::
|
60
|
+
Consider installing the
|
61
|
+
`torch_delaunay <https://github.com/ybubnov/torch_delaunay>`_ package
|
62
|
+
to speed up computation.
|
12
63
|
"""
|
13
|
-
def
|
14
|
-
|
64
|
+
def __init__(self) -> None:
|
65
|
+
self._transform = _SequentialTransform([
|
66
|
+
_ShullTransform(),
|
67
|
+
_QhullTransform(),
|
68
|
+
])
|
15
69
|
|
70
|
+
def forward(self, data: Data) -> Data:
|
16
71
|
assert data.pos is not None
|
72
|
+
device = data.pos.device
|
17
73
|
|
18
74
|
if data.pos.size(0) < 2:
|
19
|
-
data.edge_index = torch.
|
20
|
-
|
21
|
-
|
22
|
-
data.edge_index = torch.tensor([[0, 1], [1, 0]],
|
23
|
-
device=data.pos.device)
|
75
|
+
data.edge_index = torch.empty(2, 0, dtype=torch.long,
|
76
|
+
device=device)
|
77
|
+
elif data.pos.size(0) == 2:
|
78
|
+
data.edge_index = torch.tensor([[0, 1], [1, 0]], device=device)
|
24
79
|
elif data.pos.size(0) == 3:
|
25
|
-
data.face = torch.tensor([[0], [1], [2]],
|
26
|
-
|
27
|
-
|
28
|
-
pos = data.pos.cpu().numpy()
|
29
|
-
tri = scipy.spatial.Delaunay(pos, qhull_options='QJ')
|
30
|
-
face = torch.from_numpy(tri.simplices)
|
31
|
-
|
32
|
-
data.face = face.t().contiguous().to(data.pos.device, torch.long)
|
80
|
+
data.face = torch.tensor([[0], [1], [2]], device=device)
|
81
|
+
else:
|
82
|
+
data = self._transform.forward(data)
|
33
83
|
|
34
84
|
return data
|
@@ -8,8 +8,15 @@ from torch_geometric.utils import to_undirected
|
|
8
8
|
|
9
9
|
@functional_transform('face_to_edge')
|
10
10
|
class FaceToEdge(BaseTransform):
|
11
|
-
r"""Converts mesh faces :obj:`[3, num_faces]`
|
12
|
-
:obj:`[
|
11
|
+
r"""Converts mesh faces of shape :obj:`[3, num_faces]` or
|
12
|
+
:obj:`[4, num_faces]` to edge indices of shape :obj:`[2, num_edges]`
|
13
|
+
(functional name: :obj:`face_to_edge`).
|
14
|
+
|
15
|
+
This transform supports both 2D triangular faces, represented by a
|
16
|
+
tensor of shape :obj:`[3, num_faces]`, and 3D tetrahedral mesh faces,
|
17
|
+
represented by a tensor of shape :obj:`[4, num_faces]`. It will convert
|
18
|
+
these faces into edge indices, where each edge is defined by the indices
|
19
|
+
of its two endpoints.
|
13
20
|
|
14
21
|
Args:
|
15
22
|
remove_faces (bool, optional): If set to :obj:`False`, the face tensor
|
@@ -22,7 +29,29 @@ class FaceToEdge(BaseTransform):
|
|
22
29
|
if hasattr(data, 'face'):
|
23
30
|
assert data.face is not None
|
24
31
|
face = data.face
|
25
|
-
|
32
|
+
|
33
|
+
if face.size(0) not in [3, 4]:
|
34
|
+
raise RuntimeError(f"Expected 'face' tensor with shape "
|
35
|
+
f"[3, num_faces] or [4, num_faces] "
|
36
|
+
f"(got {list(face.size())})")
|
37
|
+
|
38
|
+
if face.size()[0] == 3:
|
39
|
+
edge_index = torch.cat([
|
40
|
+
face[:2],
|
41
|
+
face[1:],
|
42
|
+
face[::2],
|
43
|
+
], dim=1)
|
44
|
+
else:
|
45
|
+
assert face.size()[0] == 4
|
46
|
+
edge_index = torch.cat([
|
47
|
+
face[:2],
|
48
|
+
face[1:3],
|
49
|
+
face[2:4],
|
50
|
+
face[::2],
|
51
|
+
face[1::2],
|
52
|
+
face[::3],
|
53
|
+
], dim=1)
|
54
|
+
|
26
55
|
edge_index = to_undirected(edge_index, num_nodes=data.num_nodes)
|
27
56
|
|
28
57
|
data.edge_index = edge_index
|
torch_geometric/typing.py
CHANGED
@@ -15,6 +15,7 @@ WITH_PT22 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 2
|
|
15
15
|
WITH_PT23 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 3
|
16
16
|
WITH_PT24 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 4
|
17
17
|
WITH_PT25 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 5
|
18
|
+
WITH_PT26 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 6
|
18
19
|
WITH_PT111 = WITH_PT20 or int(torch.__version__.split('.')[1]) >= 11
|
19
20
|
WITH_PT112 = WITH_PT20 or int(torch.__version__.split('.')[1]) >= 12
|
20
21
|
WITH_PT113 = WITH_PT20 or int(torch.__version__.split('.')[1]) >= 13
|
@@ -351,5 +351,5 @@ def group_cat(
|
|
351
351
|
"""
|
352
352
|
assert len(tensors) == len(indices)
|
353
353
|
index, perm = torch.cat(indices).sort(stable=True)
|
354
|
-
out = torch.cat(tensors, dim=
|
354
|
+
out = torch.cat(tensors, dim=dim).index_select(dim, perm)
|
355
355
|
return (out, index) if return_index else out
|
@@ -346,10 +346,12 @@ def k_hop_subgraph(
|
|
346
346
|
|
347
347
|
subsets = [node_idx]
|
348
348
|
|
349
|
+
preserved_edge_mask = torch.zeros_like(edge_mask)
|
349
350
|
for _ in range(num_hops):
|
350
351
|
node_mask.fill_(False)
|
351
352
|
node_mask[subsets[-1]] = True
|
352
353
|
torch.index_select(node_mask, 0, row, out=edge_mask)
|
354
|
+
preserved_edge_mask |= edge_mask
|
353
355
|
subsets.append(col[edge_mask])
|
354
356
|
|
355
357
|
subset, inv = torch.cat(subsets).unique(return_inverse=True)
|
@@ -360,6 +362,8 @@ def k_hop_subgraph(
|
|
360
362
|
|
361
363
|
if not directed:
|
362
364
|
edge_mask = node_mask[row] & node_mask[col]
|
365
|
+
else:
|
366
|
+
edge_mask = preserved_edge_mask
|
363
367
|
|
364
368
|
edge_index = edge_index[:, edge_mask]
|
365
369
|
|
File without changes
|