pyg-nightly 2.6.0.dev20240318__py3-none-any.whl → 2.7.0.dev20250115__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (226) hide show
  1. {pyg_nightly-2.6.0.dev20240318.dist-info → pyg_nightly-2.7.0.dev20250115.dist-info}/METADATA +31 -47
  2. {pyg_nightly-2.6.0.dev20240318.dist-info → pyg_nightly-2.7.0.dev20250115.dist-info}/RECORD +226 -199
  3. {pyg_nightly-2.6.0.dev20240318.dist-info → pyg_nightly-2.7.0.dev20250115.dist-info}/WHEEL +1 -1
  4. torch_geometric/__init__.py +28 -1
  5. torch_geometric/_compile.py +8 -1
  6. torch_geometric/_onnx.py +14 -0
  7. torch_geometric/config_mixin.py +113 -0
  8. torch_geometric/config_store.py +28 -19
  9. torch_geometric/data/__init__.py +24 -1
  10. torch_geometric/data/batch.py +2 -2
  11. torch_geometric/data/collate.py +8 -2
  12. torch_geometric/data/data.py +16 -8
  13. torch_geometric/data/database.py +61 -15
  14. torch_geometric/data/dataset.py +14 -6
  15. torch_geometric/data/feature_store.py +25 -42
  16. torch_geometric/data/graph_store.py +1 -5
  17. torch_geometric/data/hetero_data.py +18 -9
  18. torch_geometric/data/in_memory_dataset.py +2 -4
  19. torch_geometric/data/large_graph_indexer.py +677 -0
  20. torch_geometric/data/lightning/datamodule.py +4 -4
  21. torch_geometric/data/separate.py +6 -1
  22. torch_geometric/data/storage.py +17 -7
  23. torch_geometric/data/summary.py +14 -4
  24. torch_geometric/data/temporal.py +1 -2
  25. torch_geometric/datasets/__init__.py +17 -2
  26. torch_geometric/datasets/actor.py +9 -11
  27. torch_geometric/datasets/airfrans.py +15 -18
  28. torch_geometric/datasets/airports.py +10 -12
  29. torch_geometric/datasets/amazon.py +8 -11
  30. torch_geometric/datasets/amazon_book.py +9 -10
  31. torch_geometric/datasets/amazon_products.py +9 -10
  32. torch_geometric/datasets/aminer.py +8 -9
  33. torch_geometric/datasets/aqsol.py +10 -13
  34. torch_geometric/datasets/attributed_graph_dataset.py +10 -12
  35. torch_geometric/datasets/ba_multi_shapes.py +10 -12
  36. torch_geometric/datasets/ba_shapes.py +5 -6
  37. torch_geometric/datasets/bitcoin_otc.py +1 -1
  38. torch_geometric/datasets/brca_tgca.py +1 -1
  39. torch_geometric/datasets/cornell.py +145 -0
  40. torch_geometric/datasets/dblp.py +2 -1
  41. torch_geometric/datasets/dbp15k.py +2 -2
  42. torch_geometric/datasets/fake.py +1 -3
  43. torch_geometric/datasets/flickr.py +2 -1
  44. torch_geometric/datasets/freebase.py +1 -1
  45. torch_geometric/datasets/gdelt_lite.py +3 -2
  46. torch_geometric/datasets/ged_dataset.py +3 -2
  47. torch_geometric/datasets/git_mol_dataset.py +263 -0
  48. torch_geometric/datasets/gnn_benchmark_dataset.py +11 -10
  49. torch_geometric/datasets/hgb_dataset.py +8 -8
  50. torch_geometric/datasets/imdb.py +2 -1
  51. torch_geometric/datasets/karate.py +3 -2
  52. torch_geometric/datasets/last_fm.py +2 -1
  53. torch_geometric/datasets/linkx_dataset.py +4 -3
  54. torch_geometric/datasets/lrgb.py +3 -5
  55. torch_geometric/datasets/malnet_tiny.py +4 -3
  56. torch_geometric/datasets/mnist_superpixels.py +2 -3
  57. torch_geometric/datasets/molecule_gpt_dataset.py +485 -0
  58. torch_geometric/datasets/molecule_net.py +15 -3
  59. torch_geometric/datasets/motif_generator/base.py +0 -1
  60. torch_geometric/datasets/neurograph.py +1 -3
  61. torch_geometric/datasets/ogb_mag.py +1 -1
  62. torch_geometric/datasets/opf.py +239 -0
  63. torch_geometric/datasets/ose_gvcs.py +1 -1
  64. torch_geometric/datasets/pascal.py +11 -9
  65. torch_geometric/datasets/pascal_pf.py +1 -1
  66. torch_geometric/datasets/pcpnet_dataset.py +1 -1
  67. torch_geometric/datasets/pcqm4m.py +10 -3
  68. torch_geometric/datasets/ppi.py +1 -1
  69. torch_geometric/datasets/qm9.py +8 -7
  70. torch_geometric/datasets/rcdd.py +4 -4
  71. torch_geometric/datasets/reddit.py +2 -1
  72. torch_geometric/datasets/reddit2.py +2 -1
  73. torch_geometric/datasets/rel_link_pred_dataset.py +3 -3
  74. torch_geometric/datasets/s3dis.py +5 -3
  75. torch_geometric/datasets/shapenet.py +3 -3
  76. torch_geometric/datasets/shrec2016.py +2 -2
  77. torch_geometric/datasets/snap_dataset.py +7 -1
  78. torch_geometric/datasets/tag_dataset.py +350 -0
  79. torch_geometric/datasets/upfd.py +2 -1
  80. torch_geometric/datasets/web_qsp_dataset.py +246 -0
  81. torch_geometric/datasets/webkb.py +2 -2
  82. torch_geometric/datasets/wikics.py +1 -1
  83. torch_geometric/datasets/wikidata.py +3 -2
  84. torch_geometric/datasets/wikipedia_network.py +2 -2
  85. torch_geometric/datasets/willow_object_class.py +1 -1
  86. torch_geometric/datasets/word_net.py +2 -2
  87. torch_geometric/datasets/yelp.py +2 -1
  88. torch_geometric/datasets/zinc.py +1 -1
  89. torch_geometric/device.py +42 -0
  90. torch_geometric/distributed/local_feature_store.py +3 -2
  91. torch_geometric/distributed/local_graph_store.py +2 -1
  92. torch_geometric/distributed/partition.py +9 -8
  93. torch_geometric/edge_index.py +616 -438
  94. torch_geometric/explain/algorithm/base.py +0 -1
  95. torch_geometric/explain/algorithm/graphmask_explainer.py +1 -2
  96. torch_geometric/explain/algorithm/pg_explainer.py +1 -1
  97. torch_geometric/explain/explanation.py +2 -2
  98. torch_geometric/graphgym/checkpoint.py +2 -1
  99. torch_geometric/graphgym/logger.py +4 -4
  100. torch_geometric/graphgym/loss.py +1 -1
  101. torch_geometric/graphgym/utils/agg_runs.py +6 -6
  102. torch_geometric/index.py +826 -0
  103. torch_geometric/inspector.py +13 -7
  104. torch_geometric/io/fs.py +28 -2
  105. torch_geometric/io/npz.py +2 -1
  106. torch_geometric/io/off.py +2 -2
  107. torch_geometric/io/sdf.py +2 -2
  108. torch_geometric/io/tu.py +4 -5
  109. torch_geometric/loader/__init__.py +4 -0
  110. torch_geometric/loader/cluster.py +10 -4
  111. torch_geometric/loader/graph_saint.py +2 -1
  112. torch_geometric/loader/ibmb_loader.py +12 -4
  113. torch_geometric/loader/mixin.py +1 -1
  114. torch_geometric/loader/neighbor_loader.py +1 -1
  115. torch_geometric/loader/neighbor_sampler.py +2 -2
  116. torch_geometric/loader/prefetch.py +1 -1
  117. torch_geometric/loader/rag_loader.py +107 -0
  118. torch_geometric/loader/utils.py +8 -7
  119. torch_geometric/loader/zip_loader.py +10 -0
  120. torch_geometric/metrics/__init__.py +11 -2
  121. torch_geometric/metrics/link_pred.py +317 -65
  122. torch_geometric/nn/aggr/__init__.py +4 -0
  123. torch_geometric/nn/aggr/attention.py +0 -2
  124. torch_geometric/nn/aggr/base.py +3 -5
  125. torch_geometric/nn/aggr/patch_transformer.py +143 -0
  126. torch_geometric/nn/aggr/set_transformer.py +1 -1
  127. torch_geometric/nn/aggr/variance_preserving.py +33 -0
  128. torch_geometric/nn/attention/__init__.py +5 -1
  129. torch_geometric/nn/attention/qformer.py +71 -0
  130. torch_geometric/nn/conv/collect.jinja +7 -4
  131. torch_geometric/nn/conv/cugraph/base.py +8 -12
  132. torch_geometric/nn/conv/edge_conv.py +3 -2
  133. torch_geometric/nn/conv/fused_gat_conv.py +1 -1
  134. torch_geometric/nn/conv/gat_conv.py +35 -7
  135. torch_geometric/nn/conv/gatv2_conv.py +36 -6
  136. torch_geometric/nn/conv/general_conv.py +1 -1
  137. torch_geometric/nn/conv/graph_conv.py +21 -3
  138. torch_geometric/nn/conv/gravnet_conv.py +3 -2
  139. torch_geometric/nn/conv/hetero_conv.py +3 -3
  140. torch_geometric/nn/conv/hgt_conv.py +1 -1
  141. torch_geometric/nn/conv/message_passing.py +138 -87
  142. torch_geometric/nn/conv/mixhop_conv.py +1 -1
  143. torch_geometric/nn/conv/propagate.jinja +9 -1
  144. torch_geometric/nn/conv/rgcn_conv.py +5 -5
  145. torch_geometric/nn/conv/spline_conv.py +4 -4
  146. torch_geometric/nn/conv/x_conv.py +3 -2
  147. torch_geometric/nn/dense/linear.py +11 -6
  148. torch_geometric/nn/fx.py +3 -3
  149. torch_geometric/nn/model_hub.py +3 -1
  150. torch_geometric/nn/models/__init__.py +10 -2
  151. torch_geometric/nn/models/deep_graph_infomax.py +1 -2
  152. torch_geometric/nn/models/dimenet_utils.py +5 -7
  153. torch_geometric/nn/models/g_retriever.py +230 -0
  154. torch_geometric/nn/models/git_mol.py +336 -0
  155. torch_geometric/nn/models/glem.py +385 -0
  156. torch_geometric/nn/models/gnnff.py +0 -1
  157. torch_geometric/nn/models/graph_unet.py +12 -3
  158. torch_geometric/nn/models/jumping_knowledge.py +63 -4
  159. torch_geometric/nn/models/lightgcn.py +1 -1
  160. torch_geometric/nn/models/metapath2vec.py +5 -5
  161. torch_geometric/nn/models/molecule_gpt.py +222 -0
  162. torch_geometric/nn/models/node2vec.py +2 -3
  163. torch_geometric/nn/models/schnet.py +2 -1
  164. torch_geometric/nn/models/signed_gcn.py +3 -3
  165. torch_geometric/nn/module_dict.py +2 -2
  166. torch_geometric/nn/nlp/__init__.py +9 -0
  167. torch_geometric/nn/nlp/llm.py +329 -0
  168. torch_geometric/nn/nlp/sentence_transformer.py +134 -0
  169. torch_geometric/nn/nlp/vision_transformer.py +33 -0
  170. torch_geometric/nn/norm/batch_norm.py +1 -1
  171. torch_geometric/nn/parameter_dict.py +2 -2
  172. torch_geometric/nn/pool/__init__.py +21 -5
  173. torch_geometric/nn/pool/cluster_pool.py +145 -0
  174. torch_geometric/nn/pool/connect/base.py +0 -1
  175. torch_geometric/nn/pool/edge_pool.py +1 -1
  176. torch_geometric/nn/pool/graclus.py +4 -2
  177. torch_geometric/nn/pool/pool.py +8 -2
  178. torch_geometric/nn/pool/select/base.py +0 -1
  179. torch_geometric/nn/pool/voxel_grid.py +3 -2
  180. torch_geometric/nn/resolver.py +1 -1
  181. torch_geometric/nn/sequential.jinja +10 -23
  182. torch_geometric/nn/sequential.py +204 -78
  183. torch_geometric/nn/summary.py +1 -1
  184. torch_geometric/nn/to_hetero_with_bases_transformer.py +19 -19
  185. torch_geometric/profile/__init__.py +2 -0
  186. torch_geometric/profile/nvtx.py +66 -0
  187. torch_geometric/profile/profiler.py +30 -19
  188. torch_geometric/resolver.py +1 -1
  189. torch_geometric/sampler/base.py +34 -13
  190. torch_geometric/sampler/neighbor_sampler.py +11 -10
  191. torch_geometric/sampler/utils.py +1 -1
  192. torch_geometric/template.py +1 -0
  193. torch_geometric/testing/__init__.py +6 -2
  194. torch_geometric/testing/decorators.py +56 -22
  195. torch_geometric/testing/feature_store.py +1 -1
  196. torch_geometric/transforms/__init__.py +2 -0
  197. torch_geometric/transforms/add_metapaths.py +5 -5
  198. torch_geometric/transforms/add_positional_encoding.py +1 -1
  199. torch_geometric/transforms/delaunay.py +65 -14
  200. torch_geometric/transforms/face_to_edge.py +32 -3
  201. torch_geometric/transforms/gdc.py +7 -6
  202. torch_geometric/transforms/laplacian_lambda_max.py +3 -3
  203. torch_geometric/transforms/mask.py +5 -1
  204. torch_geometric/transforms/node_property_split.py +1 -2
  205. torch_geometric/transforms/pad.py +7 -6
  206. torch_geometric/transforms/random_link_split.py +1 -1
  207. torch_geometric/transforms/remove_self_loops.py +36 -0
  208. torch_geometric/transforms/svd_feature_reduction.py +1 -1
  209. torch_geometric/transforms/to_sparse_tensor.py +1 -1
  210. torch_geometric/transforms/two_hop.py +1 -1
  211. torch_geometric/transforms/virtual_node.py +2 -1
  212. torch_geometric/typing.py +43 -6
  213. torch_geometric/utils/__init__.py +5 -1
  214. torch_geometric/utils/_negative_sampling.py +1 -1
  215. torch_geometric/utils/_normalize_edge_index.py +46 -0
  216. torch_geometric/utils/_scatter.py +38 -12
  217. torch_geometric/utils/_subgraph.py +4 -0
  218. torch_geometric/utils/_tree_decomposition.py +2 -2
  219. torch_geometric/utils/augmentation.py +1 -1
  220. torch_geometric/utils/convert.py +12 -8
  221. torch_geometric/utils/geodesic.py +24 -22
  222. torch_geometric/utils/hetero.py +1 -1
  223. torch_geometric/utils/map.py +8 -2
  224. torch_geometric/utils/smiles.py +65 -27
  225. torch_geometric/utils/sparse.py +39 -25
  226. torch_geometric/visualization/graph.py +3 -4
