pyg-nightly 2.6.0.dev20240319__py3-none-any.whl → 2.7.0.dev20250114__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (226) hide show
  1. {pyg_nightly-2.6.0.dev20240319.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/METADATA +31 -47
  2. {pyg_nightly-2.6.0.dev20240319.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/RECORD +226 -199
  3. {pyg_nightly-2.6.0.dev20240319.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/WHEEL +1 -1
  4. torch_geometric/__init__.py +28 -1
  5. torch_geometric/_compile.py +8 -1
  6. torch_geometric/_onnx.py +14 -0
  7. torch_geometric/config_mixin.py +113 -0
  8. torch_geometric/config_store.py +28 -19
  9. torch_geometric/data/__init__.py +24 -1
  10. torch_geometric/data/batch.py +2 -2
  11. torch_geometric/data/collate.py +8 -2
  12. torch_geometric/data/data.py +16 -8
  13. torch_geometric/data/database.py +61 -15
  14. torch_geometric/data/dataset.py +14 -6
  15. torch_geometric/data/feature_store.py +25 -42
  16. torch_geometric/data/graph_store.py +1 -5
  17. torch_geometric/data/hetero_data.py +18 -9
  18. torch_geometric/data/in_memory_dataset.py +2 -4
  19. torch_geometric/data/large_graph_indexer.py +677 -0
  20. torch_geometric/data/lightning/datamodule.py +4 -4
  21. torch_geometric/data/separate.py +6 -1
  22. torch_geometric/data/storage.py +17 -7
  23. torch_geometric/data/summary.py +14 -4
  24. torch_geometric/data/temporal.py +1 -2
  25. torch_geometric/datasets/__init__.py +17 -2
  26. torch_geometric/datasets/actor.py +9 -11
  27. torch_geometric/datasets/airfrans.py +15 -18
  28. torch_geometric/datasets/airports.py +10 -12
  29. torch_geometric/datasets/amazon.py +8 -11
  30. torch_geometric/datasets/amazon_book.py +9 -10
  31. torch_geometric/datasets/amazon_products.py +9 -10
  32. torch_geometric/datasets/aminer.py +8 -9
  33. torch_geometric/datasets/aqsol.py +10 -13
  34. torch_geometric/datasets/attributed_graph_dataset.py +10 -12
  35. torch_geometric/datasets/ba_multi_shapes.py +10 -12
  36. torch_geometric/datasets/ba_shapes.py +5 -6
  37. torch_geometric/datasets/bitcoin_otc.py +1 -1
  38. torch_geometric/datasets/brca_tgca.py +1 -1
  39. torch_geometric/datasets/cornell.py +145 -0
  40. torch_geometric/datasets/dblp.py +2 -1
  41. torch_geometric/datasets/dbp15k.py +2 -2
  42. torch_geometric/datasets/fake.py +1 -3
  43. torch_geometric/datasets/flickr.py +2 -1
  44. torch_geometric/datasets/freebase.py +1 -1
  45. torch_geometric/datasets/gdelt_lite.py +3 -2
  46. torch_geometric/datasets/ged_dataset.py +3 -2
  47. torch_geometric/datasets/git_mol_dataset.py +263 -0
  48. torch_geometric/datasets/gnn_benchmark_dataset.py +11 -10
  49. torch_geometric/datasets/hgb_dataset.py +8 -8
  50. torch_geometric/datasets/imdb.py +2 -1
  51. torch_geometric/datasets/karate.py +3 -2
  52. torch_geometric/datasets/last_fm.py +2 -1
  53. torch_geometric/datasets/linkx_dataset.py +4 -3
  54. torch_geometric/datasets/lrgb.py +3 -5
  55. torch_geometric/datasets/malnet_tiny.py +4 -3
  56. torch_geometric/datasets/mnist_superpixels.py +2 -3
  57. torch_geometric/datasets/molecule_gpt_dataset.py +485 -0
  58. torch_geometric/datasets/molecule_net.py +15 -3
  59. torch_geometric/datasets/motif_generator/base.py +0 -1
  60. torch_geometric/datasets/neurograph.py +1 -3
  61. torch_geometric/datasets/ogb_mag.py +1 -1
  62. torch_geometric/datasets/opf.py +239 -0
  63. torch_geometric/datasets/ose_gvcs.py +1 -1
  64. torch_geometric/datasets/pascal.py +11 -9
  65. torch_geometric/datasets/pascal_pf.py +1 -1
  66. torch_geometric/datasets/pcpnet_dataset.py +1 -1
  67. torch_geometric/datasets/pcqm4m.py +10 -3
  68. torch_geometric/datasets/ppi.py +1 -1
  69. torch_geometric/datasets/qm9.py +8 -7
  70. torch_geometric/datasets/rcdd.py +4 -4
  71. torch_geometric/datasets/reddit.py +2 -1
  72. torch_geometric/datasets/reddit2.py +2 -1
  73. torch_geometric/datasets/rel_link_pred_dataset.py +3 -3
  74. torch_geometric/datasets/s3dis.py +5 -3
  75. torch_geometric/datasets/shapenet.py +3 -3
  76. torch_geometric/datasets/shrec2016.py +2 -2
  77. torch_geometric/datasets/snap_dataset.py +7 -1
  78. torch_geometric/datasets/tag_dataset.py +350 -0
  79. torch_geometric/datasets/upfd.py +2 -1
  80. torch_geometric/datasets/web_qsp_dataset.py +246 -0
  81. torch_geometric/datasets/webkb.py +2 -2
  82. torch_geometric/datasets/wikics.py +1 -1
  83. torch_geometric/datasets/wikidata.py +3 -2
  84. torch_geometric/datasets/wikipedia_network.py +2 -2
  85. torch_geometric/datasets/willow_object_class.py +1 -1
  86. torch_geometric/datasets/word_net.py +2 -2
  87. torch_geometric/datasets/yelp.py +2 -1
  88. torch_geometric/datasets/zinc.py +1 -1
  89. torch_geometric/device.py +42 -0
  90. torch_geometric/distributed/local_feature_store.py +3 -2
  91. torch_geometric/distributed/local_graph_store.py +2 -1
  92. torch_geometric/distributed/partition.py +9 -8
  93. torch_geometric/edge_index.py +616 -438
  94. torch_geometric/explain/algorithm/base.py +0 -1
  95. torch_geometric/explain/algorithm/graphmask_explainer.py +1 -2
  96. torch_geometric/explain/algorithm/pg_explainer.py +1 -1
  97. torch_geometric/explain/explanation.py +2 -2
  98. torch_geometric/graphgym/checkpoint.py +2 -1
  99. torch_geometric/graphgym/logger.py +4 -4
  100. torch_geometric/graphgym/loss.py +1 -1
  101. torch_geometric/graphgym/utils/agg_runs.py +6 -6
  102. torch_geometric/index.py +826 -0
  103. torch_geometric/inspector.py +8 -3
  104. torch_geometric/io/fs.py +28 -2
  105. torch_geometric/io/npz.py +2 -1
  106. torch_geometric/io/off.py +2 -2
  107. torch_geometric/io/sdf.py +2 -2
  108. torch_geometric/io/tu.py +4 -5
  109. torch_geometric/loader/__init__.py +4 -0
  110. torch_geometric/loader/cluster.py +10 -4
  111. torch_geometric/loader/graph_saint.py +2 -1
  112. torch_geometric/loader/ibmb_loader.py +12 -4
  113. torch_geometric/loader/mixin.py +1 -1
  114. torch_geometric/loader/neighbor_loader.py +1 -1
  115. torch_geometric/loader/neighbor_sampler.py +2 -2
  116. torch_geometric/loader/prefetch.py +1 -1
  117. torch_geometric/loader/rag_loader.py +107 -0
  118. torch_geometric/loader/utils.py +8 -7
  119. torch_geometric/loader/zip_loader.py +10 -0
  120. torch_geometric/metrics/__init__.py +11 -2
  121. torch_geometric/metrics/link_pred.py +159 -34
  122. torch_geometric/nn/aggr/__init__.py +4 -0
  123. torch_geometric/nn/aggr/attention.py +0 -2
  124. torch_geometric/nn/aggr/base.py +2 -4
  125. torch_geometric/nn/aggr/patch_transformer.py +143 -0
  126. torch_geometric/nn/aggr/set_transformer.py +1 -1
  127. torch_geometric/nn/aggr/variance_preserving.py +33 -0
  128. torch_geometric/nn/attention/__init__.py +5 -1
  129. torch_geometric/nn/attention/qformer.py +71 -0
  130. torch_geometric/nn/conv/collect.jinja +7 -4
  131. torch_geometric/nn/conv/cugraph/base.py +8 -12
  132. torch_geometric/nn/conv/edge_conv.py +3 -2
  133. torch_geometric/nn/conv/fused_gat_conv.py +1 -1
  134. torch_geometric/nn/conv/gat_conv.py +35 -7
  135. torch_geometric/nn/conv/gatv2_conv.py +36 -6
  136. torch_geometric/nn/conv/general_conv.py +1 -1
  137. torch_geometric/nn/conv/graph_conv.py +21 -3
  138. torch_geometric/nn/conv/gravnet_conv.py +3 -2
  139. torch_geometric/nn/conv/hetero_conv.py +3 -3
  140. torch_geometric/nn/conv/hgt_conv.py +1 -1
  141. torch_geometric/nn/conv/message_passing.py +138 -87
  142. torch_geometric/nn/conv/mixhop_conv.py +1 -1
  143. torch_geometric/nn/conv/propagate.jinja +9 -1
  144. torch_geometric/nn/conv/rgcn_conv.py +5 -5
  145. torch_geometric/nn/conv/spline_conv.py +4 -4
  146. torch_geometric/nn/conv/x_conv.py +3 -2
  147. torch_geometric/nn/dense/linear.py +11 -6
  148. torch_geometric/nn/fx.py +3 -3
  149. torch_geometric/nn/model_hub.py +3 -1
  150. torch_geometric/nn/models/__init__.py +10 -2
  151. torch_geometric/nn/models/deep_graph_infomax.py +1 -2
  152. torch_geometric/nn/models/dimenet_utils.py +5 -7
  153. torch_geometric/nn/models/g_retriever.py +230 -0
  154. torch_geometric/nn/models/git_mol.py +336 -0
  155. torch_geometric/nn/models/glem.py +385 -0
  156. torch_geometric/nn/models/gnnff.py +0 -1
  157. torch_geometric/nn/models/graph_unet.py +12 -3
  158. torch_geometric/nn/models/jumping_knowledge.py +63 -4
  159. torch_geometric/nn/models/lightgcn.py +1 -1
  160. torch_geometric/nn/models/metapath2vec.py +5 -5
  161. torch_geometric/nn/models/molecule_gpt.py +222 -0
  162. torch_geometric/nn/models/node2vec.py +2 -3
  163. torch_geometric/nn/models/schnet.py +2 -1
  164. torch_geometric/nn/models/signed_gcn.py +3 -3
  165. torch_geometric/nn/module_dict.py +2 -2
  166. torch_geometric/nn/nlp/__init__.py +9 -0
  167. torch_geometric/nn/nlp/llm.py +322 -0
  168. torch_geometric/nn/nlp/sentence_transformer.py +134 -0
  169. torch_geometric/nn/nlp/vision_transformer.py +33 -0
  170. torch_geometric/nn/norm/batch_norm.py +1 -1
  171. torch_geometric/nn/parameter_dict.py +2 -2
  172. torch_geometric/nn/pool/__init__.py +21 -5
  173. torch_geometric/nn/pool/cluster_pool.py +145 -0
  174. torch_geometric/nn/pool/connect/base.py +0 -1
  175. torch_geometric/nn/pool/edge_pool.py +1 -1
  176. torch_geometric/nn/pool/graclus.py +4 -2
  177. torch_geometric/nn/pool/pool.py +8 -2
  178. torch_geometric/nn/pool/select/base.py +0 -1
  179. torch_geometric/nn/pool/voxel_grid.py +3 -2
  180. torch_geometric/nn/resolver.py +1 -1
  181. torch_geometric/nn/sequential.jinja +10 -23
  182. torch_geometric/nn/sequential.py +204 -78
  183. torch_geometric/nn/summary.py +1 -1
  184. torch_geometric/nn/to_hetero_with_bases_transformer.py +19 -19
  185. torch_geometric/profile/__init__.py +2 -0
  186. torch_geometric/profile/nvtx.py +66 -0
  187. torch_geometric/profile/profiler.py +30 -19
  188. torch_geometric/resolver.py +1 -1
  189. torch_geometric/sampler/base.py +34 -13
  190. torch_geometric/sampler/neighbor_sampler.py +11 -10
  191. torch_geometric/sampler/utils.py +1 -1
  192. torch_geometric/template.py +1 -0
  193. torch_geometric/testing/__init__.py +6 -2
  194. torch_geometric/testing/decorators.py +53 -20
  195. torch_geometric/testing/feature_store.py +1 -1
  196. torch_geometric/transforms/__init__.py +2 -0
  197. torch_geometric/transforms/add_metapaths.py +5 -5
  198. torch_geometric/transforms/add_positional_encoding.py +1 -1
  199. torch_geometric/transforms/delaunay.py +65 -14
  200. torch_geometric/transforms/face_to_edge.py +32 -3
  201. torch_geometric/transforms/gdc.py +7 -6
  202. torch_geometric/transforms/laplacian_lambda_max.py +3 -3
  203. torch_geometric/transforms/mask.py +5 -1
  204. torch_geometric/transforms/node_property_split.py +1 -2
  205. torch_geometric/transforms/pad.py +7 -6
  206. torch_geometric/transforms/random_link_split.py +1 -1
  207. torch_geometric/transforms/remove_self_loops.py +36 -0
  208. torch_geometric/transforms/svd_feature_reduction.py +1 -1
  209. torch_geometric/transforms/to_sparse_tensor.py +1 -1
  210. torch_geometric/transforms/two_hop.py +1 -1
  211. torch_geometric/transforms/virtual_node.py +2 -1
  212. torch_geometric/typing.py +43 -6
  213. torch_geometric/utils/__init__.py +5 -1
  214. torch_geometric/utils/_negative_sampling.py +1 -1
  215. torch_geometric/utils/_normalize_edge_index.py +46 -0
  216. torch_geometric/utils/_scatter.py +38 -12
  217. torch_geometric/utils/_subgraph.py +4 -0
  218. torch_geometric/utils/_tree_decomposition.py +2 -2
  219. torch_geometric/utils/augmentation.py +1 -1
  220. torch_geometric/utils/convert.py +12 -8
  221. torch_geometric/utils/geodesic.py +24 -22
  222. torch_geometric/utils/hetero.py +1 -1
  223. torch_geometric/utils/map.py +8 -2
  224. torch_geometric/utils/smiles.py +65 -27
  225. torch_geometric/utils/sparse.py +39 -25
  226. torch_geometric/visualization/graph.py +3 -4
@@ -305,7 +305,7 @@ class Inspector:
305
305
  according to its function signature from a data blob.
306
306
 
307
307
  Args:
308
- func (callabel or str): The function.
308
+ func (callable or str): The function.
309
309
  kwargs (dict[str, Any]): The data blob which may serve as inputs.
310
310
  """
311
311
  out_dict: Dict[str, Any] = {}
@@ -346,7 +346,7 @@ class Inspector:
346
346
  type annotations are not found.
347
347
 
348
348
  Args:
349
- func (callabel or str): The function.
349
+ func (callable or str): The function.
350
350
  exclude (list[int or str]): A list of parameters to exclude, either
351
351
  given by their name or index. (default: :obj:`None`)
352
352
  """
@@ -401,7 +401,8 @@ class Inspector:
401
401
  match = find_parenthesis_content(source, f'self.{func_name}')
402
402
  if match is not None:
403
403
  for i, kwarg in enumerate(split(match, sep=',')):
404
- if exclude is not None and i in exclude:
404
+ if ('=' not in kwarg and exclude is not None
405
+ and i in exclude):
405
406
  continue
406
407
 
407
408
  name_and_content = re.split(r'\s*=\s*', kwarg)
@@ -447,6 +448,10 @@ def type_repr(obj: Any, _globals: Dict[str, Any]) -> str:
447
448
  return '...'
448
449
 
449
450
  if obj.__module__ == 'typing': # Special logic for `typing.*` types:
451
+
452
+ if not hasattr(obj, '_name'):
453
+ return repr(obj)
454
+
450
455
  name = obj._name
451
456
  if name is None: # In some cases, `_name` is not populated.
452
457
  name = str(obj.__origin__).split('.')[-1]
torch_geometric/io/fs.py CHANGED
@@ -1,6 +1,9 @@
1
1
  import io
2
2
  import os.path as osp
3
+ import pickle
4
+ import re
3
5
  import sys
6
+ import warnings
4
7
  from typing import Any, Dict, List, Literal, Optional, Union, overload
5
8
  from uuid import uuid4
6
9
 
@@ -186,11 +189,11 @@ def rm(path: str, recursive: bool = True) -> None:
186
189
  get_fs(path).rm(path, recursive)
187
190
 
188
191
 
189
- def mv(path1: str, path2: str, recursive: bool = True) -> None:
192
+ def mv(path1: str, path2: str) -> None:
190
193
  fs1 = get_fs(path1)
191
194
  fs2 = get_fs(path2)
192
195
  assert fs1.protocol == fs2.protocol
193
- fs1.mv(path1, path2, recursive)
196
+ fs1.mv(path1, path2)
194
197
 
195
198
 
196
199
  def glob(path: str) -> List[str]:
@@ -211,5 +214,28 @@ def torch_save(data: Any, path: str) -> None:
211
214
 
212
215
 
213
216
  def torch_load(path: str, map_location: Any = None) -> Any:
217
+ if torch_geometric.typing.WITH_PT24:
218
+ try:
219
+ with fsspec.open(path, 'rb') as f:
220
+ return torch.load(f, map_location, weights_only=True)
221
+ except pickle.UnpicklingError as e:
222
+ error_msg = str(e)
223
+ if "add_safe_globals" in error_msg:
224
+ warn_msg = ("Weights only load failed. Please file an issue "
225
+ "to make `torch.load(weights_only=True)` "
226
+ "compatible in your case.")
227
+ match = re.search(r'add_safe_globals\(.*?\)', error_msg)
228
+ if match is not None:
229
+ warnings.warn(f"{warn_msg} Please use "
230
+ f"`torch.serialization.{match.group()}` to "
231
+ f"allowlist this global.")
232
+ else:
233
+ warnings.warn(warn_msg)
234
+
235
+ with fsspec.open(path, 'rb') as f:
236
+ return torch.load(f, map_location, weights_only=False)
237
+ else:
238
+ raise e
239
+
214
240
  with fsspec.open(path, 'rb') as f:
215
241
  return torch.load(f, map_location)
torch_geometric/io/npz.py CHANGED
@@ -1,7 +1,6 @@
1
1
  from typing import Any, Dict
2
2
 
3
3
  import numpy as np
4
- import scipy.sparse as sp
5
4
  import torch
6
5
 
7
6
  from torch_geometric.data import Data
@@ -15,6 +14,8 @@ def read_npz(path: str, to_undirected: bool = True) -> Data:
15
14
 
16
15
 
17
16
  def parse_npz(f: Dict[str, Any], to_undirected: bool = True) -> Data:
17
+ import scipy.sparse as sp
18
+
18
19
  x = sp.csr_matrix((f['attr_data'], f['attr_indices'], f['attr_indptr']),
19
20
  f['attr_shape']).todense()
20
21
  x = torch.from_numpy(x).to(torch.float)
torch_geometric/io/off.py CHANGED
@@ -16,7 +16,7 @@ def parse_off(src: List[str]) -> Data:
16
16
  else:
17
17
  src[0] = src[0][3:]
18
18
 
19
- num_nodes, num_faces = [int(item) for item in src[0].split()[:2]]
19
+ num_nodes, num_faces = (int(item) for item in src[0].split()[:2])
20
20
 
21
21
  pos = parse_txt_array(src[1:1 + num_nodes])
22
22
 
@@ -52,7 +52,7 @@ def read_off(path: str) -> Data:
52
52
  Args:
53
53
  path (str): The path to the file.
54
54
  """
55
- with open(path, 'r') as f:
55
+ with open(path) as f:
56
56
  src = f.read().split('\n')[:-1]
57
57
  return parse_off(src)
58
58
 
torch_geometric/io/sdf.py CHANGED
@@ -9,7 +9,7 @@ elems = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4}
9
9
 
10
10
  def parse_sdf(src: str) -> Data:
11
11
  lines = src.split('\n')[3:]
12
- num_atoms, num_bonds = [int(item) for item in lines[0].split()[:2]]
12
+ num_atoms, num_bonds = (int(item) for item in lines[0].split()[:2])
13
13
 
14
14
  atom_block = lines[1:num_atoms + 1]
15
15
  pos = parse_txt_array(atom_block, end=3)
@@ -28,5 +28,5 @@ def parse_sdf(src: str) -> Data:
28
28
 
29
29
 
30
30
  def read_sdf(path: str) -> Data:
31
- with open(path, 'r') as f:
31
+ with open(path) as f:
32
32
  return parse_sdf(f.read())
torch_geometric/io/tu.py CHANGED
@@ -1,7 +1,6 @@
1
1
  import os.path as osp
2
2
  from typing import Dict, List, Optional, Tuple
3
3
 
4
- import numpy as np
5
4
  import torch
6
5
  from torch import Tensor
7
6
 
@@ -37,7 +36,7 @@ def read_tu_data(
37
36
  if node_label.dim() == 1:
38
37
  node_label = node_label.unsqueeze(-1)
39
38
  node_label = node_label - node_label.min(dim=0)[0]
40
- node_labels = node_label.unbind(dim=-1)
39
+ node_labels = list(node_label.unbind(dim=-1))
41
40
  node_labels = [one_hot(x) for x in node_labels]
42
41
  if len(node_labels) == 1:
43
42
  node_label = node_labels[0]
@@ -56,7 +55,7 @@ def read_tu_data(
56
55
  if edge_label.dim() == 1:
57
56
  edge_label = edge_label.unsqueeze(-1)
58
57
  edge_label = edge_label - edge_label.min(dim=0)[0]
59
- edge_labels = edge_label.unbind(dim=-1)
58
+ edge_labels = list(edge_label.unbind(dim=-1))
60
59
  edge_labels = [one_hot(e) for e in edge_labels]
61
60
  if len(edge_labels) == 1:
62
61
  edge_label = edge_labels[0]
@@ -108,11 +107,11 @@ def cat(seq: List[Optional[Tensor]]) -> Optional[Tensor]:
108
107
 
109
108
 
110
109
  def split(data: Data, batch: Tensor) -> Tuple[Data, Dict[str, Tensor]]:
111
- node_slice = cumsum(torch.from_numpy(np.bincount(batch)))
110
+ node_slice = cumsum(torch.bincount(batch))
112
111
 
113
112
  assert data.edge_index is not None
114
113
  row, _ = data.edge_index
115
- edge_slice = cumsum(torch.from_numpy(np.bincount(batch[row])))
114
+ edge_slice = cumsum(torch.bincount(batch[row]))
116
115
 
117
116
  # Edge indices should start at zero for every graph.
118
117
  data.edge_index -= node_slice[batch[row]].unsqueeze(0)
@@ -22,6 +22,7 @@ from .dynamic_batch_sampler import DynamicBatchSampler
22
22
  from .prefetch import PrefetchLoader
23
23
  from .cache import CachedLoader
24
24
  from .mixin import AffinityMixin
25
+ from .rag_loader import RAGQueryLoader, RAGFeatureStore, RAGGraphStore
25
26
 
26
27
  __all__ = classes = [
27
28
  'DataLoader',
@@ -50,6 +51,9 @@ __all__ = classes = [
50
51
  'PrefetchLoader',
51
52
  'CachedLoader',
52
53
  'AffinityMixin',
54
+ 'RAGQueryLoader',
55
+ 'RAGFeatureStore',
56
+ 'RAGGraphStore'
53
57
  ]
54
58
 
55
59
  RandomNodeSampler = deprecated(
@@ -1,4 +1,5 @@
1
1
  import copy
2
+ import os
2
3
  import os.path as osp
3
4
  import sys
4
5
  from dataclasses import dataclass
@@ -10,10 +11,11 @@ from torch import Tensor
10
11
 
11
12
  import torch_geometric.typing
12
13
  from torch_geometric.data import Data
14
+ from torch_geometric.index import index2ptr, ptr2index
15
+ from torch_geometric.io import fs
13
16
  from torch_geometric.typing import pyg_lib
14
17
  from torch_geometric.utils import index_sort, narrow, select, sort_edge_index
15
18
  from torch_geometric.utils.map import map_index
16
- from torch_geometric.utils.sparse import index2ptr, ptr2index
17
19
 
18
20
 
19
21
  @dataclass
@@ -43,6 +45,8 @@ class ClusterData(torch.utils.data.Dataset):
43
45
  (default: :obj:`False`)
44
46
  save_dir (str, optional): If set, will save the partitioned data to the
45
47
  :obj:`save_dir` directory for faster re-use. (default: :obj:`None`)
48
+ filename (str, optional): Name of the stored partitioned file.
49
+ (default: :obj:`None`)
46
50
  log (bool, optional): If set to :obj:`False`, will not log any
47
51
  progress. (default: :obj:`True`)
48
52
  keep_inter_cluster_edges (bool, optional): If set to :obj:`True`,
@@ -56,6 +60,7 @@ class ClusterData(torch.utils.data.Dataset):
56
60
  num_parts: int,
57
61
  recursive: bool = False,
58
62
  save_dir: Optional[str] = None,
63
+ filename: Optional[str] = None,
59
64
  log: bool = True,
60
65
  keep_inter_cluster_edges: bool = False,
61
66
  sparse_format: Literal['csr', 'csc'] = 'csr',
@@ -69,11 +74,11 @@ class ClusterData(torch.utils.data.Dataset):
69
74
  self.sparse_format = sparse_format
70
75
 
71
76
  recursive_str = '_recursive' if recursive else ''
72
- filename = f'metis_{num_parts}{recursive_str}.pt'
73
- path = osp.join(save_dir or '', filename)
77
+ root_dir = osp.join(save_dir or '', f'part_{num_parts}{recursive_str}')
78
+ path = osp.join(root_dir, filename or 'metis.pt')
74
79
 
75
80
  if save_dir is not None and osp.exists(path):
76
- self.partition = torch.load(path)
81
+ self.partition = fs.torch_load(path)
77
82
  else:
78
83
  if log: # pragma: no cover
79
84
  print('Computing METIS partitioning...', file=sys.stderr)
@@ -82,6 +87,7 @@ class ClusterData(torch.utils.data.Dataset):
82
87
  self.partition = self._partition(data.edge_index, cluster)
83
88
 
84
89
  if save_dir is not None:
90
+ os.makedirs(root_dir, exist_ok=True)
85
91
  torch.save(self.partition, path)
86
92
 
87
93
  if log: # pragma: no cover
@@ -4,6 +4,7 @@ from typing import Optional
4
4
  import torch
5
5
  from tqdm import tqdm
6
6
 
7
+ from torch_geometric.io import fs
7
8
  from torch_geometric.typing import SparseTensor
8
9
 
9
10
 
@@ -77,7 +78,7 @@ class GraphSAINTSampler(torch.utils.data.DataLoader):
77
78
  if self.sample_coverage > 0:
78
79
  path = osp.join(save_dir or '', self._filename)
79
80
  if save_dir is not None and osp.exists(path): # pragma: no cover
80
- self.node_norm, self.edge_norm = torch.load(path)
81
+ self.node_norm, self.edge_norm = fs.torch_load(path)
81
82
  else:
82
83
  self.node_norm, self.edge_norm = self._compute_norm()
83
84
  if save_dir is not None: # pragma: no cover
@@ -1,9 +1,17 @@
1
1
  import logging
2
2
  import math
3
- from typing import Callable, Iterator, List, NamedTuple, Optional, Tuple, Union
3
+ from typing import (
4
+ Any,
5
+ Callable,
6
+ Iterator,
7
+ List,
8
+ NamedTuple,
9
+ Optional,
10
+ Tuple,
11
+ Union,
12
+ )
4
13
 
5
14
  import numpy as np
6
- import scipy.sparse
7
15
  import torch
8
16
  from torch import Tensor
9
17
  from tqdm import tqdm
@@ -281,7 +289,7 @@ def create_batchwise_out_aux_pairs(
281
289
  return loader
282
290
 
283
291
 
284
- def get_pairs(ppr_mat: scipy.sparse.csr_matrix) -> np.ndarray:
292
+ def get_pairs(ppr_mat: Any) -> np.ndarray:
285
293
  ppr_mat = ppr_mat + ppr_mat.transpose()
286
294
 
287
295
  ppr_mat = ppr_mat.tocoo()
@@ -387,7 +395,7 @@ def topk_ppr_matrix(
387
395
  output_node_indices: Union[np.ndarray, torch.LongTensor],
388
396
  topk: int,
389
397
  normalization='row',
390
- ) -> Tuple[scipy.sparse.csr_matrix, List[np.ndarray]]:
398
+ ) -> Tuple[Any, List[np.ndarray]]:
391
399
  neighbors, weights = get_ppr(edge_index, alpha, eps, output_node_indices,
392
400
  num_nodes)
393
401
 
@@ -56,7 +56,7 @@ def get_numa_nodes_cores() -> Dict[str, Any]:
56
56
  nodes[numa_node_id] = sorted([(k, sorted(v))
57
57
  for k, v in thread_siblings.items()])
58
58
 
59
- except (OSError, ValueError, IndexError, IOError):
59
+ except (OSError, ValueError, IndexError):
60
60
  Warning('Failed to read NUMA info')
61
61
  return {}
62
62
 
@@ -14,7 +14,7 @@ class NeighborLoader(NodeLoader):
14
14
  This loader allows for mini-batch training of GNNs on large-scale graphs
15
15
  where full-batch training is not feasible.
16
16
 
17
- More specifically, :obj:`num_neighbors` denotes how much neighbors are
17
+ More specifically, :obj:`num_neighbors` denotes how many neighbors are
18
18
  sampled for each node in each iteration.
19
19
  :class:`~torch_geometric.loader.NeighborLoader` takes in this list of
20
20
  :obj:`num_neighbors` and iteratively samples :obj:`num_neighbors[i]` for
@@ -72,9 +72,9 @@ class NeighborSampler(torch.utils.data.DataLoader):
72
72
  `examples/reddit.py
73
73
  <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/
74
74
  reddit.py>`_ or
75
- `examples/ogbn_products_sage.py
75
+ `examples/ogbn_train.py
76
76
  <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/
77
- ogbn_products_sage.py>`_.
77
+ ogbn_train.py>`_.
78
78
 
79
79
  Args:
80
80
  edge_index (Tensor or SparseTensor): A :obj:`torch.LongTensor` or a
@@ -73,7 +73,7 @@ class PrefetchLoader:
73
73
  if isinstance(batch, dict):
74
74
  return {k: self.non_blocking_transfer(v) for k, v in batch.items()}
75
75
 
76
- batch = batch.pin_memory(self.device_helper.device)
76
+ batch = batch.pin_memory()
77
77
  return batch.to(self.device_helper.device, non_blocking=True)
78
78
 
79
79
  def __iter__(self) -> Any:
@@ -0,0 +1,107 @@
1
+ from abc import abstractmethod
2
+ from typing import Any, Callable, Dict, Optional, Protocol, Tuple, Union
3
+
4
+ from torch_geometric.data import Data, FeatureStore, HeteroData
5
+ from torch_geometric.sampler import HeteroSamplerOutput, SamplerOutput
6
+ from torch_geometric.typing import InputEdges, InputNodes
7
+
8
+
9
+ class RAGFeatureStore(Protocol):
10
+ """Feature store template for remote GNN RAG backend."""
11
+ @abstractmethod
12
+ def retrieve_seed_nodes(self, query: Any, **kwargs) -> InputNodes:
13
+ """Makes a comparison between the query and all the nodes to get all
14
+ the closest nodes. Return the indices of the nodes that are to be seeds
15
+ for the RAG Sampler.
16
+ """
17
+ ...
18
+
19
+ @abstractmethod
20
+ def retrieve_seed_edges(self, query: Any, **kwargs) -> InputEdges:
21
+ """Makes a comparison between the query and all the edges to get all
22
+ the closest nodes. Returns the edge indices that are to be the seeds
23
+ for the RAG Sampler.
24
+ """
25
+ ...
26
+
27
+ @abstractmethod
28
+ def load_subgraph(
29
+ self, sample: Union[SamplerOutput, HeteroSamplerOutput]
30
+ ) -> Union[Data, HeteroData]:
31
+ """Combines sampled subgraph output with features in a Data object."""
32
+ ...
33
+
34
+
35
+ class RAGGraphStore(Protocol):
36
+ """Graph store template for remote GNN RAG backend."""
37
+ @abstractmethod
38
+ def sample_subgraph(self, seed_nodes: InputNodes, seed_edges: InputEdges,
39
+ **kwargs) -> Union[SamplerOutput, HeteroSamplerOutput]:
40
+ """Sample a subgraph using the seeded nodes and edges."""
41
+ ...
42
+
43
+ @abstractmethod
44
+ def register_feature_store(self, feature_store: FeatureStore):
45
+ """Register a feature store to be used with the sampler. Samplers need
46
+ info from the feature store in order to work properly on HeteroGraphs.
47
+ """
48
+ ...
49
+
50
+
51
+ # TODO: Make compatible with Heterographs
52
+
53
+
54
+ class RAGQueryLoader:
55
+ """Loader meant for making RAG queries from a remote backend."""
56
+ def __init__(self, data: Tuple[RAGFeatureStore, RAGGraphStore],
57
+ local_filter: Optional[Callable[[Data, Any], Data]] = None,
58
+ seed_nodes_kwargs: Optional[Dict[str, Any]] = None,
59
+ seed_edges_kwargs: Optional[Dict[str, Any]] = None,
60
+ sampler_kwargs: Optional[Dict[str, Any]] = None,
61
+ loader_kwargs: Optional[Dict[str, Any]] = None):
62
+ """Loader meant for making queries from a remote backend.
63
+
64
+ Args:
65
+ data (Tuple[RAGFeatureStore, RAGGraphStore]): Remote FeatureStore
66
+ and GraphStore to load from. Assumed to conform to the
67
+ protocols listed above.
68
+ local_filter (Optional[Callable[[Data, Any], Data]], optional):
69
+ Optional local transform to apply to data after retrieval.
70
+ Defaults to None.
71
+ seed_nodes_kwargs (Optional[Dict[str, Any]], optional): Paramaters
72
+ to pass into process for fetching seed nodes. Defaults to None.
73
+ seed_edges_kwargs (Optional[Dict[str, Any]], optional): Parameters
74
+ to pass into process for fetching seed edges. Defaults to None.
75
+ sampler_kwargs (Optional[Dict[str, Any]], optional): Parameters to
76
+ pass into process for sampling graph. Defaults to None.
77
+ loader_kwargs (Optional[Dict[str, Any]], optional): Parameters to
78
+ pass into process for loading graph features. Defaults to None.
79
+ """
80
+ fstore, gstore = data
81
+ self.feature_store = fstore
82
+ self.graph_store = gstore
83
+ self.graph_store.register_feature_store(self.feature_store)
84
+ self.local_filter = local_filter
85
+ self.seed_nodes_kwargs = seed_nodes_kwargs or {}
86
+ self.seed_edges_kwargs = seed_edges_kwargs or {}
87
+ self.sampler_kwargs = sampler_kwargs or {}
88
+ self.loader_kwargs = loader_kwargs or {}
89
+
90
+ def query(self, query: Any) -> Data:
91
+ """Retrieve a subgraph associated with the query with all its feature
92
+ attributes.
93
+ """
94
+ seed_nodes = self.feature_store.retrieve_seed_nodes(
95
+ query, **self.seed_nodes_kwargs)
96
+ seed_edges = self.feature_store.retrieve_seed_edges(
97
+ query, **self.seed_edges_kwargs)
98
+
99
+ subgraph_sample = self.graph_store.sample_subgraph(
100
+ seed_nodes, seed_edges, **self.sampler_kwargs)
101
+
102
+ data = self.feature_store.load_subgraph(sample=subgraph_sample,
103
+ **self.loader_kwargs)
104
+
105
+ if self.local_filter:
106
+ data = self.local_filter(data, query)
107
+ return data
@@ -8,7 +8,6 @@ import torch
8
8
  from torch import Tensor
9
9
 
10
10
  import torch_geometric.typing
11
- from torch_geometric import EdgeIndex
12
11
  from torch_geometric.data import (
13
12
  Data,
14
13
  FeatureStore,
@@ -105,13 +104,15 @@ def filter_edge_store_(store: EdgeStorage, out_store: EdgeStorage, row: Tensor,
105
104
  # which represents the new graph as denoted by `(row, col)`:
106
105
  for key, value in store.items():
107
106
  if key == 'edge_index':
107
+ edge_index = torch.stack([row, col], dim=0).to(value.device)
108
108
  # TODO Integrate `EdgeIndex` into `custom_store`.
109
- out_store.edge_index = EdgeIndex(
110
- torch.stack([row, col], dim=0).to(value.device),
111
- sparse_size=out_store.size(),
112
- sort_order='col',
113
- # TODO Support `is_undirected`.
114
- )
109
+ # edge_index = EdgeIndex(
110
+ # torch.stack([row, col], dim=0).to(value.device),
111
+ # sparse_size=out_store.size(),
112
+ # sort_order='col',
113
+ # # TODO Support `is_undirected`.
114
+ # )
115
+ out_store.edge_index = edge_index
115
116
 
116
117
  elif key == 'adj_t':
117
118
  # NOTE: We expect `(row, col)` to be sorted by `col` (CSC layout).
@@ -59,6 +59,16 @@ class ZipLoader(torch.utils.data.DataLoader):
59
59
  self.loaders = loaders
60
60
  self.filter_per_worker = filter_per_worker
61
61
 
62
+ def __call__(
63
+ self,
64
+ index: Union[Tensor, List[int]],
65
+ ) -> Union[Tuple[Data, ...], Tuple[HeteroData, ...]]:
66
+ r"""Samples subgraphs from a batch of input IDs."""
67
+ out = self.collate_fn(index)
68
+ if not self.filter_per_worker:
69
+ out = self.filter_fn(out)
70
+ return out
71
+
62
72
  def collate_fn(self, index: List[int]) -> Tuple[Any, ...]:
63
73
  if not isinstance(index, Tensor):
64
74
  index = torch.tensor(index, dtype=torch.long)
@@ -1,14 +1,23 @@
1
1
  # flake8: noqa
2
2
 
3
- from .link_pred import (LinkPredPrecision, LinkPredRecall, LinkPredF1,
4
- LinkPredMAP, LinkPredNDCG)
3
+ from .link_pred import (
4
+ LinkPredMetricCollection,
5
+ LinkPredPrecision,
6
+ LinkPredRecall,
7
+ LinkPredF1,
8
+ LinkPredMAP,
9
+ LinkPredNDCG,
10
+ LinkPredMRR,
11
+ )
5
12
 
6
13
  link_pred_metrics = [
14
+ 'LinkPredMetricCollection',
7
15
  'LinkPredPrecision',
8
16
  'LinkPredRecall',
9
17
  'LinkPredF1',
10
18
  'LinkPredMAP',
11
19
  'LinkPredNDCG',
20
+ 'LinkPredMRR',
12
21
  ]
13
22
 
14
23
  __all__ = link_pred_metrics