pyg-nightly 2.6.0.dev20240511__py3-none-any.whl → 2.7.0.dev20250114__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (205) hide show
  1. {pyg_nightly-2.6.0.dev20240511.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/METADATA +30 -31
  2. {pyg_nightly-2.6.0.dev20240511.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/RECORD +205 -181
  3. {pyg_nightly-2.6.0.dev20240511.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/WHEEL +1 -1
  4. torch_geometric/__init__.py +26 -1
  5. torch_geometric/_compile.py +8 -1
  6. torch_geometric/_onnx.py +14 -0
  7. torch_geometric/config_mixin.py +113 -0
  8. torch_geometric/config_store.py +16 -14
  9. torch_geometric/data/__init__.py +24 -1
  10. torch_geometric/data/batch.py +2 -2
  11. torch_geometric/data/data.py +13 -8
  12. torch_geometric/data/database.py +15 -7
  13. torch_geometric/data/dataset.py +14 -6
  14. torch_geometric/data/feature_store.py +13 -22
  15. torch_geometric/data/graph_store.py +0 -4
  16. torch_geometric/data/hetero_data.py +4 -4
  17. torch_geometric/data/in_memory_dataset.py +2 -4
  18. torch_geometric/data/large_graph_indexer.py +677 -0
  19. torch_geometric/data/lightning/datamodule.py +4 -4
  20. torch_geometric/data/storage.py +15 -5
  21. torch_geometric/data/summary.py +14 -4
  22. torch_geometric/data/temporal.py +1 -2
  23. torch_geometric/datasets/__init__.py +11 -1
  24. torch_geometric/datasets/actor.py +9 -11
  25. torch_geometric/datasets/airfrans.py +15 -18
  26. torch_geometric/datasets/airports.py +10 -12
  27. torch_geometric/datasets/amazon.py +8 -11
  28. torch_geometric/datasets/amazon_book.py +9 -10
  29. torch_geometric/datasets/amazon_products.py +9 -10
  30. torch_geometric/datasets/aminer.py +8 -9
  31. torch_geometric/datasets/aqsol.py +10 -13
  32. torch_geometric/datasets/attributed_graph_dataset.py +10 -12
  33. torch_geometric/datasets/ba_multi_shapes.py +10 -12
  34. torch_geometric/datasets/ba_shapes.py +5 -6
  35. torch_geometric/datasets/bitcoin_otc.py +1 -1
  36. torch_geometric/datasets/brca_tgca.py +1 -1
  37. torch_geometric/datasets/dblp.py +2 -1
  38. torch_geometric/datasets/dbp15k.py +2 -2
  39. torch_geometric/datasets/fake.py +1 -3
  40. torch_geometric/datasets/flickr.py +2 -1
  41. torch_geometric/datasets/freebase.py +1 -1
  42. torch_geometric/datasets/gdelt_lite.py +3 -2
  43. torch_geometric/datasets/ged_dataset.py +3 -2
  44. torch_geometric/datasets/git_mol_dataset.py +263 -0
  45. torch_geometric/datasets/gnn_benchmark_dataset.py +6 -5
  46. torch_geometric/datasets/hgb_dataset.py +8 -8
  47. torch_geometric/datasets/imdb.py +2 -1
  48. torch_geometric/datasets/last_fm.py +2 -1
  49. torch_geometric/datasets/linkx_dataset.py +4 -3
  50. torch_geometric/datasets/lrgb.py +3 -5
  51. torch_geometric/datasets/malnet_tiny.py +4 -3
  52. torch_geometric/datasets/mnist_superpixels.py +2 -3
  53. torch_geometric/datasets/molecule_gpt_dataset.py +485 -0
  54. torch_geometric/datasets/molecule_net.py +7 -1
  55. torch_geometric/datasets/motif_generator/base.py +0 -1
  56. torch_geometric/datasets/neurograph.py +1 -3
  57. torch_geometric/datasets/ogb_mag.py +1 -1
  58. torch_geometric/datasets/opf.py +239 -0
  59. torch_geometric/datasets/ose_gvcs.py +1 -1
  60. torch_geometric/datasets/pascal_pf.py +1 -1
  61. torch_geometric/datasets/pcpnet_dataset.py +1 -1
  62. torch_geometric/datasets/pcqm4m.py +2 -1
  63. torch_geometric/datasets/ppi.py +1 -1
  64. torch_geometric/datasets/qm9.py +4 -3
  65. torch_geometric/datasets/reddit.py +2 -1
  66. torch_geometric/datasets/reddit2.py +2 -1
  67. torch_geometric/datasets/rel_link_pred_dataset.py +3 -3
  68. torch_geometric/datasets/s3dis.py +2 -2
  69. torch_geometric/datasets/shapenet.py +3 -3
  70. torch_geometric/datasets/shrec2016.py +2 -2
  71. torch_geometric/datasets/tag_dataset.py +350 -0
  72. torch_geometric/datasets/upfd.py +2 -1
  73. torch_geometric/datasets/web_qsp_dataset.py +246 -0
  74. torch_geometric/datasets/webkb.py +2 -2
  75. torch_geometric/datasets/wikics.py +1 -1
  76. torch_geometric/datasets/wikidata.py +3 -2
  77. torch_geometric/datasets/wikipedia_network.py +2 -2
  78. torch_geometric/datasets/word_net.py +2 -2
  79. torch_geometric/datasets/yelp.py +2 -1
  80. torch_geometric/datasets/zinc.py +1 -1
  81. torch_geometric/device.py +42 -0
  82. torch_geometric/distributed/local_feature_store.py +3 -2
  83. torch_geometric/distributed/local_graph_store.py +2 -1
  84. torch_geometric/distributed/partition.py +9 -8
  85. torch_geometric/edge_index.py +17 -8
  86. torch_geometric/explain/algorithm/base.py +0 -1
  87. torch_geometric/explain/algorithm/pg_explainer.py +1 -1
  88. torch_geometric/explain/explanation.py +2 -2
  89. torch_geometric/graphgym/checkpoint.py +2 -1
  90. torch_geometric/graphgym/logger.py +4 -4
  91. torch_geometric/graphgym/loss.py +1 -1
  92. torch_geometric/graphgym/utils/agg_runs.py +6 -6
  93. torch_geometric/index.py +20 -7
  94. torch_geometric/inspector.py +6 -2
  95. torch_geometric/io/fs.py +28 -2
  96. torch_geometric/io/npz.py +2 -1
  97. torch_geometric/io/off.py +2 -2
  98. torch_geometric/io/sdf.py +2 -2
  99. torch_geometric/io/tu.py +2 -3
  100. torch_geometric/loader/__init__.py +4 -0
  101. torch_geometric/loader/cluster.py +9 -3
  102. torch_geometric/loader/graph_saint.py +2 -1
  103. torch_geometric/loader/ibmb_loader.py +12 -4
  104. torch_geometric/loader/mixin.py +1 -1
  105. torch_geometric/loader/neighbor_loader.py +1 -1
  106. torch_geometric/loader/neighbor_sampler.py +2 -2
  107. torch_geometric/loader/prefetch.py +1 -1
  108. torch_geometric/loader/rag_loader.py +107 -0
  109. torch_geometric/loader/zip_loader.py +10 -0
  110. torch_geometric/metrics/__init__.py +11 -2
  111. torch_geometric/metrics/link_pred.py +159 -34
  112. torch_geometric/nn/aggr/__init__.py +2 -0
  113. torch_geometric/nn/aggr/attention.py +0 -2
  114. torch_geometric/nn/aggr/base.py +2 -4
  115. torch_geometric/nn/aggr/patch_transformer.py +143 -0
  116. torch_geometric/nn/aggr/set_transformer.py +1 -1
  117. torch_geometric/nn/attention/__init__.py +5 -1
  118. torch_geometric/nn/attention/qformer.py +71 -0
  119. torch_geometric/nn/conv/collect.jinja +6 -3
  120. torch_geometric/nn/conv/cugraph/base.py +0 -1
  121. torch_geometric/nn/conv/edge_conv.py +3 -2
  122. torch_geometric/nn/conv/gat_conv.py +35 -7
  123. torch_geometric/nn/conv/gatv2_conv.py +36 -6
  124. torch_geometric/nn/conv/general_conv.py +1 -1
  125. torch_geometric/nn/conv/gravnet_conv.py +3 -2
  126. torch_geometric/nn/conv/hetero_conv.py +3 -3
  127. torch_geometric/nn/conv/hgt_conv.py +1 -1
  128. torch_geometric/nn/conv/message_passing.py +100 -82
  129. torch_geometric/nn/conv/mixhop_conv.py +1 -1
  130. torch_geometric/nn/conv/rgcn_conv.py +2 -1
  131. torch_geometric/nn/conv/spline_conv.py +4 -4
  132. torch_geometric/nn/conv/x_conv.py +3 -2
  133. torch_geometric/nn/dense/linear.py +5 -4
  134. torch_geometric/nn/fx.py +3 -3
  135. torch_geometric/nn/model_hub.py +3 -1
  136. torch_geometric/nn/models/__init__.py +10 -2
  137. torch_geometric/nn/models/deep_graph_infomax.py +1 -2
  138. torch_geometric/nn/models/dimenet_utils.py +5 -7
  139. torch_geometric/nn/models/g_retriever.py +230 -0
  140. torch_geometric/nn/models/git_mol.py +336 -0
  141. torch_geometric/nn/models/glem.py +385 -0
  142. torch_geometric/nn/models/gnnff.py +0 -1
  143. torch_geometric/nn/models/graph_unet.py +12 -3
  144. torch_geometric/nn/models/jumping_knowledge.py +63 -4
  145. torch_geometric/nn/models/lightgcn.py +1 -1
  146. torch_geometric/nn/models/metapath2vec.py +3 -4
  147. torch_geometric/nn/models/molecule_gpt.py +222 -0
  148. torch_geometric/nn/models/node2vec.py +1 -2
  149. torch_geometric/nn/models/schnet.py +2 -1
  150. torch_geometric/nn/models/signed_gcn.py +3 -3
  151. torch_geometric/nn/module_dict.py +2 -2
  152. torch_geometric/nn/nlp/__init__.py +9 -0
  153. torch_geometric/nn/nlp/llm.py +322 -0
  154. torch_geometric/nn/nlp/sentence_transformer.py +134 -0
  155. torch_geometric/nn/nlp/vision_transformer.py +33 -0
  156. torch_geometric/nn/norm/batch_norm.py +1 -1
  157. torch_geometric/nn/parameter_dict.py +2 -2
  158. torch_geometric/nn/pool/__init__.py +7 -5
  159. torch_geometric/nn/pool/cluster_pool.py +145 -0
  160. torch_geometric/nn/pool/connect/base.py +0 -1
  161. torch_geometric/nn/pool/edge_pool.py +1 -1
  162. torch_geometric/nn/pool/graclus.py +4 -2
  163. torch_geometric/nn/pool/select/base.py +0 -1
  164. torch_geometric/nn/pool/voxel_grid.py +3 -2
  165. torch_geometric/nn/resolver.py +1 -1
  166. torch_geometric/nn/sequential.jinja +10 -23
  167. torch_geometric/nn/sequential.py +203 -77
  168. torch_geometric/nn/summary.py +1 -1
  169. torch_geometric/nn/to_hetero_with_bases_transformer.py +19 -19
  170. torch_geometric/profile/__init__.py +2 -0
  171. torch_geometric/profile/nvtx.py +66 -0
  172. torch_geometric/profile/profiler.py +24 -15
  173. torch_geometric/resolver.py +1 -1
  174. torch_geometric/sampler/base.py +34 -13
  175. torch_geometric/sampler/neighbor_sampler.py +11 -10
  176. torch_geometric/testing/decorators.py +17 -22
  177. torch_geometric/transforms/__init__.py +2 -0
  178. torch_geometric/transforms/add_metapaths.py +4 -4
  179. torch_geometric/transforms/add_positional_encoding.py +1 -1
  180. torch_geometric/transforms/delaunay.py +65 -14
  181. torch_geometric/transforms/face_to_edge.py +32 -3
  182. torch_geometric/transforms/gdc.py +7 -6
  183. torch_geometric/transforms/laplacian_lambda_max.py +2 -2
  184. torch_geometric/transforms/mask.py +5 -1
  185. torch_geometric/transforms/node_property_split.py +1 -2
  186. torch_geometric/transforms/pad.py +7 -6
  187. torch_geometric/transforms/random_link_split.py +1 -1
  188. torch_geometric/transforms/remove_self_loops.py +36 -0
  189. torch_geometric/transforms/svd_feature_reduction.py +1 -1
  190. torch_geometric/transforms/virtual_node.py +2 -1
  191. torch_geometric/typing.py +31 -5
  192. torch_geometric/utils/__init__.py +5 -1
  193. torch_geometric/utils/_negative_sampling.py +1 -1
  194. torch_geometric/utils/_normalize_edge_index.py +46 -0
  195. torch_geometric/utils/_scatter.py +37 -12
  196. torch_geometric/utils/_subgraph.py +4 -0
  197. torch_geometric/utils/_tree_decomposition.py +2 -2
  198. torch_geometric/utils/augmentation.py +1 -1
  199. torch_geometric/utils/convert.py +5 -5
  200. torch_geometric/utils/geodesic.py +24 -22
  201. torch_geometric/utils/hetero.py +1 -1
  202. torch_geometric/utils/map.py +1 -1
  203. torch_geometric/utils/smiles.py +66 -28
  204. torch_geometric/utils/sparse.py +25 -10
  205. torch_geometric/visualization/graph.py +3 -4
@@ -67,7 +67,7 @@ class WordNet18(InMemoryDataset):
67
67
  def process(self) -> None:
68
68
  srcs, dsts, edge_types = [], [], []
69
69
  for path in self.raw_paths:
70
- with open(path, 'r') as f:
70
+ with open(path) as f:
71
71
  edges = [int(x) for x in f.read().split()[1:]]
72
72
  edge = torch.tensor(edges, dtype=torch.long)
73
73
  srcs.append(edge[::3])
@@ -173,7 +173,7 @@ class WordNet18RR(InMemoryDataset):
173
173
 
174
174
  srcs, dsts, edge_types = [], [], []
175
175
  for path in self.raw_paths:
176
- with open(path, 'r') as f:
176
+ with open(path) as f:
177
177
  edges = f.read().split()
178
178
 
179
179
  _src = edges[::3]
@@ -3,7 +3,6 @@ import os.path as osp
3
3
  from typing import Callable, List, Optional
4
4
 
5
5
  import numpy as np
6
- import scipy.sparse as sp
7
6
  import torch
8
7
 
9
8
  from torch_geometric.data import Data, InMemoryDataset, download_google_url
@@ -73,6 +72,8 @@ class Yelp(InMemoryDataset):
73
72
  download_google_url(self.role_id, self.raw_dir, 'role.json')
74
73
 
75
74
  def process(self) -> None:
75
+ import scipy.sparse as sp
76
+
76
77
  f = np.load(osp.join(self.raw_dir, 'adj_full.npz'))
77
78
  adj = sp.csr_matrix((f['data'], f['indices'], f['indptr']), f['shape'])
78
79
  adj = adj.tocoo()
@@ -139,7 +139,7 @@ class ZINC(InMemoryDataset):
139
139
  indices = list(range(len(mols)))
140
140
 
141
141
  if self.subset:
142
- with open(osp.join(self.raw_dir, f'{split}.index'), 'r') as f:
142
+ with open(osp.join(self.raw_dir, f'{split}.index')) as f:
143
143
  indices = [int(x) for x in f.read()[:-1].split(',')]
144
144
 
145
145
  pbar = tqdm(total=len(indices))
@@ -0,0 +1,42 @@
1
+ from typing import Any
2
+
3
+ import torch
4
+
5
+
6
+ def is_mps_available() -> bool:
7
+ r"""Returns a bool indicating if MPS is currently available."""
8
+ if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
9
+ try: # Github CI may not have access to MPS hardware. Confirm:
10
+ torch.empty(1, device='mps')
11
+ return True
12
+ except Exception:
13
+ return False
14
+ return False
15
+
16
+
17
+ def is_xpu_available() -> bool:
18
+ r"""Returns a bool indicating if XPU is currently available."""
19
+ if hasattr(torch, 'xpu') and torch.xpu.is_available():
20
+ return True
21
+ try:
22
+ import intel_extension_for_pytorch as ipex
23
+ return ipex.xpu.is_available()
24
+ except ImportError:
25
+ return False
26
+
27
+
28
+ def device(device: Any) -> torch.device:
29
+ r"""Returns a :class:`torch.device`.
30
+
31
+ If :obj:`"auto"` is specified, returns the optimal device depending on
32
+ available hardware.
33
+ """
34
+ if device != 'auto':
35
+ return torch.device(device)
36
+ if torch.cuda.is_available():
37
+ return torch.device('cuda')
38
+ if is_mps_available():
39
+ return torch.device('mps')
40
+ if is_xpu_available():
41
+ return torch.device('xpu')
42
+ return torch.device('cpu')
@@ -15,6 +15,7 @@ from torch_geometric.distributed.rpc import (
15
15
  rpc_async,
16
16
  rpc_register,
17
17
  )
18
+ from torch_geometric.io import fs
18
19
  from torch_geometric.typing import EdgeType, NodeOrEdgeType, NodeType
19
20
 
20
21
 
@@ -415,11 +416,11 @@ class LocalFeatureStore(FeatureStore):
415
416
 
416
417
  node_feats: Optional[Dict[str, Any]] = None
417
418
  if osp.exists(osp.join(part_dir, 'node_feats.pt')):
418
- node_feats = torch.load(osp.join(part_dir, 'node_feats.pt'))
419
+ node_feats = fs.torch_load(osp.join(part_dir, 'node_feats.pt'))
419
420
 
420
421
  edge_feats: Optional[Dict[str, Any]] = None
421
422
  if osp.exists(osp.join(part_dir, 'edge_feats.pt')):
422
- edge_feats = torch.load(osp.join(part_dir, 'edge_feats.pt'))
423
+ edge_feats = fs.torch_load(osp.join(part_dir, 'edge_feats.pt'))
423
424
 
424
425
  if not meta['is_hetero'] and node_feats is not None:
425
426
  feat_store.put_global_id(node_feats['global_id'], group_name=None)
@@ -6,6 +6,7 @@ from torch import Tensor
6
6
 
7
7
  from torch_geometric.data import EdgeAttr, GraphStore
8
8
  from torch_geometric.distributed.partition import load_partition_info
9
+ from torch_geometric.io import fs
9
10
  from torch_geometric.typing import EdgeTensorType, EdgeType, NodeType
10
11
  from torch_geometric.utils import sort_edge_index
11
12
 
@@ -185,7 +186,7 @@ class LocalGraphStore(GraphStore):
185
186
  graph_store.edge_pb = edge_pb
186
187
  graph_store.meta = meta
187
188
 
188
- graph_data = torch.load(osp.join(part_dir, 'graph.pt'))
189
+ graph_data = fs.torch_load(osp.join(part_dir, 'graph.pt'))
189
190
  graph_store.is_sorted = meta['is_sorted']
190
191
 
191
192
  if not meta['is_hetero']:
@@ -3,15 +3,16 @@ import logging
3
3
  import os
4
4
  import os.path as osp
5
5
  from collections import defaultdict
6
- from typing import List, Optional, Union
6
+ from typing import Dict, List, Optional, Tuple, Union
7
7
 
8
8
  import torch
9
9
 
10
10
  import torch_geometric.distributed as pyg_dist
11
11
  from torch_geometric.data import Data, HeteroData
12
+ from torch_geometric.io import fs
12
13
  from torch_geometric.loader.cluster import ClusterData
13
14
  from torch_geometric.sampler.utils import sort_csc
14
- from torch_geometric.typing import Dict, EdgeType, EdgeTypeStr, NodeType, Tuple
15
+ from torch_geometric.typing import EdgeType, EdgeTypeStr, NodeType
15
16
 
16
17
 
17
18
  class Partitioner:
@@ -23,7 +24,7 @@ class Partitioner:
23
24
 
24
25
  **Homogeneous graphs:**
25
26
 
26
- .. code-block::
27
+ .. code-block:: none
27
28
 
28
29
  root/
29
30
  |-- META.json
@@ -40,7 +41,7 @@ class Partitioner:
40
41
 
41
42
  **Heterogeneous graphs:**
42
43
 
43
- .. code-block::
44
+ .. code-block:: none
44
45
 
45
46
  root/
46
47
  |-- META.json
@@ -380,21 +381,21 @@ def load_partition_info(
380
381
  assert osp.exists(partition_dir)
381
382
 
382
383
  if meta['is_hetero'] is False:
383
- node_pb = torch.load(osp.join(root_dir, 'node_map.pt'))
384
- edge_pb = torch.load(osp.join(root_dir, 'edge_map.pt'))
384
+ node_pb = fs.torch_load(osp.join(root_dir, 'node_map.pt'))
385
+ edge_pb = fs.torch_load(osp.join(root_dir, 'edge_map.pt'))
385
386
 
386
387
  return (meta, num_partitions, partition_idx, node_pb, edge_pb)
387
388
  else:
388
389
  node_pb_dict = {}
389
390
  node_pb_dir = osp.join(root_dir, 'node_map')
390
391
  for ntype in meta['node_types']:
391
- node_pb_dict[ntype] = torch.load(
392
+ node_pb_dict[ntype] = fs.torch_load(
392
393
  osp.join(node_pb_dir, f'{pyg_dist.utils.as_str(ntype)}.pt'))
393
394
 
394
395
  edge_pb_dict = {}
395
396
  edge_pb_dir = osp.join(root_dir, 'edge_map')
396
397
  for etype in meta['edge_types']:
397
- edge_pb_dict[tuple(etype)] = torch.load(
398
+ edge_pb_dict[tuple(etype)] = fs.torch_load(
398
399
  osp.join(edge_pb_dir, f'{pyg_dist.utils.as_str(etype)}.pt'))
399
400
 
400
401
  return (meta, num_partitions, partition_idx, node_pb_dict,
@@ -173,7 +173,7 @@ class EdgeIndex(Tensor):
173
173
  :meth:`EdgeIndex.fill_cache_`, and are maintained and adjusted over its
174
174
  lifespan (*e.g.*, when calling :meth:`EdgeIndex.flip`).
175
175
 
176
- This representation ensures for optimal computation in GNN message passing
176
+ This representation ensures optimal computation in GNN message passing
177
177
  schemes, while preserving the ease-of-use of regular COO-based :pyg:`PyG`
178
178
  workflows.
179
179
 
@@ -820,19 +820,28 @@ class EdgeIndex(Tensor):
820
820
  :obj:`1.0`. (default: :obj:`None`)
821
821
  """
822
822
  value = self._get_value() if value is None else value
823
- out = torch.sparse_coo_tensor(
823
+
824
+ if not torch_geometric.typing.WITH_PT21:
825
+ out = torch.sparse_coo_tensor(
826
+ indices=self._data,
827
+ values=value,
828
+ size=self.get_sparse_size(),
829
+ device=self.device,
830
+ requires_grad=value.requires_grad,
831
+ )
832
+ if self.is_sorted_by_row:
833
+ out = out._coalesced_(True)
834
+ return out
835
+
836
+ return torch.sparse_coo_tensor(
824
837
  indices=self._data,
825
838
  values=value,
826
839
  size=self.get_sparse_size(),
827
840
  device=self.device,
828
841
  requires_grad=value.requires_grad,
842
+ is_coalesced=True if self.is_sorted_by_row else None,
829
843
  )
830
844
 
831
- if self.is_sorted_by_row:
832
- out = out._coalesced_(True)
833
-
834
- return out
835
-
836
845
  def to_sparse_csr( # type: ignore
837
846
  self,
838
847
  value: Optional[Tensor] = None,
@@ -1928,7 +1937,7 @@ def _spmm(
1928
1937
  if transpose and not input.is_sorted_by_col:
1929
1938
  cls_name = input.__class__.__name__
1930
1939
  raise ValueError(f"'matmul(..., transpose=True)' requires "
1931
- f"'{cls_name}' to be sorted by colums")
1940
+ f"'{cls_name}' to be sorted by columns")
1932
1941
 
1933
1942
  if (torch_geometric.typing.WITH_TORCH_SPARSE and not is_compiling()
1934
1943
  and other.is_cuda): # pragma: no cover
@@ -50,7 +50,6 @@ class ExplainerAlgorithm(torch.nn.Module):
50
50
  r"""Checks if the explainer supports the user-defined settings provided
51
51
  in :obj:`self.explainer_config`, :obj:`self.model_config`.
52
52
  """
53
- pass
54
53
 
55
54
  ###########################################################################
56
55
 
@@ -59,7 +59,7 @@ class PGExplainer(ExplainerAlgorithm):
59
59
  'edge_size': 0.05,
60
60
  'edge_ent': 1.0,
61
61
  'temp': [5.0, 2.0],
62
- 'bias': 0.0,
62
+ 'bias': 0.01,
63
63
  }
64
64
 
65
65
  def __init__(self, epochs: int, lr: float = 0.003, **kwargs):
@@ -340,10 +340,10 @@ class HeteroExplanation(HeteroData, ExplanationMixin):
340
340
  """
341
341
  node_mask_dict = self.node_mask_dict
342
342
  for node_mask in node_mask_dict.values():
343
- if node_mask.dim() != 2 or node_mask.size(1) <= 1:
343
+ if node_mask.dim() != 2:
344
344
  raise ValueError(f"Cannot compute feature importance for "
345
345
  f"object-level 'node_mask' "
346
- f"(got shape {node_mask_dict.size()})")
346
+ f"(got shape {node_mask.size()})")
347
347
 
348
348
  if feat_labels is None:
349
349
  feat_labels = {}
@@ -6,6 +6,7 @@ from typing import Any, Dict, List, Optional, Union
6
6
  import torch
7
7
 
8
8
  from torch_geometric.graphgym.config import cfg
9
+ from torch_geometric.io import fs
9
10
 
10
11
  MODEL_STATE = 'model_state'
11
12
  OPTIMIZER_STATE = 'optimizer_state'
@@ -25,7 +26,7 @@ def load_ckpt(
25
26
  if not osp.exists(path):
26
27
  return 0
27
28
 
28
- ckpt = torch.load(path)
29
+ ckpt = fs.torch_load(path)
29
30
  model.load_state_dict(ckpt[MODEL_STATE])
30
31
  if optimizer is not None and OPTIMIZER_STATE in ckpt:
31
32
  optimizer.load_state_dict(ckpt[OPTIMIZER_STATE])
@@ -19,7 +19,7 @@ def set_printing():
19
19
  logging.root.handlers = []
20
20
  logging_cfg = {'level': logging.INFO, 'format': '%(message)s'}
21
21
  os.makedirs(cfg.run_dir, exist_ok=True)
22
- h_file = logging.FileHandler('{}/logging.log'.format(cfg.run_dir))
22
+ h_file = logging.FileHandler(f'{cfg.run_dir}/logging.log')
23
23
  h_stdout = logging.StreamHandler(sys.stdout)
24
24
  if cfg.print == 'file':
25
25
  logging_cfg['handlers'] = [h_file]
@@ -40,7 +40,7 @@ class Logger:
40
40
  self._epoch_total = cfg.optim.max_epoch
41
41
  self._time_total = 0 # won't be reset
42
42
 
43
- self.out_dir = '{}/{}'.format(cfg.run_dir, name)
43
+ self.out_dir = f'{cfg.run_dir}/{name}'
44
44
  os.makedirs(self.out_dir, exist_ok=True)
45
45
  if cfg.tensorboard_each_run:
46
46
  from tensorboardX import SummaryWriter
@@ -210,9 +210,9 @@ class Logger:
210
210
  }
211
211
 
212
212
  # print
213
- logging.info('{}: {}'.format(self.name, stats))
213
+ logging.info(f'{self.name}: {stats}')
214
214
  # json
215
- dict_to_json(stats, '{}/stats.json'.format(self.out_dir))
215
+ dict_to_json(stats, f'{self.out_dir}/stats.json')
216
216
  # tensorboard
217
217
  if cfg.tensorboard_each_run:
218
218
  dict_to_tb(stats, self.tb_writer, cur_epoch)
@@ -10,7 +10,7 @@ def compute_loss(pred, true):
10
10
 
11
11
  Args:
12
12
  pred (torch.tensor): Unnormalized prediction
13
- true (torch.tensor): Grou
13
+ true (torch.tensor): Ground truth labels
14
14
 
15
15
  Returns: Loss, normalized prediction score
16
16
 
@@ -54,7 +54,7 @@ def agg_dict_list(dict_list):
54
54
  if key != 'epoch':
55
55
  value = np.array([dict[key] for dict in dict_list])
56
56
  dict_agg[key] = np.mean(value).round(cfg.round)
57
- dict_agg['{}_std'.format(key)] = np.std(value).round(cfg.round)
57
+ dict_agg[f'{key}_std'] = np.std(value).round(cfg.round)
58
58
  return dict_agg
59
59
 
60
60
 
@@ -107,7 +107,7 @@ def agg_runs(dir, metric_best='auto'):
107
107
  [stats[metric] for stats in stats_list])
108
108
  best_epoch = \
109
109
  stats_list[
110
- eval("performance_np.{}()".format(cfg.metric_agg))][
110
+ eval(f"performance_np.{cfg.metric_agg}()")][
111
111
  'epoch']
112
112
  print(best_epoch)
113
113
 
@@ -190,7 +190,7 @@ def agg_batch(dir, metric_best='auto'):
190
190
  results[key] = pd.DataFrame(results[key])
191
191
  results[key] = results[key].sort_values(
192
192
  list(dict_name.keys()), ascending=[True] * len(dict_name))
193
- fname = osp.join(dir_out, '{}_best.csv'.format(key))
193
+ fname = osp.join(dir_out, f'{key}_best.csv')
194
194
  results[key].to_csv(fname, index=False)
195
195
 
196
196
  results = {'train': [], 'val': [], 'test': []}
@@ -213,7 +213,7 @@ def agg_batch(dir, metric_best='auto'):
213
213
  results[key] = pd.DataFrame(results[key])
214
214
  results[key] = results[key].sort_values(
215
215
  list(dict_name.keys()), ascending=[True] * len(dict_name))
216
- fname = osp.join(dir_out, '{}.csv'.format(key))
216
+ fname = osp.join(dir_out, f'{key}.csv')
217
217
  results[key].to_csv(fname, index=False)
218
218
 
219
219
  results = {'train': [], 'val': [], 'test': []}
@@ -245,7 +245,7 @@ def agg_batch(dir, metric_best='auto'):
245
245
  results[key] = pd.DataFrame(results[key])
246
246
  results[key] = results[key].sort_values(
247
247
  list(dict_name.keys()), ascending=[True] * len(dict_name))
248
- fname = osp.join(dir_out, '{}_bestepoch.csv'.format(key))
248
+ fname = osp.join(dir_out, f'{key}_bestepoch.csv')
249
249
  results[key].to_csv(fname, index=False)
250
250
 
251
- print('Results aggregated across models saved in {}'.format(dir_out))
251
+ print(f'Results aggregated across models saved in {dir_out}')
torch_geometric/index.py CHANGED
@@ -106,7 +106,7 @@ class Index(Tensor):
106
106
  :meth:`Index.fill_cache_`, and are maintaned and adjusted over its
107
107
  lifespan.
108
108
 
109
- This representation ensures for optimal computation in GNN message passing
109
+ This representation ensures optimal computation in GNN message passing
110
110
  schemes, while preserving the ease-of-use of regular COO-based :pyg:`PyG`
111
111
  workflows.
112
112
 
@@ -120,7 +120,7 @@ class Index(Tensor):
120
120
  assert index.is_sorted
121
121
 
122
122
  # Flipping order:
123
- edge_index.flip(0)
123
+ index.flip(0)
124
124
  >>> Index([[2, 1, 1, 0], dim_size=3)
125
125
  assert not index.is_sorted
126
126
 
@@ -685,14 +685,14 @@ def _index(
685
685
 
686
686
  @implements(aten.add.Tensor)
687
687
  def _add(
688
- input: Index,
688
+ input: Union[int, Tensor, Index],
689
689
  other: Union[int, Tensor, Index],
690
690
  *,
691
691
  alpha: int = 1,
692
692
  ) -> Union[Index, Tensor]:
693
693
 
694
694
  data = aten.add.Tensor(
695
- input._data,
695
+ input._data if isinstance(input, Index) else input,
696
696
  other._data if isinstance(other, Index) else other,
697
697
  alpha=alpha,
698
698
  )
@@ -704,15 +704,25 @@ def _add(
704
704
 
705
705
  out = Index(data)
706
706
 
707
+ if isinstance(input, Tensor) and input.numel() <= 1:
708
+ input = int(input)
709
+
707
710
  if isinstance(other, Tensor) and other.numel() <= 1:
708
711
  other = int(other)
709
712
 
710
713
  if isinstance(other, int):
714
+ assert isinstance(input, Index)
711
715
  if input.dim_size is not None:
712
716
  out._dim_size = input.dim_size + alpha * other
713
717
  out._is_sorted = input.is_sorted
714
718
 
715
- elif isinstance(other, Index):
719
+ elif isinstance(input, int):
720
+ assert isinstance(other, Index)
721
+ if other.dim_size is not None:
722
+ out._dim_size = input + alpha * other.dim_size
723
+ out._is_sorted = other.is_sorted
724
+
725
+ elif isinstance(input, Index) and isinstance(other, Index):
716
726
  if input.dim_size is not None and other.dim_size is not None:
717
727
  out._dim_size = input.dim_size + alpha * other.dim_size
718
728
 
@@ -754,14 +764,14 @@ def add_(
754
764
 
755
765
  @implements(aten.sub.Tensor)
756
766
  def _sub(
757
- input: Index,
767
+ input: Union[int, Tensor, Index],
758
768
  other: Union[int, Tensor, Index],
759
769
  *,
760
770
  alpha: int = 1,
761
771
  ) -> Union[Index, Tensor]:
762
772
 
763
773
  data = aten.sub.Tensor(
764
- input._data,
774
+ input._data if isinstance(input, Index) else input,
765
775
  other._data if isinstance(other, Index) else other,
766
776
  alpha=alpha,
767
777
  )
@@ -773,6 +783,9 @@ def _sub(
773
783
 
774
784
  out = Index(data)
775
785
 
786
+ if not isinstance(input, Index):
787
+ return out
788
+
776
789
  if isinstance(other, Tensor) and other.numel() <= 1:
777
790
  other = int(other)
778
791
 
@@ -305,7 +305,7 @@ class Inspector:
305
305
  according to its function signature from a data blob.
306
306
 
307
307
  Args:
308
- func (callabel or str): The function.
308
+ func (callable or str): The function.
309
309
  kwargs (dict[str, Any]): The data blob which may serve as inputs.
310
310
  """
311
311
  out_dict: Dict[str, Any] = {}
@@ -346,7 +346,7 @@ class Inspector:
346
346
  type annotations are not found.
347
347
 
348
348
  Args:
349
- func (callabel or str): The function.
349
+ func (callable or str): The function.
350
350
  exclude (list[int or str]): A list of parameters to exclude, either
351
351
  given by their name or index. (default: :obj:`None`)
352
352
  """
@@ -448,6 +448,10 @@ def type_repr(obj: Any, _globals: Dict[str, Any]) -> str:
448
448
  return '...'
449
449
 
450
450
  if obj.__module__ == 'typing': # Special logic for `typing.*` types:
451
+
452
+ if not hasattr(obj, '_name'):
453
+ return repr(obj)
454
+
451
455
  name = obj._name
452
456
  if name is None: # In some cases, `_name` is not populated.
453
457
  name = str(obj.__origin__).split('.')[-1]
torch_geometric/io/fs.py CHANGED
@@ -1,6 +1,9 @@
1
1
  import io
2
2
  import os.path as osp
3
+ import pickle
4
+ import re
3
5
  import sys
6
+ import warnings
4
7
  from typing import Any, Dict, List, Literal, Optional, Union, overload
5
8
  from uuid import uuid4
6
9
 
@@ -186,11 +189,11 @@ def rm(path: str, recursive: bool = True) -> None:
186
189
  get_fs(path).rm(path, recursive)
187
190
 
188
191
 
189
- def mv(path1: str, path2: str, recursive: bool = True) -> None:
192
+ def mv(path1: str, path2: str) -> None:
190
193
  fs1 = get_fs(path1)
191
194
  fs2 = get_fs(path2)
192
195
  assert fs1.protocol == fs2.protocol
193
- fs1.mv(path1, path2, recursive)
196
+ fs1.mv(path1, path2)
194
197
 
195
198
 
196
199
  def glob(path: str) -> List[str]:
@@ -211,5 +214,28 @@ def torch_save(data: Any, path: str) -> None:
211
214
 
212
215
 
213
216
  def torch_load(path: str, map_location: Any = None) -> Any:
217
+ if torch_geometric.typing.WITH_PT24:
218
+ try:
219
+ with fsspec.open(path, 'rb') as f:
220
+ return torch.load(f, map_location, weights_only=True)
221
+ except pickle.UnpicklingError as e:
222
+ error_msg = str(e)
223
+ if "add_safe_globals" in error_msg:
224
+ warn_msg = ("Weights only load failed. Please file an issue "
225
+ "to make `torch.load(weights_only=True)` "
226
+ "compatible in your case.")
227
+ match = re.search(r'add_safe_globals\(.*?\)', error_msg)
228
+ if match is not None:
229
+ warnings.warn(f"{warn_msg} Please use "
230
+ f"`torch.serialization.{match.group()}` to "
231
+ f"allowlist this global.")
232
+ else:
233
+ warnings.warn(warn_msg)
234
+
235
+ with fsspec.open(path, 'rb') as f:
236
+ return torch.load(f, map_location, weights_only=False)
237
+ else:
238
+ raise e
239
+
214
240
  with fsspec.open(path, 'rb') as f:
215
241
  return torch.load(f, map_location)
torch_geometric/io/npz.py CHANGED
@@ -1,7 +1,6 @@
1
1
  from typing import Any, Dict
2
2
 
3
3
  import numpy as np
4
- import scipy.sparse as sp
5
4
  import torch
6
5
 
7
6
  from torch_geometric.data import Data
@@ -15,6 +14,8 @@ def read_npz(path: str, to_undirected: bool = True) -> Data:
15
14
 
16
15
 
17
16
  def parse_npz(f: Dict[str, Any], to_undirected: bool = True) -> Data:
17
+ import scipy.sparse as sp
18
+
18
19
  x = sp.csr_matrix((f['attr_data'], f['attr_indices'], f['attr_indptr']),
19
20
  f['attr_shape']).todense()
20
21
  x = torch.from_numpy(x).to(torch.float)
torch_geometric/io/off.py CHANGED
@@ -16,7 +16,7 @@ def parse_off(src: List[str]) -> Data:
16
16
  else:
17
17
  src[0] = src[0][3:]
18
18
 
19
- num_nodes, num_faces = [int(item) for item in src[0].split()[:2]]
19
+ num_nodes, num_faces = (int(item) for item in src[0].split()[:2])
20
20
 
21
21
  pos = parse_txt_array(src[1:1 + num_nodes])
22
22
 
@@ -52,7 +52,7 @@ def read_off(path: str) -> Data:
52
52
  Args:
53
53
  path (str): The path to the file.
54
54
  """
55
- with open(path, 'r') as f:
55
+ with open(path) as f:
56
56
  src = f.read().split('\n')[:-1]
57
57
  return parse_off(src)
58
58
 
torch_geometric/io/sdf.py CHANGED
@@ -9,7 +9,7 @@ elems = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4}
9
9
 
10
10
  def parse_sdf(src: str) -> Data:
11
11
  lines = src.split('\n')[3:]
12
- num_atoms, num_bonds = [int(item) for item in lines[0].split()[:2]]
12
+ num_atoms, num_bonds = (int(item) for item in lines[0].split()[:2])
13
13
 
14
14
  atom_block = lines[1:num_atoms + 1]
15
15
  pos = parse_txt_array(atom_block, end=3)
@@ -28,5 +28,5 @@ def parse_sdf(src: str) -> Data:
28
28
 
29
29
 
30
30
  def read_sdf(path: str) -> Data:
31
- with open(path, 'r') as f:
31
+ with open(path) as f:
32
32
  return parse_sdf(f.read())
torch_geometric/io/tu.py CHANGED
@@ -1,7 +1,6 @@
1
1
  import os.path as osp
2
2
  from typing import Dict, List, Optional, Tuple
3
3
 
4
- import numpy as np
5
4
  import torch
6
5
  from torch import Tensor
7
6
 
@@ -108,11 +107,11 @@ def cat(seq: List[Optional[Tensor]]) -> Optional[Tensor]:
108
107
 
109
108
 
110
109
  def split(data: Data, batch: Tensor) -> Tuple[Data, Dict[str, Tensor]]:
111
- node_slice = cumsum(torch.from_numpy(np.bincount(batch)))
110
+ node_slice = cumsum(torch.bincount(batch))
112
111
 
113
112
  assert data.edge_index is not None
114
113
  row, _ = data.edge_index
115
- edge_slice = cumsum(torch.from_numpy(np.bincount(batch[row])))
114
+ edge_slice = cumsum(torch.bincount(batch[row]))
116
115
 
117
116
  # Edge indices should start at zero for every graph.
118
117
  data.edge_index -= node_slice[batch[row]].unsqueeze(0)
@@ -22,6 +22,7 @@ from .dynamic_batch_sampler import DynamicBatchSampler
22
22
  from .prefetch import PrefetchLoader
23
23
  from .cache import CachedLoader
24
24
  from .mixin import AffinityMixin
25
+ from .rag_loader import RAGQueryLoader, RAGFeatureStore, RAGGraphStore
25
26
 
26
27
  __all__ = classes = [
27
28
  'DataLoader',
@@ -50,6 +51,9 @@ __all__ = classes = [
50
51
  'PrefetchLoader',
51
52
  'CachedLoader',
52
53
  'AffinityMixin',
54
+ 'RAGQueryLoader',
55
+ 'RAGFeatureStore',
56
+ 'RAGGraphStore'
53
57
  ]
54
58
 
55
59
  RandomNodeSampler = deprecated(