@@ -5,7 +5,7 @@ from abc import ABC
5
5
  from collections import defaultdict
6
6
  from dataclasses import dataclass
7
7
  from enum import Enum
8
- from typing import Any, Dict, List, Optional, Union
8
+ from typing import Any, Dict, List, Literal, Optional, Union
9
9
 
10
10
  import torch
11
11
  from torch import Tensor
@@ -425,6 +425,14 @@ class NumNeighbors:
425
425
  else:
426
426
  assert False
427
427
 
428
+ # Confirm that `values` only hold valid edge types:
429
+ if isinstance(self.values, dict):
430
+ edge_types_str = {EdgeTypeStr(key) for key in edge_types}
431
+ invalid_edge_types = set(self.values.keys()) - edge_types_str
432
+ if len(invalid_edge_types) > 0:
433
+ raise ValueError("Not all edge types specified in "
434
+ "'num_neighbors' exist in the graph")
435
+
428
436
  out = {}
429
437
  for edge_type in edge_types:
430
438
  edge_type_str = EdgeTypeStr(edge_type)
@@ -444,7 +452,7 @@ class NumNeighbors:
444
452
  out = copy.copy(self.values)
445
453
 
446
454
  if isinstance(out, dict):
447
- num_hops = set(len(v) for v in out.values())
455
+ num_hops = {len(v) for v in out.values()}
448
456
  if len(num_hops) > 1:
449
457
  raise ValueError(f"Number of hops must be the same across all "
450
458
  f"edge types (got {len(num_hops)} different "
@@ -533,24 +541,31 @@ class NegativeSampling(CastMixin):
533
541
  destination nodes for each positive source node.
534
542
  amount (int or float, optional): The ratio of sampled negative edges to
535
543
  the number of positive edges. (default: :obj:`1`)
536
- weight (torch.Tensor, optional): A node-level vector determining the
537
- sampling of nodes. Does not necessariyl need to sum up to one.
538
- If not given, negative nodes will be sampled uniformly.
544
+ src_weight (torch.Tensor, optional): A node-level vector determining
545
+ the sampling of source nodes. Does not necessarily need to sum up
546
+ to one. If not given, negative nodes will be sampled uniformly.
547
+ (default: :obj:`None`)
548
+ dst_weight (torch.Tensor, optional): A node-level vector determining
549
+ the sampling of destination nodes. Does not necessarily need to sum
550
+ up to one. If not given, negative nodes will be sampled uniformly.
539
551
  (default: :obj:`None`)
540
552
  """
541
553
  mode: NegativeSamplingMode
542
554
  amount: Union[int, float] = 1
543
- weight: Optional[Tensor] = None
555
+ src_weight: Optional[Tensor] = None
556
+ dst_weight: Optional[Tensor] = None
544
557
 
545
558
  def __init__(
546
559
  self,
547
560
  mode: Union[NegativeSamplingMode, str],
548
561
  amount: Union[int, float] = 1,
549
- weight: Optional[Tensor] = None,
562
+ src_weight: Optional[Tensor] = None,
563
+ dst_weight: Optional[Tensor] = None,
550
564
  ):
551
565
  self.mode = NegativeSamplingMode(mode)
552
566
  self.amount = amount
553
- self.weight = weight
567
+ self.src_weight = src_weight
568
+ self.dst_weight = dst_weight
554
569
 
555
570
  if self.amount <= 0:
556
571
  raise ValueError(f"The attribute 'amount' needs to be positive "
@@ -571,22 +586,28 @@ class NegativeSampling(CastMixin):
571
586
  def is_triplet(self) -> bool:
572
587
  return self.mode == NegativeSamplingMode.triplet
573
588
 
574
- def sample(self, num_samples: int,
575
- num_nodes: Optional[int] = None) -> Tensor:
589
+ def sample(
590
+ self,
591
+ num_samples: int,
592
+ endpoint: Literal['src', 'dst'],
593
+ num_nodes: Optional[int] = None,
594
+ ) -> Tensor:
576
595
  r"""Generates :obj:`num_samples` negative samples."""
577
- if self.weight is None:
596
+ weight = self.src_weight if endpoint == 'src' else self.dst_weight
597
+
598
+ if weight is None:
578
599
  if num_nodes is None:
579
600
  raise ValueError(
580
601
  f"Cannot sample negatives in '{self.__class__.__name__}' "
581
602
  f"without passing the 'num_nodes' argument")
582
603
  return torch.randint(num_nodes, (num_samples, ))
583
604
 
584
- if num_nodes is not None and self.weight.numel() != num_nodes:
605
+ if num_nodes is not None and weight.numel() != num_nodes:
585
606
  raise ValueError(
586
607
  f"The 'weight' attribute in '{self.__class__.__name__}' "
587
608
  f"needs to match the number of nodes {num_nodes} "
588
609
  f"(got {self.weight.numel()})")
589
- return torch.multinomial(self.weight, num_samples, replacement=True)
610
+ return torch.multinomial(weight, num_samples, replacement=True)
590
611
 
591
612
 
592
613
  class BaseSampler(ABC):
@@ -2,7 +2,7 @@ import copy
2
2
  import math
3
3
  import sys
4
4
  import warnings
5
- from typing import Callable, Dict, List, Optional, Tuple, Union
5
+ from typing import Callable, Dict, List, Literal, Optional, Tuple, Union
6
6
 
7
7
  import torch
8
8
  from torch import Tensor
@@ -168,7 +168,7 @@ class NeighborSampler(BaseSampler):
168
168
  attrs = [attr for attr in feature_store.get_all_tensor_attrs()]
169
169
 
170
170
  edge_attrs = graph_store.get_all_edge_attrs()
171
- self.edge_types = list(set(attr.edge_type for attr in edge_attrs))
171
+ self.edge_types = list({attr.edge_type for attr in edge_attrs})
172
172
 
173
173
  if weight_attr is not None:
174
174
  raise NotImplementedError(
@@ -593,7 +593,7 @@ def edge_sample(
593
593
  src_node_time = node_time
594
594
 
595
595
  src_neg = neg_sample(src, neg_sampling, num_src_nodes, src_time,
596
- src_node_time)
596
+ src_node_time, endpoint='src')
597
597
  src = torch.cat([src, src_neg], dim=0)
598
598
 
599
599
  if isinstance(node_time, dict):
@@ -602,7 +602,7 @@ def edge_sample(
602
602
  dst_node_time = node_time
603
603
 
604
604
  dst_neg = neg_sample(dst, neg_sampling, num_dst_nodes, dst_time,
605
- dst_node_time)
605
+ dst_node_time, endpoint='dst')
606
606
  dst = torch.cat([dst, dst_neg], dim=0)
607
607
 
608
608
  if edge_label is None:
@@ -623,7 +623,7 @@ def edge_sample(
623
623
  dst_node_time = node_time
624
624
 
625
625
  dst_neg = neg_sample(dst, neg_sampling, num_dst_nodes, dst_time,
626
- dst_node_time)
626
+ dst_node_time, endpoint='dst')
627
627
  dst = torch.cat([dst, dst_neg], dim=0)
628
628
 
629
629
  assert edge_label is None
@@ -631,7 +631,7 @@ def edge_sample(
631
631
  if edge_label_time is not None:
632
632
  dst_time = edge_label_time.repeat(1 + neg_sampling.amount)
633
633
 
634
- # Heterogeneus Neighborhood Sampling ######################################
634
+ # Heterogeneous Neighborhood Sampling #####################################
635
635
 
636
636
  if input_type is not None:
637
637
  seed_time_dict = None
@@ -724,7 +724,7 @@ def edge_sample(
724
724
  src_time,
725
725
  )
726
726
 
727
- # Homogeneus Neighborhood Sampling ########################################
727
+ # Homogeneous Neighborhood Sampling #######################################
728
728
 
729
729
  else:
730
730
 
@@ -781,12 +781,13 @@ def neg_sample(
781
781
  num_nodes: int,
782
782
  seed_time: Optional[Tensor],
783
783
  node_time: Optional[Tensor],
784
+ endpoint: Literal['str', 'dst'],
784
785
  ) -> Tensor:
785
786
  num_neg = math.ceil(seed.numel() * neg_sampling.amount)
786
787
 
787
788
  # TODO: Do not sample false negatives.
788
789
  if node_time is None:
789
- return neg_sampling.sample(num_neg, num_nodes)
790
+ return neg_sampling.sample(num_neg, endpoint, num_nodes)
790
791
 
791
792
  # If we are in a temporal-sampling scenario, we need to respect the
792
793
  # timestamp of the given nodes we can use as negative examples.
@@ -800,7 +801,7 @@ def neg_sample(
800
801
  num_samples = math.ceil(neg_sampling.amount)
801
802
  seed_time = seed_time.view(1, -1).expand(num_samples, -1)
802
803
 
803
- out = neg_sampling.sample(num_samples * seed.numel(), num_nodes)
804
+ out = neg_sampling.sample(num_samples * seed.numel(), endpoint, num_nodes)
804
805
  out = out.view(num_samples, seed.numel())
805
806
  mask = node_time[out] > seed_time # holds all invalid samples.
806
807
  neg_sampling_complete = False
@@ -811,7 +812,7 @@ def neg_sample(
811
812
  break
812
813
 
813
814
  # Greedily search for alternative negatives.
814
- out[mask] = tmp = neg_sampling.sample(num_invalid, num_nodes)
815
+ out[mask] = tmp = neg_sampling.sample(num_invalid, endpoint, num_nodes)
815
816
  mask[mask.clone()] = node_time[tmp] >= seed_time[mask]
816
817
 
817
818
  if not neg_sampling_complete: # pragma: no cover
@@ -5,9 +5,9 @@ from torch import Tensor
5
5
 
6
6
  from torch_geometric.data import Data, HeteroData
7
7
  from torch_geometric.data.storage import EdgeStorage
8
+ from torch_geometric.index import index2ptr
8
9
  from torch_geometric.typing import EdgeType, NodeType, OptTensor
9
10
  from torch_geometric.utils import coalesce, index_sort, lexsort
10
- from torch_geometric.utils.sparse import index2ptr
11
11
 
12
12
  # Edge Layout Conversion ######################################################
13
13
 
@@ -28,6 +28,7 @@ def module_from_template(
28
28
  delete=False,
29
29
  ) as tmp:
30
30
  tmp.write(module_repr)
31
+ tmp.flush()
31
32
 
32
33
  spec = importlib.util.spec_from_file_location(module_name, tmp.name)
33
34
  assert spec is not None
@@ -10,7 +10,8 @@ from .decorators import (
10
10
  onlyDistributedTest,
11
11
  onlyLinux,
12
12
  noWindows,
13
- onlyPython,
13
+ noMac,
14
+ minPython,
14
15
  onlyCUDA,
15
16
  onlyXPU,
16
17
  onlyOnline,
@@ -18,6 +19,7 @@ from .decorators import (
18
19
  onlyNeighborSampler,
19
20
  has_package,
20
21
  withPackage,
22
+ withDevice,
21
23
  withCUDA,
22
24
  withMETIS,
23
25
  disableExtensions,
@@ -39,7 +41,8 @@ __all__ = [
39
41
  'onlyDistributedTest',
40
42
  'onlyLinux',
41
43
  'noWindows',
42
- 'onlyPython',
44
+ 'noMac',
45
+ 'minPython',
43
46
  'onlyCUDA',
44
47
  'onlyXPU',
45
48
  'onlyOnline',
@@ -47,6 +50,7 @@ __all__ = [
47
50
  'onlyNeighborSampler',
48
51
  'has_package',
49
52
  'withPackage',
53
+ 'withDevice',
50
54
  'withCUDA',
51
55
  'withMETIS',
52
56
  'disableExtensions',
@@ -7,7 +7,9 @@ from typing import Callable
7
7
 
8
8
  import torch
9
9
  from packaging.requirements import Requirement
10
+ from packaging.version import Version
10
11
 
12
+ import torch_geometric
11
13
  from torch_geometric.typing import WITH_METIS, WITH_PYG_LIB, WITH_TORCH_SPARSE
12
14
  from torch_geometric.visualization.graph import has_graphviz
13
15
 
@@ -67,15 +69,34 @@ def noWindows(func: Callable) -> Callable:
67
69
  )(func)
68
70
 
69
71
 
70
- def onlyPython(*args: str) -> Callable:
72
+ def noMac(func: Callable) -> Callable:
73
+ r"""A decorator to specify that this function should not execute on
74
+ macOS systems.
75
+ """
76
+ import pytest
77
+ return pytest.mark.skipif(
78
+ sys.platform == 'darwin',
79
+ reason="macOS system",
80
+ )(func)
81
+
82
+
83
+ def minPython(version: str) -> Callable:
71
84
  r"""A decorator to run tests on specific :python:`Python` versions only."""
72
85
  def decorator(func: Callable) -> Callable:
73
86
  import pytest
74
87
 
75
- python_version = f'{sys.version_info.major}.{sys.version_info.minor}'
88
+ major, minor = version.split('.')
89
+
90
+ skip = False
91
+ if sys.version_info.major < int(major):
92
+ skip = True
93
+ if (sys.version_info.major == int(major)
94
+ and sys.version_info.minor < int(minor)):
95
+ skip = True
96
+
76
97
  return pytest.mark.skipif(
77
- python_version not in args,
78
- reason=f"Python {python_version} not supported",
98
+ skip,
99
+ reason=f"Python {version} required",
79
100
  )(func)
80
101
 
81
102
  return decorator
@@ -93,13 +114,8 @@ def onlyCUDA(func: Callable) -> Callable:
93
114
  def onlyXPU(func: Callable) -> Callable:
94
115
  r"""A decorator to skip tests if XPU is not found."""
95
116
  import pytest
96
- try:
97
- import intel_extension_for_pytorch as ipex
98
- xpu_available = ipex.xpu.is_available()
99
- except ImportError:
100
- xpu_available = False
101
117
  return pytest.mark.skipif(
102
- not xpu_available,
118
+ not torch_geometric.is_xpu_available(),
103
119
  reason="XPU not available",
104
120
  )(func)
105
121
 
@@ -157,24 +173,23 @@ def has_package(package: str) -> bool:
157
173
  req = Requirement(package)
158
174
  if find_spec(req.name) is None:
159
175
  return False
160
- module = import_module(req.name)
161
- if not hasattr(module, '__version__'):
162
- return True
163
176
 
164
- version = module.__version__
165
- # `req.specifier` does not support `.dev` suffixes, e.g., for
166
- # `pyg_lib==0.1.0.dev*`, so we manually drop them:
167
- if '.dev' in version:
168
- version = '.'.join(version.split('.dev')[:-1])
177
+ try:
178
+ module = import_module(req.name)
179
+ if not hasattr(module, '__version__'):
180
+ return True
169
181
 
170
- return version in req.specifier
182
+ version = Version(module.__version__).base_version
183
+ return version in req.specifier
184
+ except Exception:
185
+ return False
171
186
 
172
187
 
173
188
  def withPackage(*args: str) -> Callable:
174
189
  r"""A decorator to skip tests if certain packages are not installed.
175
190
  Also supports version specification.
176
191
  """
177
- na_packages = set(package for package in args if not has_package(package))
192
+ na_packages = {package for package in args if not has_package(package)}
178
193
 
179
194
  if len(na_packages) == 1:
180
195
  reason = f"Package {list(na_packages)[0]} not found"
@@ -196,6 +211,24 @@ def withCUDA(func: Callable) -> Callable:
196
211
  if torch.cuda.is_available():
197
212
  devices.append(pytest.param(torch.device('cuda:0'), id='cuda:0'))
198
213
 
214
+ return pytest.mark.parametrize('device', devices)(func)
215
+
216
+
217
+ def withDevice(func: Callable) -> Callable:
218
+ r"""A decorator to test on all available tensor processing devices."""
219
+ import pytest
220
+
221
+ devices = [pytest.param(torch.device('cpu'), id='cpu')]
222
+
223
+ if torch.cuda.is_available():
224
+ devices.append(pytest.param(torch.device('cuda:0'), id='cuda:0'))
225
+
226
+ if torch_geometric.is_mps_available():
227
+ devices.append(pytest.param(torch.device('mps:0'), id='mps'))
228
+
229
+ if torch_geometric.is_xpu_available():
230
+ devices.append(pytest.param(torch.device('xpu:0'), id='xpu'))
231
+
199
232
  # Additional devices can be registered through environment variables:
200
233
  device = os.getenv('TORCH_DEVICE')
201
234
  if device:
@@ -218,10 +251,11 @@ def withMETIS(func: Callable) -> Callable:
218
251
 
219
252
  if with_metis:
220
253
  try: # Test that METIS can succesfully execute:
221
- import pyg_lib
254
+ # TODO Using `pyg-lib` metis partitioning leads to some weird bugs
255
+ # in the # CI. As such, we require `torch-sparse` for now.
222
256
  rowptr = torch.tensor([0, 2, 4, 6])
223
257
  col = torch.tensor([1, 2, 0, 2, 1, 0])
224
- pyg_lib.partition.metis(rowptr, col, num_partitions=2)
258
+ torch.ops.torch_sparse.partition(rowptr, col, None, 2, True)
225
259
  except Exception:
226
260
  with_metis = False
227
261
 
@@ -36,7 +36,7 @@ class MyFeatureStore(FeatureStore):
36
36
  index, tensor = self.store.get(self.key(attr), (None, None))
37
37
 
38
38
  if tensor is None:
39
- return None
39
+ raise KeyError(f"Could not find tensor for '{attr}'")
40
40
 
41
41
  assert isinstance(tensor, Tensor)
42
42
 
@@ -20,6 +20,7 @@ from .target_indegree import TargetIndegree
20
20
  from .local_degree_profile import LocalDegreeProfile
21
21
  from .add_self_loops import AddSelfLoops
22
22
  from .add_remaining_self_loops import AddRemainingSelfLoops
23
+ from .remove_self_loops import RemoveSelfLoops
23
24
  from .remove_isolated_nodes import RemoveIsolatedNodes
24
25
  from .remove_duplicated_edges import RemoveDuplicatedEdges
25
26
  from .knn_graph import KNNGraph
@@ -87,6 +88,7 @@ graph_transforms = [
87
88
  'LocalDegreeProfile',
88
89
  'AddSelfLoops',
89
90
  'AddRemainingSelfLoops',
91
+ 'RemoveSelfLoops',
90
92
  'RemoveIsolatedNodes',
91
93
  'RemoveDuplicatedEdges',
92
94
  'KNNGraph',
@@ -37,7 +37,7 @@ class AddMetaPaths(BaseTransform):
37
37
  :class:`~torch_geometric.data.HeteroData` object as edge type
38
38
  :obj:`(src_node_type, "metapath_*", dst_node_type)`, where
39
39
  :obj:`src_node_type` and :obj:`dst_node_type` denote :math:`\mathcal{V}_1`
40
- and :math:`\mathcal{V}_{\ell}`, repectively.
40
+ and :math:`\mathcal{V}_{\ell}`, respectively.
41
41
 
42
42
  In addition, a :obj:`metapath_dict` object is added to the
43
43
  :class:`~torch_geometric.data.HeteroData` object which maps the
@@ -108,12 +108,12 @@ class AddMetaPaths(BaseTransform):
108
108
  **kwargs: bool,
109
109
  ) -> None:
110
110
  if 'drop_orig_edges' in kwargs:
111
- warnings.warn("'drop_orig_edges' is dprecated. Use "
111
+ warnings.warn("'drop_orig_edges' is deprecated. Use "
112
112
  "'drop_orig_edge_types' instead")
113
113
  drop_orig_edge_types = kwargs['drop_orig_edges']
114
114
 
115
115
  if 'drop_unconnected_nodes' in kwargs:
116
- warnings.warn("'drop_unconnected_nodes' is dprecated. Use "
116
+ warnings.warn("'drop_unconnected_nodes' is deprecated. Use "
117
117
  "'drop_unconnected_node_types' instead")
118
118
  drop_unconnected_node_types = kwargs['drop_unconnected_nodes']
119
119
 
@@ -158,7 +158,7 @@ class AddMetaPaths(BaseTransform):
158
158
  edge_index, edge_weight)
159
159
 
160
160
  new_edge_type = (metapath[0][0], f'metapath_{j}', metapath[-1][-1])
161
- data[new_edge_type].edge_index = edge_index
161
+ data[new_edge_type].edge_index = edge_index.as_tensor()
162
162
  if self.weighted:
163
163
  data[new_edge_type].edge_weight = edge_weight
164
164
  data.metapath_dict[new_edge_type] = metapath
@@ -231,7 +231,7 @@ class AddRandomMetaPaths(BaseTransform):
231
231
  will drop node types not connected by any edge type.
232
232
  (default: :obj:`False`)
233
233
  walks_per_node (int, List[int], optional): The number of random walks
234
- for each starting node in a metapth. (default: :obj:`1`)
234
+ for each starting node in a metapath. (default: :obj:`1`)
235
235
  sample_ratio (float, optional): The ratio of source nodes to start
236
236
  random walks from. (default: :obj:`1.0`)
237
237
  """
@@ -92,7 +92,7 @@ class AddLaplacianEigenvectorPE(BaseTransform):
92
92
  from numpy.linalg import eig, eigh
93
93
  eig_fn = eig if not self.is_undirected else eigh
94
94
 
95
- eig_vals, eig_vecs = eig_fn(L.todense()) # type: ignore
95
+ eig_vals, eig_vecs = eig_fn(L.todense())
96
96
  else:
97
97
  from scipy.sparse.linalg import eigs, eigsh
98
98
  eig_fn = eigs if not self.is_undirected else eigsh
@@ -1,4 +1,5 @@
1
- import scipy.spatial
1
+ from typing import List
2
+
2
3
  import torch
3
4
 
4
5
  from torch_geometric.data import Data
@@ -6,28 +7,78 @@ from torch_geometric.data.datapipes import functional_transform
6
7
  from torch_geometric.transforms import BaseTransform
7
8
 
8
9
 
10
+ class _QhullTransform(BaseTransform):
11
+ r"""Q-hull implementation of delaunay triangulation."""
12
+ def forward(self, data: Data) -> Data:
13
+ assert data.pos is not None
14
+ import scipy.spatial
15
+
16
+ pos = data.pos.cpu().numpy()
17
+ tri = scipy.spatial.Delaunay(pos, qhull_options='QJ')
18
+ face = torch.from_numpy(tri.simplices)
19
+
20
+ data.face = face.t().contiguous().to(data.pos.device, torch.long)
21
+ return data
22
+
23
+
24
+ class _ShullTransform(BaseTransform):
25
+ r"""Sweep-hull implementation of delaunay triangulation."""
26
+ def forward(self, data: Data) -> Data:
27
+ assert data.pos is not None
28
+ from torch_delaunay.functional import shull2d
29
+
30
+ face = shull2d(data.pos.cpu())
31
+ data.face = face.t().contiguous().to(data.pos.device)
32
+ return data
33
+
34
+
35
+ class _SequentialTransform(BaseTransform):
36
+ r"""Runs the first successful transformation.
37
+
38
+ All intermediate exceptions are suppressed except the last.
39
+ """
40
+ def __init__(self, transforms: List[BaseTransform]) -> None:
41
+ assert len(transforms) > 0
42
+ self.transforms = transforms
43
+
44
+ def forward(self, data: Data) -> Data:
45
+ for i, transform in enumerate(self.transforms):
46
+ try:
47
+ return transform.forward(data)
48
+ except ImportError as e:
49
+ if i == len(self.transforms) - 1:
50
+ raise e
51
+ return data
52
+
53
+
9
54
  @functional_transform('delaunay')
10
55
  class Delaunay(BaseTransform):
11
56
  r"""Computes the delaunay triangulation of a set of points
12
57
  (functional name: :obj:`delaunay`).
58
+
59
+ .. hint::
60
+ Consider installing the
61
+ `torch_delaunay <https://github.com/ybubnov/torch_delaunay>`_ package
62
+ to speed up computation.
13
63
  """
64
+ def __init__(self) -> None:
65
+ self._transform = _SequentialTransform([
66
+ _ShullTransform(),
67
+ _QhullTransform(),
68
+ ])
69
+
14
70
  def forward(self, data: Data) -> Data:
15
71
  assert data.pos is not None
72
+ device = data.pos.device
16
73
 
17
74
  if data.pos.size(0) < 2:
18
- data.edge_index = torch.tensor([], dtype=torch.long,
19
- device=data.pos.device).view(2, 0)
20
- if data.pos.size(0) == 2:
21
- data.edge_index = torch.tensor([[0, 1], [1, 0]], dtype=torch.long,
22
- device=data.pos.device)
75
+ data.edge_index = torch.empty(2, 0, dtype=torch.long,
76
+ device=device)
77
+ elif data.pos.size(0) == 2:
78
+ data.edge_index = torch.tensor([[0, 1], [1, 0]], device=device)
23
79
  elif data.pos.size(0) == 3:
24
- data.face = torch.tensor([[0], [1], [2]], dtype=torch.long,
25
- device=data.pos.device)
26
- if data.pos.size(0) > 3:
27
- pos = data.pos.cpu().numpy()
28
- tri = scipy.spatial.Delaunay(pos, qhull_options='QJ')
29
- face = torch.from_numpy(tri.simplices)
30
-
31
- data.face = face.t().contiguous().to(data.pos.device, torch.long)
80
+ data.face = torch.tensor([[0], [1], [2]], device=device)
81
+ else:
82
+ data = self._transform.forward(data)
32
83
 
33
84
  return data
@@ -8,8 +8,15 @@ from torch_geometric.utils import to_undirected
8
8
 
9
9
  @functional_transform('face_to_edge')
10
10
  class FaceToEdge(BaseTransform):
11
- r"""Converts mesh faces :obj:`[3, num_faces]` to edge indices
12
- :obj:`[2, num_edges]` (functional name: :obj:`face_to_edge`).
11
+ r"""Converts mesh faces of shape :obj:`[3, num_faces]` or
12
+ :obj:`[4, num_faces]` to edge indices of shape :obj:`[2, num_edges]`
13
+ (functional name: :obj:`face_to_edge`).
14
+
15
+ This transform supports both 2D triangular faces, represented by a
16
+ tensor of shape :obj:`[3, num_faces]`, and 3D tetrahedral mesh faces,
17
+ represented by a tensor of shape :obj:`[4, num_faces]`. It will convert
18
+ these faces into edge indices, where each edge is defined by the indices
19
+ of its two endpoints.
13
20
 
14
21
  Args:
15
22
  remove_faces (bool, optional): If set to :obj:`False`, the face tensor
@@ -22,7 +29,29 @@ class FaceToEdge(BaseTransform):
22
29
  if hasattr(data, 'face'):
23
30
  assert data.face is not None
24
31
  face = data.face
25
- edge_index = torch.cat([face[:2], face[1:], face[::2]], dim=1)
32
+
33
+ if face.size(0) not in [3, 4]:
34
+ raise RuntimeError(f"Expected 'face' tensor with shape "
35
+ f"[3, num_faces] or [4, num_faces] "
36
+ f"(got {list(face.size())})")
37
+
38
+ if face.size()[0] == 3:
39
+ edge_index = torch.cat([
40
+ face[:2],
41
+ face[1:],
42
+ face[::2],
43
+ ], dim=1)
44
+ else:
45
+ assert face.size()[0] == 4
46
+ edge_index = torch.cat([
47
+ face[:2],
48
+ face[1:3],
49
+ face[2:4],
50
+ face[::2],
51
+ face[1::2],
52
+ face[::3],
53
+ ], dim=1)
54
+
26
55
  edge_index = to_undirected(edge_index, num_nodes=data.num_nodes)
27
56
 
28
57
  data.edge_index = edge_index
@@ -2,7 +2,6 @@ from typing import Any, Dict, Tuple
2
2
 
3
3
  import numpy as np
4
4
  import torch
5
- from scipy.linalg import expm
6
5
  from torch import Tensor
7
6
 
8
7
  from torch_geometric.data import Data
@@ -22,7 +21,7 @@ from torch_geometric.utils import (
22
21
  @functional_transform('gdc')
23
22
  class GDC(BaseTransform):
24
23
  r"""Processes the graph via Graph Diffusion Convolution (GDC) from the
25
- `"Diffusion Improves Graph Learning" <https://www.kdd.in.tum.de/gdc>`_
24
+ `"Diffusion Improves Graph Learning" <https://arxiv.org/abs/1911.05485>`_
26
25
  paper (functional name: :obj:`gdc`).
27
26
 
28
27
  .. note::
@@ -338,10 +337,10 @@ class GDC(BaseTransform):
338
337
 
339
338
  elif method == 'heat':
340
339
  raise NotImplementedError(
341
- ('Currently no fast heat kernel is implemented. You are '
342
- 'welcome to create one yourself, e.g., based on '
343
- '"Kloster and Gleich: Heat kernel based community detection '
344
- '(KDD 2014)."'))
340
+ 'Currently no fast heat kernel is implemented. You are '
341
+ 'welcome to create one yourself, e.g., based on '
342
+ '"Kloster and Gleich: Heat kernel based community detection '
343
+ '(KDD 2014)."')
345
344
  else:
346
345
  raise ValueError(f"Approximate GDC diffusion '{method}' unknown")
347
346
 
@@ -473,6 +472,8 @@ class GDC(BaseTransform):
473
472
 
474
473
  :rtype: (:class:`Tensor`)
475
474
  """
475
+ from scipy.linalg import expm
476
+
476
477
  if symmetric:
477
478
  e, V = torch.linalg.eigh(matrix, UPLO='U')
478
479
  diff_mat = V @ torch.diag(e.exp()) @ V.t()