pyg-nightly 2.6.0.dev20240704__py3-none-any.whl → 2.8.0.dev20251207__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pyg-nightly might be problematic. Click here for more details.

Files changed (268) hide show
  1. {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/METADATA +81 -58
  2. {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/RECORD +265 -221
  3. {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/WHEEL +1 -1
  4. pyg_nightly-2.8.0.dev20251207.dist-info/licenses/LICENSE +19 -0
  5. torch_geometric/__init__.py +34 -1
  6. torch_geometric/_compile.py +11 -3
  7. torch_geometric/_onnx.py +228 -0
  8. torch_geometric/config_mixin.py +8 -3
  9. torch_geometric/config_store.py +1 -1
  10. torch_geometric/contrib/__init__.py +1 -1
  11. torch_geometric/contrib/explain/pgm_explainer.py +1 -1
  12. torch_geometric/data/__init__.py +19 -1
  13. torch_geometric/data/batch.py +2 -2
  14. torch_geometric/data/collate.py +1 -3
  15. torch_geometric/data/data.py +110 -6
  16. torch_geometric/data/database.py +19 -5
  17. torch_geometric/data/dataset.py +14 -9
  18. torch_geometric/data/extract.py +1 -1
  19. torch_geometric/data/feature_store.py +17 -22
  20. torch_geometric/data/graph_store.py +3 -2
  21. torch_geometric/data/hetero_data.py +139 -7
  22. torch_geometric/data/hypergraph_data.py +2 -2
  23. torch_geometric/data/in_memory_dataset.py +2 -2
  24. torch_geometric/data/lightning/datamodule.py +42 -28
  25. torch_geometric/data/storage.py +9 -1
  26. torch_geometric/datasets/__init__.py +20 -1
  27. torch_geometric/datasets/actor.py +7 -9
  28. torch_geometric/datasets/airfrans.py +17 -20
  29. torch_geometric/datasets/airports.py +8 -10
  30. torch_geometric/datasets/amazon.py +8 -11
  31. torch_geometric/datasets/amazon_book.py +8 -9
  32. torch_geometric/datasets/amazon_products.py +7 -9
  33. torch_geometric/datasets/aminer.py +8 -9
  34. torch_geometric/datasets/aqsol.py +10 -13
  35. torch_geometric/datasets/attributed_graph_dataset.py +8 -10
  36. torch_geometric/datasets/ba_multi_shapes.py +10 -12
  37. torch_geometric/datasets/ba_shapes.py +5 -6
  38. torch_geometric/datasets/brca_tgca.py +1 -1
  39. torch_geometric/datasets/city.py +157 -0
  40. torch_geometric/datasets/dbp15k.py +1 -1
  41. torch_geometric/datasets/gdelt_lite.py +3 -2
  42. torch_geometric/datasets/ged_dataset.py +3 -2
  43. torch_geometric/datasets/git_mol_dataset.py +263 -0
  44. torch_geometric/datasets/gnn_benchmark_dataset.py +3 -2
  45. torch_geometric/datasets/hgb_dataset.py +2 -2
  46. torch_geometric/datasets/hm.py +1 -1
  47. torch_geometric/datasets/instruct_mol_dataset.py +134 -0
  48. torch_geometric/datasets/linkx_dataset.py +4 -3
  49. torch_geometric/datasets/lrgb.py +3 -5
  50. torch_geometric/datasets/malnet_tiny.py +2 -1
  51. torch_geometric/datasets/md17.py +3 -3
  52. torch_geometric/datasets/medshapenet.py +145 -0
  53. torch_geometric/datasets/mnist_superpixels.py +2 -3
  54. torch_geometric/datasets/modelnet.py +1 -1
  55. torch_geometric/datasets/molecule_gpt_dataset.py +492 -0
  56. torch_geometric/datasets/molecule_net.py +3 -2
  57. torch_geometric/datasets/neurograph.py +1 -3
  58. torch_geometric/datasets/ogb_mag.py +1 -1
  59. torch_geometric/datasets/opf.py +19 -5
  60. torch_geometric/datasets/pascal_pf.py +1 -1
  61. torch_geometric/datasets/pcqm4m.py +2 -1
  62. torch_geometric/datasets/ppi.py +2 -1
  63. torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
  64. torch_geometric/datasets/qm7.py +1 -1
  65. torch_geometric/datasets/qm9.py +3 -2
  66. torch_geometric/datasets/shrec2016.py +2 -2
  67. torch_geometric/datasets/snap_dataset.py +8 -4
  68. torch_geometric/datasets/tag_dataset.py +462 -0
  69. torch_geometric/datasets/teeth3ds.py +269 -0
  70. torch_geometric/datasets/web_qsp_dataset.py +342 -0
  71. torch_geometric/datasets/wikics.py +2 -1
  72. torch_geometric/datasets/wikidata.py +2 -1
  73. torch_geometric/deprecation.py +1 -1
  74. torch_geometric/distributed/__init__.py +13 -0
  75. torch_geometric/distributed/dist_loader.py +2 -2
  76. torch_geometric/distributed/local_feature_store.py +3 -2
  77. torch_geometric/distributed/local_graph_store.py +2 -1
  78. torch_geometric/distributed/partition.py +9 -8
  79. torch_geometric/distributed/rpc.py +3 -3
  80. torch_geometric/edge_index.py +35 -22
  81. torch_geometric/explain/algorithm/attention_explainer.py +219 -29
  82. torch_geometric/explain/algorithm/base.py +2 -2
  83. torch_geometric/explain/algorithm/captum.py +1 -1
  84. torch_geometric/explain/algorithm/captum_explainer.py +2 -1
  85. torch_geometric/explain/algorithm/gnn_explainer.py +406 -69
  86. torch_geometric/explain/algorithm/graphmask_explainer.py +8 -8
  87. torch_geometric/explain/algorithm/pg_explainer.py +305 -47
  88. torch_geometric/explain/explainer.py +2 -2
  89. torch_geometric/explain/explanation.py +89 -5
  90. torch_geometric/explain/metric/faithfulness.py +1 -1
  91. torch_geometric/graphgym/checkpoint.py +2 -1
  92. torch_geometric/graphgym/config.py +3 -2
  93. torch_geometric/graphgym/imports.py +15 -4
  94. torch_geometric/graphgym/logger.py +1 -1
  95. torch_geometric/graphgym/loss.py +1 -1
  96. torch_geometric/graphgym/models/encoder.py +2 -2
  97. torch_geometric/graphgym/models/layer.py +1 -1
  98. torch_geometric/graphgym/utils/comp_budget.py +4 -3
  99. torch_geometric/hash_tensor.py +798 -0
  100. torch_geometric/index.py +16 -7
  101. torch_geometric/inspector.py +6 -2
  102. torch_geometric/io/fs.py +27 -0
  103. torch_geometric/io/tu.py +2 -3
  104. torch_geometric/llm/__init__.py +9 -0
  105. torch_geometric/llm/large_graph_indexer.py +741 -0
  106. torch_geometric/llm/models/__init__.py +23 -0
  107. torch_geometric/llm/models/g_retriever.py +251 -0
  108. torch_geometric/llm/models/git_mol.py +336 -0
  109. torch_geometric/llm/models/glem.py +397 -0
  110. torch_geometric/llm/models/llm.py +470 -0
  111. torch_geometric/llm/models/llm_judge.py +158 -0
  112. torch_geometric/llm/models/molecule_gpt.py +222 -0
  113. torch_geometric/llm/models/protein_mpnn.py +333 -0
  114. torch_geometric/llm/models/sentence_transformer.py +188 -0
  115. torch_geometric/llm/models/txt2kg.py +353 -0
  116. torch_geometric/llm/models/vision_transformer.py +38 -0
  117. torch_geometric/llm/rag_loader.py +154 -0
  118. torch_geometric/llm/utils/__init__.py +10 -0
  119. torch_geometric/llm/utils/backend_utils.py +443 -0
  120. torch_geometric/llm/utils/feature_store.py +169 -0
  121. torch_geometric/llm/utils/graph_store.py +199 -0
  122. torch_geometric/llm/utils/vectorrag.py +125 -0
  123. torch_geometric/loader/cluster.py +6 -5
  124. torch_geometric/loader/graph_saint.py +2 -1
  125. torch_geometric/loader/ibmb_loader.py +4 -4
  126. torch_geometric/loader/link_loader.py +1 -1
  127. torch_geometric/loader/link_neighbor_loader.py +2 -1
  128. torch_geometric/loader/mixin.py +6 -5
  129. torch_geometric/loader/neighbor_loader.py +1 -1
  130. torch_geometric/loader/neighbor_sampler.py +2 -2
  131. torch_geometric/loader/prefetch.py +4 -3
  132. torch_geometric/loader/temporal_dataloader.py +2 -2
  133. torch_geometric/loader/utils.py +10 -10
  134. torch_geometric/metrics/__init__.py +23 -2
  135. torch_geometric/metrics/link_pred.py +755 -85
  136. torch_geometric/nn/__init__.py +1 -0
  137. torch_geometric/nn/aggr/__init__.py +2 -0
  138. torch_geometric/nn/aggr/base.py +1 -1
  139. torch_geometric/nn/aggr/equilibrium.py +1 -1
  140. torch_geometric/nn/aggr/fused.py +1 -1
  141. torch_geometric/nn/aggr/patch_transformer.py +149 -0
  142. torch_geometric/nn/aggr/set_transformer.py +1 -1
  143. torch_geometric/nn/aggr/utils.py +9 -4
  144. torch_geometric/nn/attention/__init__.py +9 -1
  145. torch_geometric/nn/attention/polynormer.py +107 -0
  146. torch_geometric/nn/attention/qformer.py +71 -0
  147. torch_geometric/nn/attention/sgformer.py +99 -0
  148. torch_geometric/nn/conv/__init__.py +2 -0
  149. torch_geometric/nn/conv/appnp.py +1 -1
  150. torch_geometric/nn/conv/collect.jinja +6 -3
  151. torch_geometric/nn/conv/cugraph/gat_conv.py +8 -2
  152. torch_geometric/nn/conv/cugraph/rgcn_conv.py +3 -0
  153. torch_geometric/nn/conv/cugraph/sage_conv.py +3 -0
  154. torch_geometric/nn/conv/dna_conv.py +1 -1
  155. torch_geometric/nn/conv/eg_conv.py +7 -7
  156. torch_geometric/nn/conv/gat_conv.py +33 -4
  157. torch_geometric/nn/conv/gatv2_conv.py +35 -4
  158. torch_geometric/nn/conv/gen_conv.py +1 -1
  159. torch_geometric/nn/conv/general_conv.py +1 -1
  160. torch_geometric/nn/conv/gravnet_conv.py +2 -1
  161. torch_geometric/nn/conv/hetero_conv.py +3 -2
  162. torch_geometric/nn/conv/meshcnn_conv.py +487 -0
  163. torch_geometric/nn/conv/message_passing.py +6 -5
  164. torch_geometric/nn/conv/mixhop_conv.py +1 -1
  165. torch_geometric/nn/conv/rgcn_conv.py +2 -1
  166. torch_geometric/nn/conv/sg_conv.py +1 -1
  167. torch_geometric/nn/conv/spline_conv.py +2 -1
  168. torch_geometric/nn/conv/ssg_conv.py +1 -1
  169. torch_geometric/nn/conv/transformer_conv.py +5 -3
  170. torch_geometric/nn/data_parallel.py +5 -4
  171. torch_geometric/nn/dense/linear.py +5 -24
  172. torch_geometric/nn/encoding.py +17 -3
  173. torch_geometric/nn/fx.py +17 -15
  174. torch_geometric/nn/model_hub.py +5 -16
  175. torch_geometric/nn/models/__init__.py +11 -0
  176. torch_geometric/nn/models/attentive_fp.py +1 -1
  177. torch_geometric/nn/models/attract_repel.py +148 -0
  178. torch_geometric/nn/models/basic_gnn.py +2 -1
  179. torch_geometric/nn/models/captum.py +1 -1
  180. torch_geometric/nn/models/deep_graph_infomax.py +1 -1
  181. torch_geometric/nn/models/dimenet.py +2 -2
  182. torch_geometric/nn/models/dimenet_utils.py +4 -2
  183. torch_geometric/nn/models/gpse.py +1083 -0
  184. torch_geometric/nn/models/graph_unet.py +13 -4
  185. torch_geometric/nn/models/lpformer.py +783 -0
  186. torch_geometric/nn/models/metapath2vec.py +1 -1
  187. torch_geometric/nn/models/mlp.py +4 -2
  188. torch_geometric/nn/models/node2vec.py +1 -1
  189. torch_geometric/nn/models/polynormer.py +206 -0
  190. torch_geometric/nn/models/rev_gnn.py +3 -3
  191. torch_geometric/nn/models/schnet.py +2 -1
  192. torch_geometric/nn/models/sgformer.py +219 -0
  193. torch_geometric/nn/models/signed_gcn.py +1 -1
  194. torch_geometric/nn/models/visnet.py +2 -2
  195. torch_geometric/nn/norm/batch_norm.py +17 -7
  196. torch_geometric/nn/norm/diff_group_norm.py +7 -2
  197. torch_geometric/nn/norm/graph_norm.py +9 -4
  198. torch_geometric/nn/norm/instance_norm.py +5 -1
  199. torch_geometric/nn/norm/layer_norm.py +15 -7
  200. torch_geometric/nn/norm/msg_norm.py +8 -2
  201. torch_geometric/nn/pool/__init__.py +15 -9
  202. torch_geometric/nn/pool/cluster_pool.py +144 -0
  203. torch_geometric/nn/pool/connect/base.py +1 -3
  204. torch_geometric/nn/pool/edge_pool.py +1 -1
  205. torch_geometric/nn/pool/knn.py +13 -10
  206. torch_geometric/nn/pool/select/base.py +1 -4
  207. torch_geometric/nn/summary.py +1 -1
  208. torch_geometric/nn/to_hetero_module.py +4 -3
  209. torch_geometric/nn/to_hetero_transformer.py +3 -3
  210. torch_geometric/nn/to_hetero_with_bases_transformer.py +5 -5
  211. torch_geometric/profile/__init__.py +2 -0
  212. torch_geometric/profile/nvtx.py +66 -0
  213. torch_geometric/profile/profiler.py +18 -9
  214. torch_geometric/profile/utils.py +20 -5
  215. torch_geometric/sampler/__init__.py +2 -1
  216. torch_geometric/sampler/base.py +337 -8
  217. torch_geometric/sampler/hgt_sampler.py +11 -1
  218. torch_geometric/sampler/neighbor_sampler.py +298 -25
  219. torch_geometric/sampler/utils.py +93 -5
  220. torch_geometric/testing/__init__.py +4 -0
  221. torch_geometric/testing/decorators.py +35 -5
  222. torch_geometric/testing/distributed.py +1 -1
  223. torch_geometric/transforms/__init__.py +4 -0
  224. torch_geometric/transforms/add_gpse.py +49 -0
  225. torch_geometric/transforms/add_metapaths.py +10 -8
  226. torch_geometric/transforms/add_positional_encoding.py +2 -2
  227. torch_geometric/transforms/base_transform.py +2 -1
  228. torch_geometric/transforms/delaunay.py +65 -15
  229. torch_geometric/transforms/face_to_edge.py +32 -3
  230. torch_geometric/transforms/gdc.py +8 -9
  231. torch_geometric/transforms/largest_connected_components.py +1 -1
  232. torch_geometric/transforms/mask.py +5 -1
  233. torch_geometric/transforms/node_property_split.py +1 -1
  234. torch_geometric/transforms/normalize_features.py +3 -3
  235. torch_geometric/transforms/pad.py +1 -1
  236. torch_geometric/transforms/random_link_split.py +1 -1
  237. torch_geometric/transforms/remove_duplicated_edges.py +4 -2
  238. torch_geometric/transforms/remove_self_loops.py +36 -0
  239. torch_geometric/transforms/rooted_subgraph.py +1 -1
  240. torch_geometric/transforms/svd_feature_reduction.py +1 -1
  241. torch_geometric/transforms/virtual_node.py +2 -1
  242. torch_geometric/typing.py +82 -17
  243. torch_geometric/utils/__init__.py +6 -1
  244. torch_geometric/utils/_lexsort.py +0 -9
  245. torch_geometric/utils/_negative_sampling.py +28 -13
  246. torch_geometric/utils/_normalize_edge_index.py +46 -0
  247. torch_geometric/utils/_scatter.py +126 -164
  248. torch_geometric/utils/_sort_edge_index.py +0 -2
  249. torch_geometric/utils/_spmm.py +16 -14
  250. torch_geometric/utils/_subgraph.py +4 -0
  251. torch_geometric/utils/_tree_decomposition.py +1 -1
  252. torch_geometric/utils/_trim_to_layer.py +2 -2
  253. torch_geometric/utils/augmentation.py +1 -1
  254. torch_geometric/utils/convert.py +17 -10
  255. torch_geometric/utils/cross_entropy.py +34 -13
  256. torch_geometric/utils/embedding.py +91 -2
  257. torch_geometric/utils/geodesic.py +28 -25
  258. torch_geometric/utils/influence.py +279 -0
  259. torch_geometric/utils/map.py +14 -10
  260. torch_geometric/utils/nested.py +1 -1
  261. torch_geometric/utils/smiles.py +3 -3
  262. torch_geometric/utils/sparse.py +32 -24
  263. torch_geometric/visualization/__init__.py +2 -1
  264. torch_geometric/visualization/graph.py +250 -5
  265. torch_geometric/warnings.py +11 -2
  266. torch_geometric/nn/nlp/__init__.py +0 -7
  267. torch_geometric/nn/nlp/llm.py +0 -283
  268. torch_geometric/nn/nlp/sentence_transformer.py +0 -94
@@ -19,17 +19,15 @@ class Actor(InMemoryDataset):
19
19
  actor's Wikipedia.
20
20
 
21
21
  Args:
22
- root (str): Root directory where the dataset should be saved.
23
- transform (callable, optional): A function/transform that takes in an
22
+ root: Root directory where the dataset should be saved.
23
+ transform: A function/transform that takes in an
24
24
  :obj:`torch_geometric.data.Data` object and returns a transformed
25
25
  version. The data object will be transformed before every access.
26
- (default: :obj:`None`)
27
- pre_transform (callable, optional): A function/transform that takes in
28
- an :obj:`torch_geometric.data.Data` object and returns a
29
- transformed version. The data object will be transformed before
30
- being saved to disk. (default: :obj:`None`)
31
- force_reload (bool, optional): Whether to re-process the dataset.
32
- (default: :obj:`False`)
26
+ pre_transform: A function/transform that takes in an
27
+ :class:`torch_geometric.data.Data` object and returns a transformed
28
+ version. The data object will be transformed before being saved to
29
+ disk.
30
+ force_reload: Whether to re-process the dataset.
33
31
 
34
32
  **STATS:**
35
33
 
@@ -2,14 +2,13 @@ import json
2
2
  import os
3
3
  from typing import Callable, List, Optional
4
4
 
5
- import torch
6
-
7
5
  from torch_geometric.data import (
8
6
  Data,
9
7
  InMemoryDataset,
10
8
  download_url,
11
9
  extract_zip,
12
10
  )
11
+ from torch_geometric.io import fs
13
12
 
14
13
 
15
14
  class AirfRANS(InMemoryDataset):
@@ -26,13 +25,13 @@ class AirfRANS(InMemoryDataset):
26
25
  features: the inlet velocity (two components in meter per second), the
27
26
  distance to the airfoil (one component in meter), and the normals (two
28
27
  components in meter, set to :obj:`0` if the point is not on the airfoil).
29
- Each point is given a target of 4 components for the underyling regression
28
+ Each point is given a target of 4 components for the underlying regression
30
29
  task: the velocity (two components in meter per second), the pressure
31
30
  divided by the specific mass (one component in meter squared per second
32
31
  squared), the turbulent kinematic viscosity (one component in meter squared
33
32
  per second).
34
- Finaly, a boolean is attached to each point to inform if this point lies on
35
- the airfoil or not.
33
+ Finally, a boolean is attached to each point to inform if this point lies
34
+ on the airfoil or not.
36
35
 
37
36
  A library for manipulating simulations of the dataset is available `here
38
37
  <https://airfrans.readthedocs.io/en/latest/index.html>`_.
@@ -47,26 +46,24 @@ class AirfRANS(InMemoryDataset):
47
46
  :obj:`torch_geometric.transforms.RadiusGraph` transform.
48
47
 
49
48
  Args:
50
- root (str): Root directory where the dataset should be saved.
51
- task (str): The task to study (:obj:`"full"`, :obj:`"scarce"`,
49
+ root: Root directory where the dataset should be saved.
50
+ task: The task to study (:obj:`"full"`, :obj:`"scarce"`,
52
51
  :obj:`"reynolds"`, :obj:`"aoa"`) that defines the utilized training
53
52
  and test splits.
54
- train (bool, optional): If :obj:`True`, loads the training dataset,
55
- otherwise the test dataset. (default: :obj:`True`)
56
- transform (callable, optional): A function/transform that takes in an
57
- :obj:`torch_geometric.data.Data` object and returns a transformed
53
+ train: If :obj:`True`, loads the training dataset, otherwise the test
54
+ dataset.
55
+ transform: A function/transform that takes in an
56
+ :class:`torch_geometric.data.Data` object and returns a transformed
58
57
  version. The data object will be transformed before every access.
59
- (default: :obj:`None`)
60
- pre_transform (callable, optional): A function/transform that takes in
61
- an :obj:`torch_geometric.data.Data` object and returns a
58
+ pre_transform: A function/transform that takes in an
59
+ :class:`torch_geometric.data.Data` object and returns a
62
60
  transformed version. The data object will be transformed before
63
- being saved to disk. (default: :obj:`None`)
64
- pre_filter (callable, optional): A function that takes in an
61
+ being saved to disk.
62
+ pre_filter: A function that takes in an
65
63
  :obj:`torch_geometric.data.Data` object and returns a boolean
66
64
  value, indicating whether the data object should be included in the
67
- final dataset. (default: :obj:`None`)
68
- force_reload (bool, optional): Whether to re-process the dataset.
69
- (default: :obj:`False`)
65
+ final dataset.
66
+ force_reload: Whether to re-process the dataset.
70
67
 
71
68
  **STATS:**
72
69
 
@@ -129,7 +126,7 @@ class AirfRANS(InMemoryDataset):
129
126
  partial = set(manifest[f'{self.task}_{self.split}'])
130
127
 
131
128
  data_list = []
132
- raw_data = torch.load(self.raw_paths[0])
129
+ raw_data = fs.torch_load(self.raw_paths[0])
133
130
  for k, s in enumerate(total):
134
131
  if s in partial:
135
132
  data = Data(**raw_data[k])
@@ -14,22 +14,20 @@ class Airports(InMemoryDataset):
14
14
  and labels correspond to activity levels.
15
15
  Features are given by one-hot encoded node identifiers, as described in the
16
16
  `"GraLSP: Graph Neural Networks with Local Structural Patterns"
17
- ` <https://arxiv.org/abs/1911.07675>`_ paper.
17
+ <https://arxiv.org/abs/1911.07675>`_ paper.
18
18
 
19
19
  Args:
20
- root (str): Root directory where the dataset should be saved.
21
- name (str): The name of the dataset (:obj:`"USA"`, :obj:`"Brazil"`,
20
+ root: Root directory where the dataset should be saved.
21
+ name: The name of the dataset (:obj:`"USA"`, :obj:`"Brazil"`,
22
22
  :obj:`"Europe"`).
23
- transform (callable, optional): A function/transform that takes in an
24
- :obj:`torch_geometric.data.Data` object and returns a transformed
23
+ transform: A function/transform that takes in an
24
+ :class:`torch_geometric.data.Data` object and returns a transformed
25
25
  version. The data object will be transformed before every access.
26
- (default: :obj:`None`)
27
26
  pre_transform (callable, optional): A function/transform that takes in
28
- an :obj:`torch_geometric.data.Data` object and returns a
27
+ :class:`torch_geometric.data.Data` object and returns a
29
28
  transformed version. The data object will be transformed before
30
- being saved to disk. (default: :obj:`None`)
31
- force_reload (bool, optional): Whether to re-process the dataset.
32
- (default: :obj:`False`)
29
+ being saved to disk.
30
+ force_reload: Whether to re-process the dataset.
33
31
  """
34
32
  edge_url = ('https://github.com/leoribeiro/struc2vec/'
35
33
  'raw/master/graph/{}-airports.edgelist')
@@ -15,19 +15,16 @@ class Amazon(InMemoryDataset):
15
15
  map goods to their respective product category.
16
16
 
17
17
  Args:
18
- root (str): Root directory where the dataset should be saved.
19
- name (str): The name of the dataset (:obj:`"Computers"`,
20
- :obj:`"Photo"`).
21
- transform (callable, optional): A function/transform that takes in an
22
- :obj:`torch_geometric.data.Data` object and returns a transformed
18
+ root: Root directory where the dataset should be saved.
19
+ name: The name of the dataset (:obj:`"Computers"`, :obj:`"Photo"`).
20
+ transform: A function/transform that takes in a
21
+ :class:`torch_geometric.data.Data` object and returns a transformed
23
22
  version. The data object will be transformed before every access.
24
- (default: :obj:`None`)
25
- pre_transform (callable, optional): A function/transform that takes in
26
- an :obj:`torch_geometric.data.Data` object and returns a
23
+ pre_transform: A function/transform that takes in an
24
+ :class:`torch_geometric.data.Data` object and returns a
27
25
  transformed version. The data object will be transformed before
28
- being saved to disk. (default: :obj:`None`)
29
- force_reload (bool, optional): Whether to re-process the dataset.
30
- (default: :obj:`False`)
26
+ being saved to disk.
27
+ force_reload: Whether to re-process the dataset.
31
28
 
32
29
  **STATS:**
33
30
 
@@ -14,17 +14,16 @@ class AmazonBook(InMemoryDataset):
14
14
  No labels or features are provided.
15
15
 
16
16
  Args:
17
- root (str): Root directory where the dataset should be saved.
18
- transform (callable, optional): A function/transform that takes in an
19
- :obj:`torch_geometric.data.HeteroData` object and returns a
17
+ root: Root directory where the dataset should be saved.
18
+ transform: A function/transform that takes in an
19
+ :class:`torch_geometric.data.HeteroData` object and returns a
20
20
  transformed version. The data object will be transformed before
21
- every access. (default: :obj:`None`)
22
- pre_transform (callable, optional): A function/transform that takes in
23
- an :obj:`torch_geometric.data.HeteroData` object and returns a
21
+ every access.
22
+ pre_transform: A function/transform that takes in an
23
+ :class:`torch_geometric.data.HeteroData` object and returns a
24
24
  transformed version. The data object will be transformed before
25
- being saved to disk. (default: :obj:`None`)
26
- force_reload (bool, optional): Whether to re-process the dataset.
27
- (default: :obj:`False`)
25
+ being saved to disk.
26
+ force_reload: Whether to re-process the dataset.
28
27
  """
29
28
  url = ('https://raw.githubusercontent.com/gusye1234/LightGCN-PyTorch/'
30
29
  'master/data/amazon-book')
@@ -14,17 +14,15 @@ class AmazonProducts(InMemoryDataset):
14
14
  containing products and its categories.
15
15
 
16
16
  Args:
17
- root (str): Root directory where the dataset should be saved.
18
- transform (callable, optional): A function/transform that takes in an
19
- :obj:`torch_geometric.data.Data` object and returns a transformed
17
+ root: Root directory where the dataset should be saved.
18
+ transform: A function/transform that takes in an
19
+ :class:`torch_geometric.data.Data` object and returns a transformed
20
20
  version. The data object will be transformed before every access.
21
- (default: :obj:`None`)
22
- pre_transform (callable, optional): A function/transform that takes in
23
- an :obj:`torch_geometric.data.Data` object and returns a
21
+ pre_transform: A function/transform that takes in a
22
+ :class:`torch_geometric.data.Data` object and returns a
24
23
  transformed version. The data object will be transformed before
25
- being saved to disk. (default: :obj:`None`)
26
- force_reload (bool, optional): Whether to re-process the dataset.
27
- (default: :obj:`False`)
24
+ being saved to disk.
25
+ force_reload: Whether to re-process the dataset.
28
26
 
29
27
  **STATS:**
30
28
 
@@ -24,17 +24,16 @@ class AMiner(InMemoryDataset):
24
24
  truth labels for a subset of nodes.
25
25
 
26
26
  Args:
27
- root (str): Root directory where the dataset should be saved.
28
- transform (callable, optional): A function/transform that takes in an
29
- :obj:`torch_geometric.data.HeteroData` object and returns a
27
+ root: Root directory where the dataset should be saved.
28
+ transform: A function/transform that takes in a
29
+ :class:`torch_geometric.data.HeteroData` object and returns a
30
30
  transformed version. The data object will be transformed before
31
- every access. (default: :obj:`None`)
32
- pre_transform (callable, optional): A function/transform that takes in
33
- an :obj:`torch_geometric.data.HeteroData` object and returns a
31
+ every access.
32
+ pre_transform: A function/transform that takes in a
33
+ :class:`torch_geometric.data.HeteroData` object and returns a
34
34
  transformed version. The data object will be transformed before
35
- being saved to disk. (default: :obj:`None`)
36
- force_reload (bool, optional): Whether to re-process the dataset.
37
- (default: :obj:`False`)
35
+ being saved to disk.
36
+ force_reload: Whether to re-process the dataset.
38
37
  """
39
38
 
40
39
  url = 'https://www.dropbox.com/s/1bnz8r7mofx0osf/net_aminer.zip?dl=1'
@@ -30,25 +30,22 @@ class AQSOL(InMemoryDataset):
30
30
  the :class:`~torch_geometric.datasets.ZINC` dataset.
31
31
 
32
32
  Args:
33
- root (str): Root directory where the dataset should be saved.
34
- split (str, optional): If :obj:`"train"`, loads the training dataset.
33
+ root: Root directory where the dataset should be saved.
34
+ split: If :obj:`"train"`, loads the training dataset.
35
35
  If :obj:`"val"`, loads the validation dataset.
36
36
  If :obj:`"test"`, loads the test dataset.
37
- (default: :obj:`"train"`)
38
- transform (callable, optional): A function/transform that takes in an
39
- :obj:`torch_geometric.data.Data` object and returns a transformed
37
+ transform: A function/transform that takes in a
38
+ :class:`torch_geometric.data.Data` object and returns a transformed
40
39
  version. The data object will be transformed before every access.
41
- (default: :obj:`None`)
42
- pre_transform (callable, optional): A function/transform that takes in
43
- an :obj:`torch_geometric.data.Data` object and returns a
40
+ pre_transform: A function/transform that takes in a
41
+ :class:`torch_geometric.data.Data` object and returns a
44
42
  transformed version. The data object will be transformed before
45
- being saved to disk. (default: :obj:`None`)
43
+ being saved to disk.
46
44
  pre_filter (callable, optional): A function that takes in an
47
- :obj:`torch_geometric.data.Data` object and returns a boolean
45
+ :class:`torch_geometric.data.Data` object and returns a boolean
48
46
  value, indicating whether the data object should be included in
49
- the final dataset. (default: :obj:`None`)
50
- force_reload (bool, optional): Whether to re-process the dataset.
51
- (default: :obj:`False`)
47
+ the final dataset.
48
+ force_reload: Whether to re-process the dataset.
52
49
 
53
50
  **STATS:**
54
51
 
@@ -19,21 +19,19 @@ class AttributedGraphDataset(InMemoryDataset):
19
19
  <https://arxiv.org/abs/2009.00826>`_ paper.
20
20
 
21
21
  Args:
22
- root (str): Root directory where the dataset should be saved.
23
- name (str): The name of the dataset (:obj:`"Wiki"`, :obj:`"Cora"`
22
+ root: Root directory where the dataset should be saved.
23
+ name: The name of the dataset (:obj:`"Wiki"`, :obj:`"Cora"`,
24
24
  :obj:`"CiteSeer"`, :obj:`"PubMed"`, :obj:`"BlogCatalog"`,
25
25
  :obj:`"PPI"`, :obj:`"Flickr"`, :obj:`"Facebook"`, :obj:`"Twitter"`,
26
26
  :obj:`"TWeibo"`, :obj:`"MAG"`).
27
- transform (callable, optional): A function/transform that takes in an
28
- :obj:`torch_geometric.data.Data` object and returns a transformed
27
+ transform: A function/transform that takes in a
28
+ :class:`torch_geometric.data.Data` object and returns a transformed
29
29
  version. The data object will be transformed before every access.
30
- (default: :obj:`None`)
31
- pre_transform (callable, optional): A function/transform that takes in
32
- an :obj:`torch_geometric.data.Data` object and returns a
30
+ pre_transform: A function/transform that takes in a
31
+ :class:`torch_geometric.data.Data` object and returns a
33
32
  transformed version. The data object will be transformed before
34
- being saved to disk. (default: :obj:`None`)
35
- force_reload (bool, optional): Whether to re-process the dataset.
36
- (default: :obj:`False`)
33
+ being saved to disk.
34
+ force_reload: Whether to re-process the dataset.
37
35
 
38
36
  **STATS:**
39
37
 
@@ -25,21 +25,19 @@ class BAMultiShapesDataset(InMemoryDataset):
25
25
  This dataset is pre-computed from the official implementation.
26
26
 
27
27
  Args:
28
- root (str): Root directory where the dataset should be saved.
29
- transform (callable, optional): A function/transform that takes in an
30
- :obj:`torch_geometric.data.Data` object and returns a transformed
28
+ root: Root directory where the dataset should be saved.
29
+ transform: A function/transform that takes in a
30
+ :class:`torch_geometric.data.Data` object and returns a transformed
31
31
  version. The data object will be transformed before every access.
32
- (default: :obj:`None`)
33
- pre_transform (callable, optional): A function/transform that takes in
34
- an :obj:`torch_geometric.data.Data` object and returns a
32
+ pre_transform: A function/transform that takes in a
33
+ :class:`torch_geometric.data.Data` object and returns a
35
34
  transformed version. The data object will be transformed before
36
- being saved to disk. (default: :obj:`None`)
37
- pre_filter (callable, optional): A function that takes in an
38
- :obj:`torch_geometric.data.Data` object and returns a boolean
35
+ being saved to disk.
36
+ pre_filter: A function that takes in a
37
+ :class:`torch_geometric.data.Data` object and returns a boolean
39
38
  value, indicating whether the data object should be included in the
40
- final dataset. (default: :obj:`None`)
41
- force_reload (bool, optional): Whether to re-process the dataset.
42
- (default: :obj:`False`)
39
+ final dataset.
40
+ force_reload: Whether to re-process the dataset.
43
41
 
44
42
  **STATS:**
45
43
 
@@ -30,15 +30,14 @@ class BAShapes(InMemoryDataset):
30
30
  :class:`torch_geometric.datasets.graph_generator.BAGraph` instead.
31
31
 
32
32
  Args:
33
- connection_distribution (str, optional): Specifies how the houses
34
- and the BA graph get connected. Valid inputs are :obj:`"random"`
33
+ connection_distribution: Specifies how the houses and the BA graph get
34
+ connected. Valid inputs are :obj:`"random"`
35
35
  (random BA graph nodes are selected for connection to the houses),
36
36
  and :obj:`"uniform"` (uniformly distributed BA graph nodes are
37
- selected for connection to the houses). (default: :obj:`"random"`)
38
- transform (callable, optional): A function/transform that takes in an
39
- :obj:`torch_geometric.data.Data` object and returns a transformed
37
+ selected for connection to the houses).
38
+ transform: A function/transform that takes in a
39
+ :class:`torch_geometric.data.Data` object and returns a transformed
40
40
  version. The data object will be transformed before every access.
41
- (default: :obj:`None`)
42
41
  """
43
42
  def __init__(
44
43
  self,
@@ -94,7 +94,7 @@ class BrcaTcga(InMemoryDataset):
94
94
  graph_feat = torch.from_numpy(graph_feat).to(torch.float)
95
95
  graph_labels = np.loadtxt(self.raw_paths[1], delimiter=',')
96
96
  graph_label = torch.from_numpy(graph_labels).to(torch.float)
97
- edge_index = torch.load(self.raw_paths[2])
97
+ edge_index = fs.torch_load(self.raw_paths[2])
98
98
 
99
99
  data_list = []
100
100
  for x, y in zip(graph_feat, graph_label):
@@ -0,0 +1,157 @@
1
+ import os.path as osp
2
+ from typing import Callable, Optional
3
+
4
+ from torch_geometric.data import (
5
+ Data,
6
+ InMemoryDataset,
7
+ download_url,
8
+ extract_tar,
9
+ )
10
+ from torch_geometric.io import fs
11
+
12
+
13
+ class CityNetwork(InMemoryDataset):
14
+ r"""The City-Networks are introduced in
15
+ `"Towards Quantifying Long-Range Interactions in Graph Machine Learning:
16
+ a Large Graph Dataset and a Measurement"
17
+ <https://arxiv.org/abs/2503.09008>`_ paper.
18
+ The dataset contains four city networks: `paris`, `shanghai`, `la`,
19
+ and `london`, where nodes represent junctions and edges represent
20
+ undirected road segments. The task is to predict each node's eccentricity
21
+ score, which is approximated based on its 16-hop neighborhood and naturally
22
+ requires long-range information. The score indicates how accessible one
23
+ node is in the network, and is mapped to 10 quantiles for transductive
24
+ classification. See the original
25
+ `source code <https://github.com/LeonResearch/City-Networks>`_ for more
26
+ details on the individual networks.
27
+
28
+ Args:
29
+ root (str): Root directory where the dataset should be saved.
30
+ name (str): The name of the dataset (``"paris"``, ``"shanghai"``,
31
+ ``"la"``, ``"london"``).
32
+ augmented (bool, optional): Whether to use the augmented node features
33
+ from edge features.(default: :obj:`True`)
34
+ transform (callable, optional): A function/transform that takes in an
35
+ :class:`~torch_geometric.data.Data` object and returns a
36
+ transformed version. The data object will be transformed before
37
+ every access. (default: :obj:`None`)
38
+ pre_transform (callable, optional): A function/transform that takes in
39
+ an :class:`~torch_geometric.data.Data` object and returns a
40
+ transformed version. The data object will be transformed before
41
+ being saved to disk. (default: :obj:`None`)
42
+ force_reload (bool, optional): Whether to re-process the dataset.
43
+ (default: :obj:`False`)
44
+
45
+ **STATS:**
46
+
47
+ .. list-table::
48
+ :widths: 10 10 10 10 10
49
+ :header-rows: 1
50
+
51
+ * - Name
52
+ - #nodes
53
+ - #edges
54
+ - #features
55
+ - #classes
56
+ * - paris
57
+ - 114,127
58
+ - 182,511
59
+ - 37
60
+ - 10
61
+ * - shanghai
62
+ - 183,917
63
+ - 262,092
64
+ - 37
65
+ - 10
66
+ * - la
67
+ - 240,587
68
+ - 341,523
69
+ - 37
70
+ - 10
71
+ * - london
72
+ - 568,795
73
+ - 756,502
74
+ - 37
75
+ - 10
76
+ """
77
+ url = "https://github.com/LeonResearch/City-Networks/raw/refs/heads/main/data/" # noqa: E501
78
+
79
+ def __init__(
80
+ self,
81
+ root: str,
82
+ name: str,
83
+ augmented: bool = True,
84
+ transform: Optional[Callable] = None,
85
+ pre_transform: Optional[Callable] = None,
86
+ force_reload: bool = False,
87
+ delete_raw: bool = False,
88
+ ) -> None:
89
+ self.name = name.lower()
90
+ assert self.name in ["paris", "shanghai", "la", "london"]
91
+ self.augmented = augmented
92
+ self.delete_raw = delete_raw
93
+ super().__init__(
94
+ root,
95
+ transform,
96
+ pre_transform,
97
+ force_reload=force_reload,
98
+ )
99
+ self.load(self.processed_paths[0])
100
+
101
+ @property
102
+ def raw_dir(self) -> str:
103
+ return osp.join(self.root, self.name, "raw")
104
+
105
+ @property
106
+ def processed_dir(self) -> str:
107
+ return osp.join(self.root, self.name, "processed")
108
+
109
+ @property
110
+ def raw_file_names(self) -> str:
111
+ return f"{self.name}.json"
112
+
113
+ @property
114
+ def processed_file_names(self) -> str:
115
+ return "data.pt"
116
+
117
+ def download(self) -> None:
118
+ self.download_path = download_url(
119
+ self.url + f"{self.name}.tar.gz",
120
+ self.raw_dir,
121
+ )
122
+
123
+ def process(self) -> None:
124
+ extract_tar(self.download_path, self.raw_dir)
125
+ data_path = osp.join(self.raw_dir, self.name)
126
+ node_feat = fs.torch_load(
127
+ osp.join(
128
+ data_path,
129
+ f"node_features{'_augmented' if self.augmented else ''}.pt",
130
+ ))
131
+ edge_index = fs.torch_load(osp.join(data_path, "edge_indices.pt"))
132
+ label = fs.torch_load(
133
+ osp.join(data_path, "10-chunk_16-hop_node_labels.pt"))
134
+ train_mask = fs.torch_load(osp.join(data_path, "train_mask.pt"))
135
+ val_mask = fs.torch_load(osp.join(data_path, "valid_mask.pt"))
136
+ test_mask = fs.torch_load(osp.join(data_path, "test_mask.pt"))
137
+ data = Data(
138
+ x=node_feat,
139
+ edge_index=edge_index,
140
+ y=label,
141
+ train_mask=train_mask,
142
+ val_mask=val_mask,
143
+ test_mask=test_mask,
144
+ )
145
+ if self.pre_transform is not None:
146
+ data = self.pre_transform(data)
147
+
148
+ self.save([data], self.processed_paths[0])
149
+
150
+ if self.delete_raw:
151
+ fs.rm(data_path)
152
+
153
+ def __repr__(self) -> str:
154
+ return (f"{self.__class__.__name__}("
155
+ f"root='{self.root}', "
156
+ f"name='{self.name}', "
157
+ f"augmented={self.augmented})")
@@ -73,7 +73,7 @@ class DBP15K(InMemoryDataset):
73
73
  def process(self) -> None:
74
74
  embs = {}
75
75
  with open(osp.join(self.raw_dir, 'sub.glove.300d')) as f:
76
- for i, line in enumerate(f):
76
+ for line in f:
77
77
  info = line.strip().split(' ')
78
78
  if len(info) > 300:
79
79
  embs[info[0]] = torch.tensor([float(x) for x in info[1:]])
@@ -9,6 +9,7 @@ from torch_geometric.data import (
9
9
  download_url,
10
10
  extract_zip,
11
11
  )
12
+ from torch_geometric.io import fs
12
13
 
13
14
 
14
15
  class GDELTLite(InMemoryDataset):
@@ -80,9 +81,9 @@ class GDELTLite(InMemoryDataset):
80
81
  def process(self) -> None:
81
82
  import pandas as pd
82
83
 
83
- x = torch.load(self.raw_paths[0])
84
+ x = fs.torch_load(self.raw_paths[0])
84
85
  df = pd.read_csv(self.raw_paths[1])
85
- edge_attr = torch.load(self.raw_paths[2])
86
+ edge_attr = fs.torch_load(self.raw_paths[2])
86
87
 
87
88
  row = torch.from_numpy(df['src'].values)
88
89
  col = torch.from_numpy(df['dst'].values)
@@ -13,6 +13,7 @@ from torch_geometric.data import (
13
13
  extract_tar,
14
14
  extract_zip,
15
15
  )
16
+ from torch_geometric.io import fs
16
17
  from torch_geometric.utils import one_hot, to_undirected
17
18
 
18
19
 
@@ -145,9 +146,9 @@ class GEDDataset(InMemoryDataset):
145
146
  path = self.processed_paths[0] if train else self.processed_paths[1]
146
147
  self.load(path)
147
148
  path = osp.join(self.processed_dir, f'{self.name}_ged.pt')
148
- self.ged = torch.load(path)
149
+ self.ged = fs.torch_load(path)
149
150
  path = osp.join(self.processed_dir, f'{self.name}_norm_ged.pt')
150
- self.norm_ged = torch.load(path)
151
+ self.norm_ged = fs.torch_load(path)
151
152
 
152
153
  @property
153
154
  def raw_file_names(self) -> List[str]: