pyg-nightly 2.6.0.dev20240319__py3-none-any.whl → 2.7.0.dev20250114__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (226) hide show
  1. {pyg_nightly-2.6.0.dev20240319.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/METADATA +31 -47
  2. {pyg_nightly-2.6.0.dev20240319.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/RECORD +226 -199
  3. {pyg_nightly-2.6.0.dev20240319.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/WHEEL +1 -1
  4. torch_geometric/__init__.py +28 -1
  5. torch_geometric/_compile.py +8 -1
  6. torch_geometric/_onnx.py +14 -0
  7. torch_geometric/config_mixin.py +113 -0
  8. torch_geometric/config_store.py +28 -19
  9. torch_geometric/data/__init__.py +24 -1
  10. torch_geometric/data/batch.py +2 -2
  11. torch_geometric/data/collate.py +8 -2
  12. torch_geometric/data/data.py +16 -8
  13. torch_geometric/data/database.py +61 -15
  14. torch_geometric/data/dataset.py +14 -6
  15. torch_geometric/data/feature_store.py +25 -42
  16. torch_geometric/data/graph_store.py +1 -5
  17. torch_geometric/data/hetero_data.py +18 -9
  18. torch_geometric/data/in_memory_dataset.py +2 -4
  19. torch_geometric/data/large_graph_indexer.py +677 -0
  20. torch_geometric/data/lightning/datamodule.py +4 -4
  21. torch_geometric/data/separate.py +6 -1
  22. torch_geometric/data/storage.py +17 -7
  23. torch_geometric/data/summary.py +14 -4
  24. torch_geometric/data/temporal.py +1 -2
  25. torch_geometric/datasets/__init__.py +17 -2
  26. torch_geometric/datasets/actor.py +9 -11
  27. torch_geometric/datasets/airfrans.py +15 -18
  28. torch_geometric/datasets/airports.py +10 -12
  29. torch_geometric/datasets/amazon.py +8 -11
  30. torch_geometric/datasets/amazon_book.py +9 -10
  31. torch_geometric/datasets/amazon_products.py +9 -10
  32. torch_geometric/datasets/aminer.py +8 -9
  33. torch_geometric/datasets/aqsol.py +10 -13
  34. torch_geometric/datasets/attributed_graph_dataset.py +10 -12
  35. torch_geometric/datasets/ba_multi_shapes.py +10 -12
  36. torch_geometric/datasets/ba_shapes.py +5 -6
  37. torch_geometric/datasets/bitcoin_otc.py +1 -1
  38. torch_geometric/datasets/brca_tgca.py +1 -1
  39. torch_geometric/datasets/cornell.py +145 -0
  40. torch_geometric/datasets/dblp.py +2 -1
  41. torch_geometric/datasets/dbp15k.py +2 -2
  42. torch_geometric/datasets/fake.py +1 -3
  43. torch_geometric/datasets/flickr.py +2 -1
  44. torch_geometric/datasets/freebase.py +1 -1
  45. torch_geometric/datasets/gdelt_lite.py +3 -2
  46. torch_geometric/datasets/ged_dataset.py +3 -2
  47. torch_geometric/datasets/git_mol_dataset.py +263 -0
  48. torch_geometric/datasets/gnn_benchmark_dataset.py +11 -10
  49. torch_geometric/datasets/hgb_dataset.py +8 -8
  50. torch_geometric/datasets/imdb.py +2 -1
  51. torch_geometric/datasets/karate.py +3 -2
  52. torch_geometric/datasets/last_fm.py +2 -1
  53. torch_geometric/datasets/linkx_dataset.py +4 -3
  54. torch_geometric/datasets/lrgb.py +3 -5
  55. torch_geometric/datasets/malnet_tiny.py +4 -3
  56. torch_geometric/datasets/mnist_superpixels.py +2 -3
  57. torch_geometric/datasets/molecule_gpt_dataset.py +485 -0
  58. torch_geometric/datasets/molecule_net.py +15 -3
  59. torch_geometric/datasets/motif_generator/base.py +0 -1
  60. torch_geometric/datasets/neurograph.py +1 -3
  61. torch_geometric/datasets/ogb_mag.py +1 -1
  62. torch_geometric/datasets/opf.py +239 -0
  63. torch_geometric/datasets/ose_gvcs.py +1 -1
  64. torch_geometric/datasets/pascal.py +11 -9
  65. torch_geometric/datasets/pascal_pf.py +1 -1
  66. torch_geometric/datasets/pcpnet_dataset.py +1 -1
  67. torch_geometric/datasets/pcqm4m.py +10 -3
  68. torch_geometric/datasets/ppi.py +1 -1
  69. torch_geometric/datasets/qm9.py +8 -7
  70. torch_geometric/datasets/rcdd.py +4 -4
  71. torch_geometric/datasets/reddit.py +2 -1
  72. torch_geometric/datasets/reddit2.py +2 -1
  73. torch_geometric/datasets/rel_link_pred_dataset.py +3 -3
  74. torch_geometric/datasets/s3dis.py +5 -3
  75. torch_geometric/datasets/shapenet.py +3 -3
  76. torch_geometric/datasets/shrec2016.py +2 -2
  77. torch_geometric/datasets/snap_dataset.py +7 -1
  78. torch_geometric/datasets/tag_dataset.py +350 -0
  79. torch_geometric/datasets/upfd.py +2 -1
  80. torch_geometric/datasets/web_qsp_dataset.py +246 -0
  81. torch_geometric/datasets/webkb.py +2 -2
  82. torch_geometric/datasets/wikics.py +1 -1
  83. torch_geometric/datasets/wikidata.py +3 -2
  84. torch_geometric/datasets/wikipedia_network.py +2 -2
  85. torch_geometric/datasets/willow_object_class.py +1 -1
  86. torch_geometric/datasets/word_net.py +2 -2
  87. torch_geometric/datasets/yelp.py +2 -1
  88. torch_geometric/datasets/zinc.py +1 -1
  89. torch_geometric/device.py +42 -0
  90. torch_geometric/distributed/local_feature_store.py +3 -2
  91. torch_geometric/distributed/local_graph_store.py +2 -1
  92. torch_geometric/distributed/partition.py +9 -8
  93. torch_geometric/edge_index.py +616 -438
  94. torch_geometric/explain/algorithm/base.py +0 -1
  95. torch_geometric/explain/algorithm/graphmask_explainer.py +1 -2
  96. torch_geometric/explain/algorithm/pg_explainer.py +1 -1
  97. torch_geometric/explain/explanation.py +2 -2
  98. torch_geometric/graphgym/checkpoint.py +2 -1
  99. torch_geometric/graphgym/logger.py +4 -4
  100. torch_geometric/graphgym/loss.py +1 -1
  101. torch_geometric/graphgym/utils/agg_runs.py +6 -6
  102. torch_geometric/index.py +826 -0
  103. torch_geometric/inspector.py +8 -3
  104. torch_geometric/io/fs.py +28 -2
  105. torch_geometric/io/npz.py +2 -1
  106. torch_geometric/io/off.py +2 -2
  107. torch_geometric/io/sdf.py +2 -2
  108. torch_geometric/io/tu.py +4 -5
  109. torch_geometric/loader/__init__.py +4 -0
  110. torch_geometric/loader/cluster.py +10 -4
  111. torch_geometric/loader/graph_saint.py +2 -1
  112. torch_geometric/loader/ibmb_loader.py +12 -4
  113. torch_geometric/loader/mixin.py +1 -1
  114. torch_geometric/loader/neighbor_loader.py +1 -1
  115. torch_geometric/loader/neighbor_sampler.py +2 -2
  116. torch_geometric/loader/prefetch.py +1 -1
  117. torch_geometric/loader/rag_loader.py +107 -0
  118. torch_geometric/loader/utils.py +8 -7
  119. torch_geometric/loader/zip_loader.py +10 -0
  120. torch_geometric/metrics/__init__.py +11 -2
  121. torch_geometric/metrics/link_pred.py +159 -34
  122. torch_geometric/nn/aggr/__init__.py +4 -0
  123. torch_geometric/nn/aggr/attention.py +0 -2
  124. torch_geometric/nn/aggr/base.py +2 -4
  125. torch_geometric/nn/aggr/patch_transformer.py +143 -0
  126. torch_geometric/nn/aggr/set_transformer.py +1 -1
  127. torch_geometric/nn/aggr/variance_preserving.py +33 -0
  128. torch_geometric/nn/attention/__init__.py +5 -1
  129. torch_geometric/nn/attention/qformer.py +71 -0
  130. torch_geometric/nn/conv/collect.jinja +7 -4
  131. torch_geometric/nn/conv/cugraph/base.py +8 -12
  132. torch_geometric/nn/conv/edge_conv.py +3 -2
  133. torch_geometric/nn/conv/fused_gat_conv.py +1 -1
  134. torch_geometric/nn/conv/gat_conv.py +35 -7
  135. torch_geometric/nn/conv/gatv2_conv.py +36 -6
  136. torch_geometric/nn/conv/general_conv.py +1 -1
  137. torch_geometric/nn/conv/graph_conv.py +21 -3
  138. torch_geometric/nn/conv/gravnet_conv.py +3 -2
  139. torch_geometric/nn/conv/hetero_conv.py +3 -3
  140. torch_geometric/nn/conv/hgt_conv.py +1 -1
  141. torch_geometric/nn/conv/message_passing.py +138 -87
  142. torch_geometric/nn/conv/mixhop_conv.py +1 -1
  143. torch_geometric/nn/conv/propagate.jinja +9 -1
  144. torch_geometric/nn/conv/rgcn_conv.py +5 -5
  145. torch_geometric/nn/conv/spline_conv.py +4 -4
  146. torch_geometric/nn/conv/x_conv.py +3 -2
  147. torch_geometric/nn/dense/linear.py +11 -6
  148. torch_geometric/nn/fx.py +3 -3
  149. torch_geometric/nn/model_hub.py +3 -1
  150. torch_geometric/nn/models/__init__.py +10 -2
  151. torch_geometric/nn/models/deep_graph_infomax.py +1 -2
  152. torch_geometric/nn/models/dimenet_utils.py +5 -7
  153. torch_geometric/nn/models/g_retriever.py +230 -0
  154. torch_geometric/nn/models/git_mol.py +336 -0
  155. torch_geometric/nn/models/glem.py +385 -0
  156. torch_geometric/nn/models/gnnff.py +0 -1
  157. torch_geometric/nn/models/graph_unet.py +12 -3
  158. torch_geometric/nn/models/jumping_knowledge.py +63 -4
  159. torch_geometric/nn/models/lightgcn.py +1 -1
  160. torch_geometric/nn/models/metapath2vec.py +5 -5
  161. torch_geometric/nn/models/molecule_gpt.py +222 -0
  162. torch_geometric/nn/models/node2vec.py +2 -3
  163. torch_geometric/nn/models/schnet.py +2 -1
  164. torch_geometric/nn/models/signed_gcn.py +3 -3
  165. torch_geometric/nn/module_dict.py +2 -2
  166. torch_geometric/nn/nlp/__init__.py +9 -0
  167. torch_geometric/nn/nlp/llm.py +322 -0
  168. torch_geometric/nn/nlp/sentence_transformer.py +134 -0
  169. torch_geometric/nn/nlp/vision_transformer.py +33 -0
  170. torch_geometric/nn/norm/batch_norm.py +1 -1
  171. torch_geometric/nn/parameter_dict.py +2 -2
  172. torch_geometric/nn/pool/__init__.py +21 -5
  173. torch_geometric/nn/pool/cluster_pool.py +145 -0
  174. torch_geometric/nn/pool/connect/base.py +0 -1
  175. torch_geometric/nn/pool/edge_pool.py +1 -1
  176. torch_geometric/nn/pool/graclus.py +4 -2
  177. torch_geometric/nn/pool/pool.py +8 -2
  178. torch_geometric/nn/pool/select/base.py +0 -1
  179. torch_geometric/nn/pool/voxel_grid.py +3 -2
  180. torch_geometric/nn/resolver.py +1 -1
  181. torch_geometric/nn/sequential.jinja +10 -23
  182. torch_geometric/nn/sequential.py +204 -78
  183. torch_geometric/nn/summary.py +1 -1
  184. torch_geometric/nn/to_hetero_with_bases_transformer.py +19 -19
  185. torch_geometric/profile/__init__.py +2 -0
  186. torch_geometric/profile/nvtx.py +66 -0
  187. torch_geometric/profile/profiler.py +30 -19
  188. torch_geometric/resolver.py +1 -1
  189. torch_geometric/sampler/base.py +34 -13
  190. torch_geometric/sampler/neighbor_sampler.py +11 -10
  191. torch_geometric/sampler/utils.py +1 -1
  192. torch_geometric/template.py +1 -0
  193. torch_geometric/testing/__init__.py +6 -2
  194. torch_geometric/testing/decorators.py +53 -20
  195. torch_geometric/testing/feature_store.py +1 -1
  196. torch_geometric/transforms/__init__.py +2 -0
  197. torch_geometric/transforms/add_metapaths.py +5 -5
  198. torch_geometric/transforms/add_positional_encoding.py +1 -1
  199. torch_geometric/transforms/delaunay.py +65 -14
  200. torch_geometric/transforms/face_to_edge.py +32 -3
  201. torch_geometric/transforms/gdc.py +7 -6
  202. torch_geometric/transforms/laplacian_lambda_max.py +3 -3
  203. torch_geometric/transforms/mask.py +5 -1
  204. torch_geometric/transforms/node_property_split.py +1 -2
  205. torch_geometric/transforms/pad.py +7 -6
  206. torch_geometric/transforms/random_link_split.py +1 -1
  207. torch_geometric/transforms/remove_self_loops.py +36 -0
  208. torch_geometric/transforms/svd_feature_reduction.py +1 -1
  209. torch_geometric/transforms/to_sparse_tensor.py +1 -1
  210. torch_geometric/transforms/two_hop.py +1 -1
  211. torch_geometric/transforms/virtual_node.py +2 -1
  212. torch_geometric/typing.py +43 -6
  213. torch_geometric/utils/__init__.py +5 -1
  214. torch_geometric/utils/_negative_sampling.py +1 -1
  215. torch_geometric/utils/_normalize_edge_index.py +46 -0
  216. torch_geometric/utils/_scatter.py +38 -12
  217. torch_geometric/utils/_subgraph.py +4 -0
  218. torch_geometric/utils/_tree_decomposition.py +2 -2
  219. torch_geometric/utils/augmentation.py +1 -1
  220. torch_geometric/utils/convert.py +12 -8
  221. torch_geometric/utils/geodesic.py +24 -22
  222. torch_geometric/utils/hetero.py +1 -1
  223. torch_geometric/utils/map.py +8 -2
  224. torch_geometric/utils/smiles.py +65 -27
  225. torch_geometric/utils/sparse.py +39 -25
  226. torch_geometric/visualization/graph.py +3 -4
@@ -142,7 +142,7 @@ class WILLOWObjectClass(InMemoryDataset):
142
142
  pos[:, 0] = pos[:, 0] * 256.0 / (img.size[0])
143
143
  pos[:, 1] = pos[:, 1] * 256.0 / (img.size[1])
144
144
 
145
- img = img.resize((256, 256), resample=Image.BICUBIC)
145
+ img = img.resize((256, 256), resample=Image.Resampling.BICUBIC)
146
146
  img = transform(img)
147
147
 
148
148
  data = Data(img=img, pos=pos, name=name)
@@ -67,7 +67,7 @@ class WordNet18(InMemoryDataset):
67
67
  def process(self) -> None:
68
68
  srcs, dsts, edge_types = [], [], []
69
69
  for path in self.raw_paths:
70
- with open(path, 'r') as f:
70
+ with open(path) as f:
71
71
  edges = [int(x) for x in f.read().split()[1:]]
72
72
  edge = torch.tensor(edges, dtype=torch.long)
73
73
  srcs.append(edge[::3])
@@ -173,7 +173,7 @@ class WordNet18RR(InMemoryDataset):
173
173
 
174
174
  srcs, dsts, edge_types = [], [], []
175
175
  for path in self.raw_paths:
176
- with open(path, 'r') as f:
176
+ with open(path) as f:
177
177
  edges = f.read().split()
178
178
 
179
179
  _src = edges[::3]
@@ -3,7 +3,6 @@ import os.path as osp
3
3
  from typing import Callable, List, Optional
4
4
 
5
5
  import numpy as np
6
- import scipy.sparse as sp
7
6
  import torch
8
7
 
9
8
  from torch_geometric.data import Data, InMemoryDataset, download_google_url
@@ -73,6 +72,8 @@ class Yelp(InMemoryDataset):
73
72
  download_google_url(self.role_id, self.raw_dir, 'role.json')
74
73
 
75
74
  def process(self) -> None:
75
+ import scipy.sparse as sp
76
+
76
77
  f = np.load(osp.join(self.raw_dir, 'adj_full.npz'))
77
78
  adj = sp.csr_matrix((f['data'], f['indices'], f['indptr']), f['shape'])
78
79
  adj = adj.tocoo()
@@ -139,7 +139,7 @@ class ZINC(InMemoryDataset):
139
139
  indices = list(range(len(mols)))
140
140
 
141
141
  if self.subset:
142
- with open(osp.join(self.raw_dir, f'{split}.index'), 'r') as f:
142
+ with open(osp.join(self.raw_dir, f'{split}.index')) as f:
143
143
  indices = [int(x) for x in f.read()[:-1].split(',')]
144
144
 
145
145
  pbar = tqdm(total=len(indices))
@@ -0,0 +1,42 @@
1
+ from typing import Any
2
+
3
+ import torch
4
+
5
+
6
+ def is_mps_available() -> bool:
7
+ r"""Returns a bool indicating if MPS is currently available."""
8
+ if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
9
+ try: # Github CI may not have access to MPS hardware. Confirm:
10
+ torch.empty(1, device='mps')
11
+ return True
12
+ except Exception:
13
+ return False
14
+ return False
15
+
16
+
17
+ def is_xpu_available() -> bool:
18
+ r"""Returns a bool indicating if XPU is currently available."""
19
+ if hasattr(torch, 'xpu') and torch.xpu.is_available():
20
+ return True
21
+ try:
22
+ import intel_extension_for_pytorch as ipex
23
+ return ipex.xpu.is_available()
24
+ except ImportError:
25
+ return False
26
+
27
+
28
+ def device(device: Any) -> torch.device:
29
+ r"""Returns a :class:`torch.device`.
30
+
31
+ If :obj:`"auto"` is specified, returns the optimal device depending on
32
+ available hardware.
33
+ """
34
+ if device != 'auto':
35
+ return torch.device(device)
36
+ if torch.cuda.is_available():
37
+ return torch.device('cuda')
38
+ if is_mps_available():
39
+ return torch.device('mps')
40
+ if is_xpu_available():
41
+ return torch.device('xpu')
42
+ return torch.device('cpu')
@@ -15,6 +15,7 @@ from torch_geometric.distributed.rpc import (
15
15
  rpc_async,
16
16
  rpc_register,
17
17
  )
18
+ from torch_geometric.io import fs
18
19
  from torch_geometric.typing import EdgeType, NodeOrEdgeType, NodeType
19
20
 
20
21
 
@@ -415,11 +416,11 @@ class LocalFeatureStore(FeatureStore):
415
416
 
416
417
  node_feats: Optional[Dict[str, Any]] = None
417
418
  if osp.exists(osp.join(part_dir, 'node_feats.pt')):
418
- node_feats = torch.load(osp.join(part_dir, 'node_feats.pt'))
419
+ node_feats = fs.torch_load(osp.join(part_dir, 'node_feats.pt'))
419
420
 
420
421
  edge_feats: Optional[Dict[str, Any]] = None
421
422
  if osp.exists(osp.join(part_dir, 'edge_feats.pt')):
422
- edge_feats = torch.load(osp.join(part_dir, 'edge_feats.pt'))
423
+ edge_feats = fs.torch_load(osp.join(part_dir, 'edge_feats.pt'))
423
424
 
424
425
  if not meta['is_hetero'] and node_feats is not None:
425
426
  feat_store.put_global_id(node_feats['global_id'], group_name=None)
@@ -6,6 +6,7 @@ from torch import Tensor
6
6
 
7
7
  from torch_geometric.data import EdgeAttr, GraphStore
8
8
  from torch_geometric.distributed.partition import load_partition_info
9
+ from torch_geometric.io import fs
9
10
  from torch_geometric.typing import EdgeTensorType, EdgeType, NodeType
10
11
  from torch_geometric.utils import sort_edge_index
11
12
 
@@ -185,7 +186,7 @@ class LocalGraphStore(GraphStore):
185
186
  graph_store.edge_pb = edge_pb
186
187
  graph_store.meta = meta
187
188
 
188
- graph_data = torch.load(osp.join(part_dir, 'graph.pt'))
189
+ graph_data = fs.torch_load(osp.join(part_dir, 'graph.pt'))
189
190
  graph_store.is_sorted = meta['is_sorted']
190
191
 
191
192
  if not meta['is_hetero']:
@@ -3,15 +3,16 @@ import logging
3
3
  import os
4
4
  import os.path as osp
5
5
  from collections import defaultdict
6
- from typing import List, Optional, Union
6
+ from typing import Dict, List, Optional, Tuple, Union
7
7
 
8
8
  import torch
9
9
 
10
10
  import torch_geometric.distributed as pyg_dist
11
11
  from torch_geometric.data import Data, HeteroData
12
+ from torch_geometric.io import fs
12
13
  from torch_geometric.loader.cluster import ClusterData
13
14
  from torch_geometric.sampler.utils import sort_csc
14
- from torch_geometric.typing import Dict, EdgeType, EdgeTypeStr, NodeType, Tuple
15
+ from torch_geometric.typing import EdgeType, EdgeTypeStr, NodeType
15
16
 
16
17
 
17
18
  class Partitioner:
@@ -23,7 +24,7 @@ class Partitioner:
23
24
 
24
25
  **Homogeneous graphs:**
25
26
 
26
- .. code-block::
27
+ .. code-block:: none
27
28
 
28
29
  root/
29
30
  |-- META.json
@@ -40,7 +41,7 @@ class Partitioner:
40
41
 
41
42
  **Heterogeneous graphs:**
42
43
 
43
- .. code-block::
44
+ .. code-block:: none
44
45
 
45
46
  root/
46
47
  |-- META.json
@@ -380,21 +381,21 @@ def load_partition_info(
380
381
  assert osp.exists(partition_dir)
381
382
 
382
383
  if meta['is_hetero'] is False:
383
- node_pb = torch.load(osp.join(root_dir, 'node_map.pt'))
384
- edge_pb = torch.load(osp.join(root_dir, 'edge_map.pt'))
384
+ node_pb = fs.torch_load(osp.join(root_dir, 'node_map.pt'))
385
+ edge_pb = fs.torch_load(osp.join(root_dir, 'edge_map.pt'))
385
386
 
386
387
  return (meta, num_partitions, partition_idx, node_pb, edge_pb)
387
388
  else:
388
389
  node_pb_dict = {}
389
390
  node_pb_dir = osp.join(root_dir, 'node_map')
390
391
  for ntype in meta['node_types']:
391
- node_pb_dict[ntype] = torch.load(
392
+ node_pb_dict[ntype] = fs.torch_load(
392
393
  osp.join(node_pb_dir, f'{pyg_dist.utils.as_str(ntype)}.pt'))
393
394
 
394
395
  edge_pb_dict = {}
395
396
  edge_pb_dir = osp.join(root_dir, 'edge_map')
396
397
  for etype in meta['edge_types']:
397
- edge_pb_dict[tuple(etype)] = torch.load(
398
+ edge_pb_dict[tuple(etype)] = fs.torch_load(
398
399
  osp.join(edge_pb_dir, f'{pyg_dist.utils.as_str(etype)}.pt'))
399
400
 
400
401
  return (meta, num_partitions, partition_idx, node_pb_dict,