pyg-nightly 2.7.0.dev20241111__py3-none-any.whl → 2.7.0.dev20241114__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: pyg-nightly
3
- Version: 2.7.0.dev20241111
3
+ Version: 2.7.0.dev20241114
4
4
  Summary: Graph Neural Network Library for PyTorch
5
5
  Keywords: deep-learning,pytorch,geometric-deep-learning,graph-neural-networks,graph-convolutional-networks
6
6
  Author-email: Matthias Fey <matthias@pyg.org>
@@ -446,12 +446,12 @@ We recommend to start with a minimal installation, and install additional depend
446
446
 
447
447
  For ease of installation of these extensions, we provide `pip` wheels for all major OS/PyTorch/CUDA combinations, see [here](https://data.pyg.org/whl).
448
448
 
449
- #### PyTorch 2.4
449
+ #### PyTorch 2.5
450
450
 
451
- To install the binaries for PyTorch 2.4.0, simply run
451
+ To install the binaries for PyTorch 2.5.0, simply run
452
452
 
453
453
  ```
454
- pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.4.0+${CUDA}.html
454
+ pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.5.0+${CUDA}.html
455
455
  ```
456
456
 
457
457
  where `${CUDA}` should be replaced by either `cpu`, `cu118`, `cu121`, or `cu124` depending on your PyTorch installation.
@@ -462,23 +462,23 @@ where `${CUDA}` should be replaced by either `cpu`, `cu118`, `cu121`, or `cu124`
462
462
  | **Windows** | ✅ | ✅ | ✅ | ✅ |
463
463
  | **macOS** | ✅ | | | |
464
464
 
465
- #### PyTorch 2.3
465
+ #### PyTorch 2.4
466
466
 
467
- To install the binaries for PyTorch 2.3.0, simply run
467
+ To install the binaries for PyTorch 2.4.0, simply run
468
468
 
469
469
  ```
470
- pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.3.0+${CUDA}.html
470
+ pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.4.0+${CUDA}.html
471
471
  ```
472
472
 
473
- where `${CUDA}` should be replaced by either `cpu`, `cu118`, or `cu121` depending on your PyTorch installation.
473
+ where `${CUDA}` should be replaced by either `cpu`, `cu118`, `cu121`, or `cu124` depending on your PyTorch installation.
474
474
 
475
- | | `cpu` | `cu118` | `cu121` |
476
- | ----------- | ----- | ------- | ------- |
477
- | **Linux** | ✅ | ✅ | ✅ |
478
- | **Windows** | ✅ | ✅ | ✅ |
479
- | **macOS** | ✅ | | |
475
+ | | `cpu` | `cu118` | `cu121` | `cu124` |
476
+ | ----------- | ----- | ------- | ------- | ------- |
477
+ | **Linux** | ✅ | ✅ | ✅ | ✅ |
478
+ | **Windows** | ✅ | ✅ | ✅ | ✅ |
479
+ | **macOS** | ✅ | | | |
480
480
 
481
- **Note:** Binaries of older versions are also provided for PyTorch 1.4.0, PyTorch 1.5.0, PyTorch 1.6.0, PyTorch 1.7.0/1.7.1, PyTorch 1.8.0/1.8.1, PyTorch 1.9.0, PyTorch 1.10.0/1.10.1/1.10.2, PyTorch 1.11.0, PyTorch 1.12.0/1.12.1, PyTorch 1.13.0/1.13.1, PyTorch 2.0.0/2.0.1, PyTorch 2.1.0/2.1.1/2.1.2, and PyTorch 2.2.0/2.2.1/2.2.2 (following the same procedure).
481
+ **Note:** Binaries of older versions are also provided for PyTorch 1.4.0, PyTorch 1.5.0, PyTorch 1.6.0, PyTorch 1.7.0/1.7.1, PyTorch 1.8.0/1.8.1, PyTorch 1.9.0, PyTorch 1.10.0/1.10.1/1.10.2, PyTorch 1.11.0, PyTorch 1.12.0/1.12.1, PyTorch 1.13.0/1.13.1, PyTorch 2.0.0/2.0.1, PyTorch 2.1.0/2.1.1/2.1.2, PyTorch 2.2.0/2.2.1/2.2.2, and PyTorch 2.3.0/2.3.1 (following the same procedure).
482
482
  **For older versions, you might need to explicitly specify the latest supported version number** or install via `pip install --no-index` in order to prevent a manual installation from source.
483
483
  You can look up the latest supported version number [here](https://data.pyg.org/whl).
484
484
 
@@ -1,5 +1,5 @@
1
- torch_geometric/__init__.py,sha256=zixV3WisaNJLj2ys2odFv7wocaCbwFr4A8uxGnDc63A,1904
2
- torch_geometric/_compile.py,sha256=REjj1_qX8YBrva6iqr3AsNiDueTAy2BhLZkdezKL2MY,1322
1
+ torch_geometric/__init__.py,sha256=48k-an8_K-Va5jV0wsOlouiB4PjW_d1wuzw3wABJB3M,1904
2
+ torch_geometric/_compile.py,sha256=f-WQeH4VLi5Hn9lrgztFUCSrN_FImjhQa6BxFzcYC38,1338
3
3
  torch_geometric/_onnx.py,sha256=V9ffrIKSqhDw6xUZ12lkuSfNs48cQp2EeJ6Z19GfnVw,349
4
4
  torch_geometric/backend.py,sha256=lVaf7aLoVaB3M-UcByUJ1G4T4FOK6LXAg0CF4W3E8jo,1575
5
5
  torch_geometric/config_mixin.py,sha256=DGQXStKNiPp4iBvtx7aVoofWcUSuVKEHdG5WUL2oNJs,4230
@@ -18,7 +18,7 @@ torch_geometric/logging.py,sha256=HmHHLiCcM64k-6UYNOSfXPIeSGNAyiGGcn8cD8tlyuQ,85
18
18
  torch_geometric/resolver.py,sha256=fn-_6mCpI2xv7eDZnIFcYrHOn0IrwbkWFLDb9laQrWI,1270
19
19
  torch_geometric/seed.py,sha256=MJLbVwpb9i8mK3oi32sS__Cq-dRq_afTeoOL_HoA9ko,372
20
20
  torch_geometric/template.py,sha256=rqjDWgcSAgTCiV4bkOjWRPaO4PpUdC_RXigzxxBqAu8,1060
21
- torch_geometric/typing.py,sha256=OU7zhpnwarQ5cCzl8Sfvh5aKr3RDG3tZot7-WO4a_Yo,13865
21
+ torch_geometric/typing.py,sha256=0pxCLue4iqqFC-k5ByqAeyg2mogtWXqgtod3ZOEMq1A,13933
22
22
  torch_geometric/warnings.py,sha256=t114CbkrmiqkXaavx5g7OO52dLdktf-U__B5QqYIQvI,413
23
23
  torch_geometric/contrib/__init__.py,sha256=0pWkmXfZtbdr-AKwlii5LTFggTEH-MCrSKpZxrtPlVs,352
24
24
  torch_geometric/contrib/datasets/__init__.py,sha256=lrGnWsEiJf5zsBRmshGZZFN_uYR2ezDjbj9n9nCpvtk,23
@@ -32,13 +32,13 @@ torch_geometric/contrib/transforms/__init__.py,sha256=lrGnWsEiJf5zsBRmshGZZFN_uY
32
32
  torch_geometric/data/__init__.py,sha256=OLkV82AGm6xMSynT_DHfRE6_INfPxLx4BQnY0-WVn54,4323
33
33
  torch_geometric/data/batch.py,sha256=C9cT7-rcWPgnG68Eb_uAcn90HS3OvOG6n4fY3ihpFhI,8764
34
34
  torch_geometric/data/collate.py,sha256=RRiUMBLxDAitaHx7zF0qiMR2nW1NY_0uaNdxlUo5-bo,12756
35
- torch_geometric/data/data.py,sha256=6HeA8tSMAcjWDsJqtRDMC0mmgnvvHCUDWfIusA7ObBA,43445
35
+ torch_geometric/data/data.py,sha256=l_gHy18g9WtiSCm1mDinR4vGrZOLetogrw5wJEcn23E,43807
36
36
  torch_geometric/data/database.py,sha256=VTct1xyzXsK0GZahBV9-noviCzjRteAsKMG7VgJ52n0,22998
37
37
  torch_geometric/data/datapipes.py,sha256=9_Cq3j_7LIF4plQFzbLaqyy0LcpKdAic6yiKgMqSX9A,3083
38
38
  torch_geometric/data/dataset.py,sha256=TX2AM3OQkMLOx5Ie8IFtFFYuoA3AGeYwoT3ZqW56N7c,16768
39
39
  torch_geometric/data/download.py,sha256=kcesTu6jlgmCeePpOxDQOnVhxB_GuZ9iu9ds72KEORc,1889
40
40
  torch_geometric/data/extract.py,sha256=X_f0JEo67DF9hOpIlq3QPWXA9RF8uoVFi195UjstzDc,2324
41
- torch_geometric/data/feature_store.py,sha256=m_VzeKl_WLEPaT_OsPbVdBNUWH2vdKfxk28qCa1iVfA,20153
41
+ torch_geometric/data/feature_store.py,sha256=ma65GAHHEoYiZqHs_CkMGAYxeepGc1Bp0TMXmioIfCs,20044
42
42
  torch_geometric/data/graph_store.py,sha256=oFrLDNP5hKf3HWWsFsjcamx5vLIEk8JnLjuGpjrFLdc,13867
43
43
  torch_geometric/data/hetero_data.py,sha256=q0L3bENyEvo_BGLPwZPVzh730Aak6sQ7yXoawPgM72E,47982
44
44
  torch_geometric/data/hypergraph_data.py,sha256=33hsXW25Yz4Ju8mKajYinZOrkqrUi1SqThG7MlOOYNM,8294
@@ -515,9 +515,9 @@ torch_geometric/transforms/cartesian.py,sha256=_gdFrPP5q3aPmQW6QvYeI8-nvKNVyQF-W
515
515
  torch_geometric/transforms/center.py,sha256=4avx4_wm7Q11epOaMQ2YaVcdcTFjozPBEWG2h6GyKc4,645
516
516
  torch_geometric/transforms/compose.py,sha256=P5AFGd6s9L-lpb8io1jKIm2LjAccp_6Q2XocR5T1z5c,1658
517
517
  torch_geometric/transforms/constant.py,sha256=zDJbO1sEds1vjbRmgzSd-8D8gM4PtvWESuC-gX2qB9E,2005
518
- torch_geometric/transforms/delaunay.py,sha256=-7JIDKhjk8h1cFGVYiqRszwHFDf5dMzKOiB1ejUah_o,1273
518
+ torch_geometric/transforms/delaunay.py,sha256=2mbTqs7oeDZQdifrqldLaLeCZz7pGd_5VMSLOzKMIGE,2668
519
519
  torch_geometric/transforms/distance.py,sha256=DvvI2vAYpxklnKCz3-4w2EXz7AYra0xBZ5m7MAL1tok,2360
520
- torch_geometric/transforms/face_to_edge.py,sha256=Gr0NwRtVwp51hkDJ2bQVkGbSs2k6ePrRiua3ZNt6-K8,1083
520
+ torch_geometric/transforms/face_to_edge.py,sha256=ohAWtpiCs_qHwrrYmz2eLvAcd_bhfPsxkpHdFnAZkhk,2144
521
521
  torch_geometric/transforms/feature_propagation.py,sha256=GPiKiGU7OuOpBBJeATlCtAtBUy_DSHUoBnJdDK8T81E,3056
522
522
  torch_geometric/transforms/fixed_points.py,sha256=sfcqHZSw542LIYmq1DrTJdyncDRa2Uxf5N50G5lYSfQ,2426
523
523
  torch_geometric/transforms/gcn_norm.py,sha256=INi8f8J3i2OXWgX5U4GNKROpcvJNW42qO39EdLPRPS8,1397
@@ -578,13 +578,13 @@ torch_geometric/utils/_negative_sampling.py,sha256=u-7oDg8luSFto-iUqsq7eC9uek6yq
578
578
  torch_geometric/utils/_normalize_edge_index.py,sha256=H6DY-Dzi1Psr3igG_nb0U3ZPNZz-BBDntO2iuA8FtzA,1682
579
579
  torch_geometric/utils/_normalized_cut.py,sha256=uwVJkl-Q0tpY-w0nvcHajcQYcqFh1oDOf55XELdjJBU,1167
580
580
  torch_geometric/utils/_one_hot.py,sha256=vXC7l7zudYRZIwWv6mT-Biuk2zKELyqteJXLynPocPM,1404
581
- torch_geometric/utils/_scatter.py,sha256=yWayqkSPs5G5tsku2HPxJaHcqFRaSaam_L70Gb-7Uwg,14594
581
+ torch_geometric/utils/_scatter.py,sha256=f8nSA_zZXO2YwKMnaGcx_Cz-11UdMxfck-hl0B6Mcng,14614
582
582
  torch_geometric/utils/_segment.py,sha256=CqS7_NMQihX89gEwFVHbyMEZgaEnSlJGpyuWqy3i8HI,1976
583
583
  torch_geometric/utils/_select.py,sha256=BZ5P6-1riR4xhCIJZnsNg5HmeAGelRzH42TpADj9xpQ,2439
584
584
  torch_geometric/utils/_softmax.py,sha256=6dTVbWX04laemRP-ZFPMS6ymRZtRa8zYF22QCXl_m4w,3242
585
585
  torch_geometric/utils/_sort_edge_index.py,sha256=oPS1AkAmPm5Aq8sSGI-e-OkeXOnU5X58Q86eamHi4gA,4500
586
586
  torch_geometric/utils/_spmm.py,sha256=uRq21nCgostC2jE6SKfp2xIY4_BtrXOG5YHbWUEPv10,5794
587
- torch_geometric/utils/_subgraph.py,sha256=T7olS3B4QAy6BbC6FaHIgKT9usTdDBzcRnP-VCGmWL8,18311
587
+ torch_geometric/utils/_subgraph.py,sha256=GcOGNUcVe97tifuQyi5qBZ88A_Wo3-o17l9xCSIsau4,18456
588
588
  torch_geometric/utils/_to_dense_adj.py,sha256=hl1sboUBvED5Er66bqLms4VdmxKA-7Y3ozJIR-YIAUc,3606
589
589
  torch_geometric/utils/_to_dense_batch.py,sha256=-K5NjjfvjKYKJQ3kXgNIDR7lwMJ_GGISI45b50IGMvY,4582
590
590
  torch_geometric/utils/_train_test_split_edges.py,sha256=KnBDgnaKuJYTHUOIlvFtzvkHUe-93DG3ckST4-wOERM,3569
@@ -616,8 +616,8 @@ torch_geometric/utils/smiles.py,sha256=4xTW56OWqvQcM5i2LEvsESAIvd2n0I17n9tvarHok
616
616
  torch_geometric/utils/sparse.py,sha256=uYd0oPrp5XN0c2Zc15f-00rhhVMfLnRMqNcqcmILNKQ,25519
617
617
  torch_geometric/utils/undirected.py,sha256=H_nfpI0_WluOG6VfjPyldvcjL4w5USAKWu2x5nUAUjw,6222
618
618
  torch_geometric/visualization/__init__.py,sha256=PyR_4K5SafsJrBr6qWrkjKr6GBL1b7FtZybyXCDEVwY,154
619
- torch_geometric/visualization/graph.py,sha256=AGKqbtTdL14w7xIhy6n3g4bpCOnujKt-pXHCNzovxB4,4784
619
+ torch_geometric/visualization/graph.py,sha256=ZuLPL92yGRi7lxlqsUPwL_EVVXF7P2kMcveTtW79vpA,4784
620
620
  torch_geometric/visualization/influence.py,sha256=CWMvuNA_Nf1sfbJmQgn58yS4OFpeKXeZPe7kEuvkUBw,477
621
- pyg_nightly-2.7.0.dev20241111.dist-info/WHEEL,sha256=CpUCUxeHQbRN5UGRQHYRJorO5Af-Qy_fHMctcQ8DSGI,82
622
- pyg_nightly-2.7.0.dev20241111.dist-info/METADATA,sha256=iI2qoQgS9-zDN86zQuJjwdCON3Bf8YD3B7HAySrpqTg,62897
623
- pyg_nightly-2.7.0.dev20241111.dist-info/RECORD,,
621
+ pyg_nightly-2.7.0.dev20241114.dist-info/WHEEL,sha256=CpUCUxeHQbRN5UGRQHYRJorO5Af-Qy_fHMctcQ8DSGI,82
622
+ pyg_nightly-2.7.0.dev20241114.dist-info/METADATA,sha256=urxpy2ZlbcTbEvAnn12xbIIGtjyQbXfja_co7d7DC9c,62979
623
+ pyg_nightly-2.7.0.dev20241114.dist-info/RECORD,,
@@ -30,7 +30,7 @@ from .lazy_loader import LazyLoader
30
30
  contrib = LazyLoader('contrib', globals(), 'torch_geometric.contrib')
31
31
  graphgym = LazyLoader('graphgym', globals(), 'torch_geometric.graphgym')
32
32
 
33
- __version__ = '2.7.0.dev20241111'
33
+ __version__ = '2.7.0.dev20241114'
34
34
 
35
35
  __all__ = [
36
36
  'Index',
@@ -38,4 +38,4 @@ def compile(
38
38
  """
39
39
  warnings.warn("'torch_geometric.compile' is deprecated in favor of "
40
40
  "'torch.compile'")
41
- return torch.compile(model, *args, **kwargs)
41
+ return torch.compile(model, *args, **kwargs) # type: ignore
@@ -659,7 +659,13 @@ class Data(BaseData, FeatureStore, GraphStore):
659
659
  return value.get_dim_size()
660
660
  return int(value.max()) + 1
661
661
  elif 'index' in key or key == 'face':
662
- return self.num_nodes
662
+ num_nodes = self.num_nodes
663
+ if num_nodes is None:
664
+ raise RuntimeError(f"Unable to infer 'num_nodes' from the "
665
+ f"attribute '{key}'. Please explicitly set "
666
+ f"'num_nodes' as an attribute of 'data' to "
667
+ f"prevent this error")
668
+ return num_nodes
663
669
  else:
664
670
  return 0
665
671
 
@@ -74,13 +74,6 @@ class TensorAttr(CastMixin):
74
74
  r"""Whether the :obj:`TensorAttr` has no unset fields."""
75
75
  return all([self.is_set(key) for key in self.__dataclass_fields__])
76
76
 
77
- def fully_specify(self) -> 'TensorAttr':
78
- r"""Sets all :obj:`UNSET` fields to :obj:`None`."""
79
- for key in self.__dataclass_fields__:
80
- if not self.is_set(key):
81
- setattr(self, key, None)
82
- return self
83
-
84
77
  def update(self, attr: 'TensorAttr') -> 'TensorAttr':
85
78
  r"""Updates an :class:`TensorAttr` with set attributes from another
86
79
  :class:`TensorAttr`.
@@ -230,10 +223,11 @@ class AttrView(CastMixin):
230
223
 
231
224
  store[group_name, attr_name]()
232
225
  """
233
- # Set all UNSET values to None:
234
- out = copy.copy(self)
235
- out._attr.fully_specify()
236
- return out._store.get_tensor(out._attr)
226
+ attr = copy.copy(self._attr)
227
+ for key in attr.__dataclass_fields__: # Set all UNSET values to None.
228
+ if not attr.is_set(key):
229
+ setattr(attr, key, None)
230
+ return self._store.get_tensor(attr)
237
231
 
238
232
  def __copy__(self) -> 'AttrView':
239
233
  out = self.__class__.__new__(self.__class__)
@@ -479,9 +473,7 @@ class FeatureStore(ABC):
479
473
  # CastMixin will handle the case of key being a tuple or TensorAttr
480
474
  # object:
481
475
  key = self._tensor_attr_cls.cast(key)
482
- # We need to fully-specify the key for __setitem__ as it does not make
483
- # sense to work with a view here:
484
- key.fully_specify()
476
+ assert key.is_fully_specified()
485
477
  self.put_tensor(value, key)
486
478
 
487
479
  def __getitem__(self, key: TensorAttr) -> Any:
@@ -503,13 +495,16 @@ class FeatureStore(ABC):
503
495
  # If the view is not fully-specified, return a :class:`AttrView`:
504
496
  return self.view(attr)
505
497
 
506
- def __delitem__(self, key: TensorAttr):
498
+ def __delitem__(self, attr: TensorAttr):
507
499
  r"""Supports :obj:`del store[tensor_attr]`."""
508
500
  # CastMixin will handle the case of key being a tuple or TensorAttr
509
501
  # object:
510
- key = self._tensor_attr_cls.cast(key)
511
- key.fully_specify()
512
- self.remove_tensor(key)
502
+ attr = self._tensor_attr_cls.cast(attr)
503
+ attr = copy.copy(attr)
504
+ for key in attr.__dataclass_fields__: # Set all UNSET values to None.
505
+ if not attr.is_set(key):
506
+ setattr(attr, key, None)
507
+ self.remove_tensor(attr)
513
508
 
514
509
  def __iter__(self):
515
510
  raise NotImplementedError
@@ -1,3 +1,5 @@
1
+ from typing import List
2
+
1
3
  import torch
2
4
 
3
5
  from torch_geometric.data import Data
@@ -5,30 +7,78 @@ from torch_geometric.data.datapipes import functional_transform
5
7
  from torch_geometric.transforms import BaseTransform
6
8
 
7
9
 
10
+ class _QhullTransform(BaseTransform):
11
+ r"""Q-hull implementation of delaunay triangulation."""
12
+ def forward(self, data: Data) -> Data:
13
+ assert data.pos is not None
14
+ import scipy.spatial
15
+
16
+ pos = data.pos.cpu().numpy()
17
+ tri = scipy.spatial.Delaunay(pos, qhull_options='QJ')
18
+ face = torch.from_numpy(tri.simplices)
19
+
20
+ data.face = face.t().contiguous().to(data.pos.device, torch.long)
21
+ return data
22
+
23
+
24
+ class _ShullTransform(BaseTransform):
25
+ r"""Sweep-hull implementation of delaunay triangulation."""
26
+ def forward(self, data: Data) -> Data:
27
+ assert data.pos is not None
28
+ from torch_delaunay.functional import shull2d
29
+
30
+ face = shull2d(data.pos.cpu())
31
+ data.face = face.t().contiguous().to(data.pos.device)
32
+ return data
33
+
34
+
35
+ class _SequentialTransform(BaseTransform):
36
+ r"""Runs the first successful transformation.
37
+
38
+ All intermediate exceptions are suppressed except the last.
39
+ """
40
+ def __init__(self, transforms: List[BaseTransform]) -> None:
41
+ assert len(transforms) > 0
42
+ self.transforms = transforms
43
+
44
+ def forward(self, data: Data) -> Data:
45
+ for i, transform in enumerate(self.transforms):
46
+ try:
47
+ return transform.forward(data)
48
+ except ImportError as e:
49
+ if i == len(self.transforms) - 1:
50
+ raise e
51
+ return data
52
+
53
+
8
54
  @functional_transform('delaunay')
9
55
  class Delaunay(BaseTransform):
10
56
  r"""Computes the delaunay triangulation of a set of points
11
57
  (functional name: :obj:`delaunay`).
58
+
59
+ .. hint::
60
+ Consider installing the
61
+ `torch_delaunay <https://github.com/ybubnov/torch_delaunay>`_ package
62
+ to speed up computation.
12
63
  """
13
- def forward(self, data: Data) -> Data:
14
- import scipy.spatial
64
+ def __init__(self) -> None:
65
+ self._transform = _SequentialTransform([
66
+ _ShullTransform(),
67
+ _QhullTransform(),
68
+ ])
15
69
 
70
+ def forward(self, data: Data) -> Data:
16
71
  assert data.pos is not None
72
+ device = data.pos.device
17
73
 
18
74
  if data.pos.size(0) < 2:
19
- data.edge_index = torch.tensor([], dtype=torch.long,
20
- device=data.pos.device).view(2, 0)
21
- if data.pos.size(0) == 2:
22
- data.edge_index = torch.tensor([[0, 1], [1, 0]], dtype=torch.long,
23
- device=data.pos.device)
75
+ data.edge_index = torch.empty(2, 0, dtype=torch.long,
76
+ device=device)
77
+ elif data.pos.size(0) == 2:
78
+ data.edge_index = torch.tensor([[0, 1], [1, 0]], device=device)
24
79
  elif data.pos.size(0) == 3:
25
- data.face = torch.tensor([[0], [1], [2]], dtype=torch.long,
26
- device=data.pos.device)
27
- if data.pos.size(0) > 3:
28
- pos = data.pos.cpu().numpy()
29
- tri = scipy.spatial.Delaunay(pos, qhull_options='QJ')
30
- face = torch.from_numpy(tri.simplices)
31
-
32
- data.face = face.t().contiguous().to(data.pos.device, torch.long)
80
+ data.face = torch.tensor([[0], [1], [2]], device=device)
81
+ else:
82
+ data = self._transform.forward(data)
33
83
 
34
84
  return data
@@ -8,8 +8,15 @@ from torch_geometric.utils import to_undirected
8
8
 
9
9
  @functional_transform('face_to_edge')
10
10
  class FaceToEdge(BaseTransform):
11
- r"""Converts mesh faces :obj:`[3, num_faces]` to edge indices
12
- :obj:`[2, num_edges]` (functional name: :obj:`face_to_edge`).
11
+ r"""Converts mesh faces of shape :obj:`[3, num_faces]` or
12
+ :obj:`[4, num_faces]` to edge indices of shape :obj:`[2, num_edges]`
13
+ (functional name: :obj:`face_to_edge`).
14
+
15
+ This transform supports both 2D triangular faces, represented by a
16
+ tensor of shape :obj:`[3, num_faces]`, and 3D tetrahedral mesh faces,
17
+ represented by a tensor of shape :obj:`[4, num_faces]`. It will convert
18
+ these faces into edge indices, where each edge is defined by the indices
19
+ of its two endpoints.
13
20
 
14
21
  Args:
15
22
  remove_faces (bool, optional): If set to :obj:`False`, the face tensor
@@ -22,7 +29,29 @@ class FaceToEdge(BaseTransform):
22
29
  if hasattr(data, 'face'):
23
30
  assert data.face is not None
24
31
  face = data.face
25
- edge_index = torch.cat([face[:2], face[1:], face[::2]], dim=1)
32
+
33
+ if face.size(0) not in [3, 4]:
34
+ raise RuntimeError(f"Expected 'face' tensor with shape "
35
+ f"[3, num_faces] or [4, num_faces] "
36
+ f"(got {list(face.size())})")
37
+
38
+ if face.size()[0] == 3:
39
+ edge_index = torch.cat([
40
+ face[:2],
41
+ face[1:],
42
+ face[::2],
43
+ ], dim=1)
44
+ else:
45
+ assert face.size()[0] == 4
46
+ edge_index = torch.cat([
47
+ face[:2],
48
+ face[1:3],
49
+ face[2:4],
50
+ face[::2],
51
+ face[1::2],
52
+ face[::3],
53
+ ], dim=1)
54
+
26
55
  edge_index = to_undirected(edge_index, num_nodes=data.num_nodes)
27
56
 
28
57
  data.edge_index = edge_index
torch_geometric/typing.py CHANGED
@@ -15,6 +15,7 @@ WITH_PT22 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 2
15
15
  WITH_PT23 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 3
16
16
  WITH_PT24 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 4
17
17
  WITH_PT25 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 5
18
+ WITH_PT26 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 6
18
19
  WITH_PT111 = WITH_PT20 or int(torch.__version__.split('.')[1]) >= 11
19
20
  WITH_PT112 = WITH_PT20 or int(torch.__version__.split('.')[1]) >= 12
20
21
  WITH_PT113 = WITH_PT20 or int(torch.__version__.split('.')[1]) >= 13
@@ -351,5 +351,5 @@ def group_cat(
351
351
  """
352
352
  assert len(tensors) == len(indices)
353
353
  index, perm = torch.cat(indices).sort(stable=True)
354
- out = torch.cat(tensors, dim=0)[perm]
354
+ out = torch.cat(tensors, dim=dim).index_select(dim, perm)
355
355
  return (out, index) if return_index else out
@@ -346,10 +346,12 @@ def k_hop_subgraph(
346
346
 
347
347
  subsets = [node_idx]
348
348
 
349
+ preserved_edge_mask = torch.zeros_like(edge_mask)
349
350
  for _ in range(num_hops):
350
351
  node_mask.fill_(False)
351
352
  node_mask[subsets[-1]] = True
352
353
  torch.index_select(node_mask, 0, row, out=edge_mask)
354
+ preserved_edge_mask |= edge_mask
353
355
  subsets.append(col[edge_mask])
354
356
 
355
357
  subset, inv = torch.cat(subsets).unique(return_inverse=True)
@@ -360,6 +362,8 @@ def k_hop_subgraph(
360
362
 
361
363
  if not directed:
362
364
  edge_mask = node_mask[row] & node_mask[col]
365
+ else:
366
+ edge_mask = preserved_edge_mask
363
367
 
364
368
  edge_index = edge_index[:, edge_mask]
365
369
 
@@ -132,7 +132,7 @@ def _visualize_graph_via_networkx(
132
132
  xy=pos[src],
133
133
  xytext=pos[dst],
134
134
  arrowprops=dict(
135
- arrowstyle="->",
135
+ arrowstyle="<-",
136
136
  alpha=data['alpha'],
137
137
  shrinkA=sqrt(node_size) / 2.0,
138
138
  shrinkB=sqrt(node_size) / 2.0,