pyg-nightly 2.6.0.dev20240319__py3-none-any.whl → 2.7.0.dev20250114__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (226) hide show
  1. {pyg_nightly-2.6.0.dev20240319.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/METADATA +31 -47
  2. {pyg_nightly-2.6.0.dev20240319.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/RECORD +226 -199
  3. {pyg_nightly-2.6.0.dev20240319.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/WHEEL +1 -1
  4. torch_geometric/__init__.py +28 -1
  5. torch_geometric/_compile.py +8 -1
  6. torch_geometric/_onnx.py +14 -0
  7. torch_geometric/config_mixin.py +113 -0
  8. torch_geometric/config_store.py +28 -19
  9. torch_geometric/data/__init__.py +24 -1
  10. torch_geometric/data/batch.py +2 -2
  11. torch_geometric/data/collate.py +8 -2
  12. torch_geometric/data/data.py +16 -8
  13. torch_geometric/data/database.py +61 -15
  14. torch_geometric/data/dataset.py +14 -6
  15. torch_geometric/data/feature_store.py +25 -42
  16. torch_geometric/data/graph_store.py +1 -5
  17. torch_geometric/data/hetero_data.py +18 -9
  18. torch_geometric/data/in_memory_dataset.py +2 -4
  19. torch_geometric/data/large_graph_indexer.py +677 -0
  20. torch_geometric/data/lightning/datamodule.py +4 -4
  21. torch_geometric/data/separate.py +6 -1
  22. torch_geometric/data/storage.py +17 -7
  23. torch_geometric/data/summary.py +14 -4
  24. torch_geometric/data/temporal.py +1 -2
  25. torch_geometric/datasets/__init__.py +17 -2
  26. torch_geometric/datasets/actor.py +9 -11
  27. torch_geometric/datasets/airfrans.py +15 -18
  28. torch_geometric/datasets/airports.py +10 -12
  29. torch_geometric/datasets/amazon.py +8 -11
  30. torch_geometric/datasets/amazon_book.py +9 -10
  31. torch_geometric/datasets/amazon_products.py +9 -10
  32. torch_geometric/datasets/aminer.py +8 -9
  33. torch_geometric/datasets/aqsol.py +10 -13
  34. torch_geometric/datasets/attributed_graph_dataset.py +10 -12
  35. torch_geometric/datasets/ba_multi_shapes.py +10 -12
  36. torch_geometric/datasets/ba_shapes.py +5 -6
  37. torch_geometric/datasets/bitcoin_otc.py +1 -1
  38. torch_geometric/datasets/brca_tgca.py +1 -1
  39. torch_geometric/datasets/cornell.py +145 -0
  40. torch_geometric/datasets/dblp.py +2 -1
  41. torch_geometric/datasets/dbp15k.py +2 -2
  42. torch_geometric/datasets/fake.py +1 -3
  43. torch_geometric/datasets/flickr.py +2 -1
  44. torch_geometric/datasets/freebase.py +1 -1
  45. torch_geometric/datasets/gdelt_lite.py +3 -2
  46. torch_geometric/datasets/ged_dataset.py +3 -2
  47. torch_geometric/datasets/git_mol_dataset.py +263 -0
  48. torch_geometric/datasets/gnn_benchmark_dataset.py +11 -10
  49. torch_geometric/datasets/hgb_dataset.py +8 -8
  50. torch_geometric/datasets/imdb.py +2 -1
  51. torch_geometric/datasets/karate.py +3 -2
  52. torch_geometric/datasets/last_fm.py +2 -1
  53. torch_geometric/datasets/linkx_dataset.py +4 -3
  54. torch_geometric/datasets/lrgb.py +3 -5
  55. torch_geometric/datasets/malnet_tiny.py +4 -3
  56. torch_geometric/datasets/mnist_superpixels.py +2 -3
  57. torch_geometric/datasets/molecule_gpt_dataset.py +485 -0
  58. torch_geometric/datasets/molecule_net.py +15 -3
  59. torch_geometric/datasets/motif_generator/base.py +0 -1
  60. torch_geometric/datasets/neurograph.py +1 -3
  61. torch_geometric/datasets/ogb_mag.py +1 -1
  62. torch_geometric/datasets/opf.py +239 -0
  63. torch_geometric/datasets/ose_gvcs.py +1 -1
  64. torch_geometric/datasets/pascal.py +11 -9
  65. torch_geometric/datasets/pascal_pf.py +1 -1
  66. torch_geometric/datasets/pcpnet_dataset.py +1 -1
  67. torch_geometric/datasets/pcqm4m.py +10 -3
  68. torch_geometric/datasets/ppi.py +1 -1
  69. torch_geometric/datasets/qm9.py +8 -7
  70. torch_geometric/datasets/rcdd.py +4 -4
  71. torch_geometric/datasets/reddit.py +2 -1
  72. torch_geometric/datasets/reddit2.py +2 -1
  73. torch_geometric/datasets/rel_link_pred_dataset.py +3 -3
  74. torch_geometric/datasets/s3dis.py +5 -3
  75. torch_geometric/datasets/shapenet.py +3 -3
  76. torch_geometric/datasets/shrec2016.py +2 -2
  77. torch_geometric/datasets/snap_dataset.py +7 -1
  78. torch_geometric/datasets/tag_dataset.py +350 -0
  79. torch_geometric/datasets/upfd.py +2 -1
  80. torch_geometric/datasets/web_qsp_dataset.py +246 -0
  81. torch_geometric/datasets/webkb.py +2 -2
  82. torch_geometric/datasets/wikics.py +1 -1
  83. torch_geometric/datasets/wikidata.py +3 -2
  84. torch_geometric/datasets/wikipedia_network.py +2 -2
  85. torch_geometric/datasets/willow_object_class.py +1 -1
  86. torch_geometric/datasets/word_net.py +2 -2
  87. torch_geometric/datasets/yelp.py +2 -1
  88. torch_geometric/datasets/zinc.py +1 -1
  89. torch_geometric/device.py +42 -0
  90. torch_geometric/distributed/local_feature_store.py +3 -2
  91. torch_geometric/distributed/local_graph_store.py +2 -1
  92. torch_geometric/distributed/partition.py +9 -8
  93. torch_geometric/edge_index.py +616 -438
  94. torch_geometric/explain/algorithm/base.py +0 -1
  95. torch_geometric/explain/algorithm/graphmask_explainer.py +1 -2
  96. torch_geometric/explain/algorithm/pg_explainer.py +1 -1
  97. torch_geometric/explain/explanation.py +2 -2
  98. torch_geometric/graphgym/checkpoint.py +2 -1
  99. torch_geometric/graphgym/logger.py +4 -4
  100. torch_geometric/graphgym/loss.py +1 -1
  101. torch_geometric/graphgym/utils/agg_runs.py +6 -6
  102. torch_geometric/index.py +826 -0
  103. torch_geometric/inspector.py +8 -3
  104. torch_geometric/io/fs.py +28 -2
  105. torch_geometric/io/npz.py +2 -1
  106. torch_geometric/io/off.py +2 -2
  107. torch_geometric/io/sdf.py +2 -2
  108. torch_geometric/io/tu.py +4 -5
  109. torch_geometric/loader/__init__.py +4 -0
  110. torch_geometric/loader/cluster.py +10 -4
  111. torch_geometric/loader/graph_saint.py +2 -1
  112. torch_geometric/loader/ibmb_loader.py +12 -4
  113. torch_geometric/loader/mixin.py +1 -1
  114. torch_geometric/loader/neighbor_loader.py +1 -1
  115. torch_geometric/loader/neighbor_sampler.py +2 -2
  116. torch_geometric/loader/prefetch.py +1 -1
  117. torch_geometric/loader/rag_loader.py +107 -0
  118. torch_geometric/loader/utils.py +8 -7
  119. torch_geometric/loader/zip_loader.py +10 -0
  120. torch_geometric/metrics/__init__.py +11 -2
  121. torch_geometric/metrics/link_pred.py +159 -34
  122. torch_geometric/nn/aggr/__init__.py +4 -0
  123. torch_geometric/nn/aggr/attention.py +0 -2
  124. torch_geometric/nn/aggr/base.py +2 -4
  125. torch_geometric/nn/aggr/patch_transformer.py +143 -0
  126. torch_geometric/nn/aggr/set_transformer.py +1 -1
  127. torch_geometric/nn/aggr/variance_preserving.py +33 -0
  128. torch_geometric/nn/attention/__init__.py +5 -1
  129. torch_geometric/nn/attention/qformer.py +71 -0
  130. torch_geometric/nn/conv/collect.jinja +7 -4
  131. torch_geometric/nn/conv/cugraph/base.py +8 -12
  132. torch_geometric/nn/conv/edge_conv.py +3 -2
  133. torch_geometric/nn/conv/fused_gat_conv.py +1 -1
  134. torch_geometric/nn/conv/gat_conv.py +35 -7
  135. torch_geometric/nn/conv/gatv2_conv.py +36 -6
  136. torch_geometric/nn/conv/general_conv.py +1 -1
  137. torch_geometric/nn/conv/graph_conv.py +21 -3
  138. torch_geometric/nn/conv/gravnet_conv.py +3 -2
  139. torch_geometric/nn/conv/hetero_conv.py +3 -3
  140. torch_geometric/nn/conv/hgt_conv.py +1 -1
  141. torch_geometric/nn/conv/message_passing.py +138 -87
  142. torch_geometric/nn/conv/mixhop_conv.py +1 -1
  143. torch_geometric/nn/conv/propagate.jinja +9 -1
  144. torch_geometric/nn/conv/rgcn_conv.py +5 -5
  145. torch_geometric/nn/conv/spline_conv.py +4 -4
  146. torch_geometric/nn/conv/x_conv.py +3 -2
  147. torch_geometric/nn/dense/linear.py +11 -6
  148. torch_geometric/nn/fx.py +3 -3
  149. torch_geometric/nn/model_hub.py +3 -1
  150. torch_geometric/nn/models/__init__.py +10 -2
  151. torch_geometric/nn/models/deep_graph_infomax.py +1 -2
  152. torch_geometric/nn/models/dimenet_utils.py +5 -7
  153. torch_geometric/nn/models/g_retriever.py +230 -0
  154. torch_geometric/nn/models/git_mol.py +336 -0
  155. torch_geometric/nn/models/glem.py +385 -0
  156. torch_geometric/nn/models/gnnff.py +0 -1
  157. torch_geometric/nn/models/graph_unet.py +12 -3
  158. torch_geometric/nn/models/jumping_knowledge.py +63 -4
  159. torch_geometric/nn/models/lightgcn.py +1 -1
  160. torch_geometric/nn/models/metapath2vec.py +5 -5
  161. torch_geometric/nn/models/molecule_gpt.py +222 -0
  162. torch_geometric/nn/models/node2vec.py +2 -3
  163. torch_geometric/nn/models/schnet.py +2 -1
  164. torch_geometric/nn/models/signed_gcn.py +3 -3
  165. torch_geometric/nn/module_dict.py +2 -2
  166. torch_geometric/nn/nlp/__init__.py +9 -0
  167. torch_geometric/nn/nlp/llm.py +322 -0
  168. torch_geometric/nn/nlp/sentence_transformer.py +134 -0
  169. torch_geometric/nn/nlp/vision_transformer.py +33 -0
  170. torch_geometric/nn/norm/batch_norm.py +1 -1
  171. torch_geometric/nn/parameter_dict.py +2 -2
  172. torch_geometric/nn/pool/__init__.py +21 -5
  173. torch_geometric/nn/pool/cluster_pool.py +145 -0
  174. torch_geometric/nn/pool/connect/base.py +0 -1
  175. torch_geometric/nn/pool/edge_pool.py +1 -1
  176. torch_geometric/nn/pool/graclus.py +4 -2
  177. torch_geometric/nn/pool/pool.py +8 -2
  178. torch_geometric/nn/pool/select/base.py +0 -1
  179. torch_geometric/nn/pool/voxel_grid.py +3 -2
  180. torch_geometric/nn/resolver.py +1 -1
  181. torch_geometric/nn/sequential.jinja +10 -23
  182. torch_geometric/nn/sequential.py +204 -78
  183. torch_geometric/nn/summary.py +1 -1
  184. torch_geometric/nn/to_hetero_with_bases_transformer.py +19 -19
  185. torch_geometric/profile/__init__.py +2 -0
  186. torch_geometric/profile/nvtx.py +66 -0
  187. torch_geometric/profile/profiler.py +30 -19
  188. torch_geometric/resolver.py +1 -1
  189. torch_geometric/sampler/base.py +34 -13
  190. torch_geometric/sampler/neighbor_sampler.py +11 -10
  191. torch_geometric/sampler/utils.py +1 -1
  192. torch_geometric/template.py +1 -0
  193. torch_geometric/testing/__init__.py +6 -2
  194. torch_geometric/testing/decorators.py +53 -20
  195. torch_geometric/testing/feature_store.py +1 -1
  196. torch_geometric/transforms/__init__.py +2 -0
  197. torch_geometric/transforms/add_metapaths.py +5 -5
  198. torch_geometric/transforms/add_positional_encoding.py +1 -1
  199. torch_geometric/transforms/delaunay.py +65 -14
  200. torch_geometric/transforms/face_to_edge.py +32 -3
  201. torch_geometric/transforms/gdc.py +7 -6
  202. torch_geometric/transforms/laplacian_lambda_max.py +3 -3
  203. torch_geometric/transforms/mask.py +5 -1
  204. torch_geometric/transforms/node_property_split.py +1 -2
  205. torch_geometric/transforms/pad.py +7 -6
  206. torch_geometric/transforms/random_link_split.py +1 -1
  207. torch_geometric/transforms/remove_self_loops.py +36 -0
  208. torch_geometric/transforms/svd_feature_reduction.py +1 -1
  209. torch_geometric/transforms/to_sparse_tensor.py +1 -1
  210. torch_geometric/transforms/two_hop.py +1 -1
  211. torch_geometric/transforms/virtual_node.py +2 -1
  212. torch_geometric/typing.py +43 -6
  213. torch_geometric/utils/__init__.py +5 -1
  214. torch_geometric/utils/_negative_sampling.py +1 -1
  215. torch_geometric/utils/_normalize_edge_index.py +46 -0
  216. torch_geometric/utils/_scatter.py +38 -12
  217. torch_geometric/utils/_subgraph.py +4 -0
  218. torch_geometric/utils/_tree_decomposition.py +2 -2
  219. torch_geometric/utils/augmentation.py +1 -1
  220. torch_geometric/utils/convert.py +12 -8
  221. torch_geometric/utils/geodesic.py +24 -22
  222. torch_geometric/utils/hetero.py +1 -1
  223. torch_geometric/utils/map.py +8 -2
  224. torch_geometric/utils/smiles.py +65 -27
  225. torch_geometric/utils/sparse.py +39 -25
  226. torch_geometric/visualization/graph.py +3 -4
@@ -13,7 +13,7 @@ from torch_geometric.typing import InputEdges, InputNodes, OptTensor
13
13
  try:
14
14
  from pytorch_lightning import LightningDataModule as PLLightningDataModule
15
15
  no_pytorch_lightning = False
16
- except (ImportError, ModuleNotFoundError):
16
+ except ImportError:
17
17
  PLLightningDataModule = object # type: ignore
18
18
  no_pytorch_lightning = True
19
19
 
@@ -221,7 +221,7 @@ class LightningDataset(LightningDataModule):
221
221
  speed.html>`__ are supported in order to correctly share data across
222
222
  all devices/processes:
223
223
 
224
- .. code-block::
224
+ .. code-block:: python
225
225
 
226
226
  import pytorch_lightning as pl
227
227
  trainer = pl.Trainer(strategy="ddp_spawn", accelerator="gpu",
@@ -332,7 +332,7 @@ class LightningNodeData(LightningData):
332
332
  speed.html>`__ are supported in order to correctly share data across
333
333
  all devices/processes:
334
334
 
335
- .. code-block::
335
+ .. code-block:: python
336
336
 
337
337
  import pytorch_lightning as pl
338
338
  trainer = pl.Trainer(strategy="ddp_spawn", accelerator="gpu",
@@ -525,7 +525,7 @@ class LightningLinkData(LightningData):
525
525
  speed.html>`__ are supported in order to correctly share data across
526
526
  all devices/processes:
527
527
 
528
- .. code-block::
528
+ .. code-block:: python
529
529
 
530
530
  import pytorch_lightning as pl
531
531
  trainer = pl.Trainer(strategy="ddp_spawn", accelerator="gpu",
@@ -3,7 +3,7 @@ from typing import Any, Type, TypeVar
3
3
 
4
4
  from torch import Tensor
5
5
 
6
- from torch_geometric import EdgeIndex
6
+ from torch_geometric import EdgeIndex, Index
7
7
  from torch_geometric.data.data import BaseData
8
8
  from torch_geometric.data.storage import BaseStorage
9
9
  from torch_geometric.typing import SparseTensor, TensorFrame
@@ -76,6 +76,11 @@ def _separate(
76
76
  value = narrow(values, cat_dim or 0, start, end - start)
77
77
  value = value.squeeze(0) if cat_dim is None else value
78
78
 
79
+ if isinstance(values, Index) and values._cat_metadata is not None:
80
+ # Reconstruct original `Index` metadata:
81
+ value._dim_size = values._cat_metadata.dim_size[idx]
82
+ value._is_sorted = values._cat_metadata.is_sorted[idx]
83
+
79
84
  if isinstance(values, EdgeIndex) and values._cat_metadata is not None:
80
85
  # Reconstruct original `EdgeIndex` metadata:
81
86
  value._sparse_size = values._cat_metadata.sparse_size[idx]
@@ -370,18 +370,20 @@ class BaseStorage(MutableMapping):
370
370
  self,
371
371
  start_time: Union[float, int],
372
372
  end_time: Union[float, int],
373
+ attr: str = 'time',
373
374
  ) -> Self:
374
- if 'time' in self:
375
- mask = (self.time >= start_time) & (self.time <= end_time)
375
+ if attr in self:
376
+ time = self[attr]
377
+ mask = (time >= start_time) & (time <= end_time)
376
378
 
377
- if self.is_node_attr('time'):
379
+ if self.is_node_attr(attr):
378
380
  keys = self.node_attrs()
379
- elif self.is_edge_attr('time'):
381
+ elif self.is_edge_attr(attr):
380
382
  keys = self.edge_attrs()
381
383
 
382
384
  self._select(keys, mask)
383
385
 
384
- if self.is_node_attr('time') and 'num_nodes' in self:
386
+ if self.is_node_attr(attr) and 'num_nodes' in self:
385
387
  self.num_nodes: Optional[int] = int(mask.sum())
386
388
 
387
389
  return self
@@ -443,9 +445,9 @@ class NodeStorage(BaseStorage):
443
445
  return self.edge_index.sparse_size(0)
444
446
  if self.edge_index.sparse_size(1) is not None:
445
447
  return self.edge_index.sparse_size(1)
446
- if 'adj' in self and isinstance(self.adj, SparseTensor):
448
+ if 'adj' in self and isinstance(self.adj, (Tensor, SparseTensor)):
447
449
  return self.adj.size(0)
448
- if 'adj_t' in self and isinstance(self.adj_t, SparseTensor):
450
+ if 'adj_t' in self and isinstance(self.adj_t, (Tensor, SparseTensor)):
449
451
  return self.adj_t.size(1)
450
452
  warnings.warn(
451
453
  f"Unable to accurately infer 'num_nodes' from the attribute set "
@@ -804,6 +806,10 @@ class GlobalStorage(NodeStorage, EdgeStorage):
804
806
  return False
805
807
 
806
808
  cat_dim = self._parent().__cat_dim__(key, value, self)
809
+
810
+ if not isinstance(cat_dim, int):
811
+ return False
812
+
807
813
  num_nodes, num_edges = self.num_nodes, self.num_edges
808
814
 
809
815
  if value.shape[cat_dim] != num_nodes:
@@ -850,6 +856,10 @@ class GlobalStorage(NodeStorage, EdgeStorage):
850
856
  return False
851
857
 
852
858
  cat_dim = self._parent().__cat_dim__(key, value, self)
859
+
860
+ if not isinstance(cat_dim, int):
861
+ return False
862
+
853
863
  num_nodes, num_edges = self.num_nodes, self.num_edges
854
864
 
855
865
  if value.shape[cat_dim] != num_edges:
@@ -117,7 +117,14 @@ class Summary:
117
117
  num_edges_per_type=num_edges_per_type,
118
118
  )
119
119
 
120
- def __repr__(self) -> str:
120
+ def format(self, fmt: str = "psql") -> str:
121
+ r"""Formats summary statistics of the dataset.
122
+
123
+ Args:
124
+ fmt (str, optional): Summary tables format. Available table formats
125
+ can be found `here <https://github.com/astanin/python-tabulate?
126
+ tab=readme-ov-file#table-format>`__. (default: :obj:`"psql"`)
127
+ """
121
128
  from tabulate import tabulate
122
129
 
123
130
  body = f'{self.name} (#graphs={self.num_graphs}):\n'
@@ -127,7 +134,7 @@ class Summary:
127
134
  for field in Stats.__dataclass_fields__:
128
135
  row = [field] + [f'{getattr(s, field):.1f}' for s in stats]
129
136
  content.append(row)
130
- body += tabulate(content, headers='firstrow', tablefmt='psql')
137
+ body += tabulate(content, headers='firstrow', tablefmt=fmt)
131
138
 
132
139
  if self.num_nodes_per_type is not None:
133
140
  content = [['']]
@@ -140,7 +147,7 @@ class Summary:
140
147
  ]
141
148
  content.append(row)
142
149
  body += "\nNumber of nodes per node type:\n"
143
- body += tabulate(content, headers='firstrow', tablefmt='psql')
150
+ body += tabulate(content, headers='firstrow', tablefmt=fmt)
144
151
 
145
152
  if self.num_edges_per_type is not None:
146
153
  content = [['']]
@@ -156,6 +163,9 @@ class Summary:
156
163
  ]
157
164
  content.append(row)
158
165
  body += "\nNumber of edges per edge type:\n"
159
- body += tabulate(content, headers='firstrow', tablefmt='psql')
166
+ body += tabulate(content, headers='firstrow', tablefmt=fmt)
160
167
 
161
168
  return body
169
+
170
+ def __repr__(self) -> str:
171
+ return self.format()
@@ -156,8 +156,7 @@ class TemporalData(BaseData):
156
156
  return self.num_events
157
157
 
158
158
  def __call__(self, *args: List[str]) -> Iterable:
159
- for key, value in self._store.items(*args):
160
- yield key, value
159
+ yield from self._store.items(*args)
161
160
 
162
161
  def __copy__(self):
163
162
  out = self.__class__.__new__(self.__class__)
@@ -61,7 +61,6 @@ from .gemsec import GemsecDeezer
61
61
  from .twitch import Twitch
62
62
  from .airports import Airports
63
63
  from .lrgb import LRGBDataset
64
- from .neurograph import NeuroGraphDataset
65
64
  from .malnet_tiny import MalNetTiny
66
65
  from .omdb import OMDB
67
66
  from .polblogs import PolBlogs
@@ -76,6 +75,11 @@ from .jodie import JODIEDataset
76
75
  from .wikidata import Wikidata5M
77
76
  from .myket import MyketDataset
78
77
  from .brca_tgca import BrcaTcga
78
+ from .neurograph import NeuroGraphDataset
79
+ from .web_qsp_dataset import WebQSPDataset
80
+ from .git_mol_dataset import GitMolDataset
81
+ from .molecule_gpt_dataset import MoleculeGPTDataset
82
+ from .tag_dataset import TAGDataset
79
83
 
80
84
  from .dbp15k import DBP15K
81
85
  from .aminer import AMiner
@@ -93,6 +97,9 @@ from .amazon_book import AmazonBook
93
97
  from .hm import HM
94
98
  from .ose_gvcs import OSE_GVCS
95
99
  from .rcdd import RCDD
100
+ from .opf import OPFDataset
101
+
102
+ from .cornell import CornellTemporalHyperGraphDataset
96
103
 
97
104
  from .fake import FakeDataset, FakeHeteroDataset
98
105
  from .sbm_dataset import StochasticBlockModelDataset
@@ -185,6 +192,10 @@ homo_datasets = [
185
192
  'MyketDataset',
186
193
  'BrcaTcga',
187
194
  'NeuroGraphDataset',
195
+ 'WebQSPDataset',
196
+ 'GitMolDataset',
197
+ 'MoleculeGPTDataset',
198
+ 'TAGDataset',
188
199
  ]
189
200
 
190
201
  hetero_datasets = [
@@ -204,6 +215,10 @@ hetero_datasets = [
204
215
  'HM',
205
216
  'OSE_GVCS',
206
217
  'RCDD',
218
+ 'OPFDataset',
219
+ ]
220
+ hyper_datasets = [
221
+ 'CornellTemporalHyperGraphDataset',
207
222
  ]
208
223
  synthetic_datasets = [
209
224
  'FakeDataset',
@@ -218,4 +233,4 @@ synthetic_datasets = [
218
233
  'BAShapes',
219
234
  ]
220
235
 
221
- __all__ = homo_datasets + hetero_datasets + synthetic_datasets
236
+ __all__ = homo_datasets + hetero_datasets + hyper_datasets + synthetic_datasets
@@ -19,17 +19,15 @@ class Actor(InMemoryDataset):
19
19
  actor's Wikipedia.
20
20
 
21
21
  Args:
22
- root (str): Root directory where the dataset should be saved.
23
- transform (callable, optional): A function/transform that takes in an
22
+ root: Root directory where the dataset should be saved.
23
+ transform: A function/transform that takes in an
24
24
  :obj:`torch_geometric.data.Data` object and returns a transformed
25
25
  version. The data object will be transformed before every access.
26
- (default: :obj:`None`)
27
- pre_transform (callable, optional): A function/transform that takes in
28
- an :obj:`torch_geometric.data.Data` object and returns a
29
- transformed version. The data object will be transformed before
30
- being saved to disk. (default: :obj:`None`)
31
- force_reload (bool, optional): Whether to re-process the dataset.
32
- (default: :obj:`False`)
26
+ pre_transform: A function/transform that takes in an
27
+ :class:`torch_geometric.data.Data` object and returns a transformed
28
+ version. The data object will be transformed before being saved to
29
+ disk.
30
+ force_reload: Whether to re-process the dataset.
33
31
 
34
32
  **STATS:**
35
33
 
@@ -76,7 +74,7 @@ class Actor(InMemoryDataset):
76
74
  download_url(f'{self.url}/splits/{f}', self.raw_dir)
77
75
 
78
76
  def process(self) -> None:
79
- with open(self.raw_paths[0], 'r') as f:
77
+ with open(self.raw_paths[0]) as f:
80
78
  node_data = [x.split('\t') for x in f.read().split('\n')[1:-1]]
81
79
 
82
80
  rows, cols = [], []
@@ -93,7 +91,7 @@ class Actor(InMemoryDataset):
93
91
  for n_id, _, label in node_data:
94
92
  y[int(n_id)] = int(label)
95
93
 
96
- with open(self.raw_paths[1], 'r') as f:
94
+ with open(self.raw_paths[1]) as f:
97
95
  edge_data = f.read().split('\n')[1:-1]
98
96
  edge_indices = [[int(v) for v in r.split('\t')] for r in edge_data]
99
97
  edge_index = torch.tensor(edge_indices).t().contiguous()
@@ -2,14 +2,13 @@ import json
2
2
  import os
3
3
  from typing import Callable, List, Optional
4
4
 
5
- import torch
6
-
7
5
  from torch_geometric.data import (
8
6
  Data,
9
7
  InMemoryDataset,
10
8
  download_url,
11
9
  extract_zip,
12
10
  )
11
+ from torch_geometric.io import fs
13
12
 
14
13
 
15
14
  class AirfRANS(InMemoryDataset):
@@ -47,26 +46,24 @@ class AirfRANS(InMemoryDataset):
47
46
  :obj:`torch_geometric.transforms.RadiusGraph` transform.
48
47
 
49
48
  Args:
50
- root (str): Root directory where the dataset should be saved.
51
- task (str): The task to study (:obj:`"full"`, :obj:`"scarce"`,
49
+ root: Root directory where the dataset should be saved.
50
+ task: The task to study (:obj:`"full"`, :obj:`"scarce"`,
52
51
  :obj:`"reynolds"`, :obj:`"aoa"`) that defines the utilized training
53
52
  and test splits.
54
- train (bool, optional): If :obj:`True`, loads the training dataset,
55
- otherwise the test dataset. (default: :obj:`True`)
56
- transform (callable, optional): A function/transform that takes in an
57
- :obj:`torch_geometric.data.Data` object and returns a transformed
53
+ train: If :obj:`True`, loads the training dataset, otherwise the test
54
+ dataset.
55
+ transform: A function/transform that takes in an
56
+ :class:`torch_geometric.data.Data` object and returns a transformed
58
57
  version. The data object will be transformed before every access.
59
- (default: :obj:`None`)
60
- pre_transform (callable, optional): A function/transform that takes in
61
- an :obj:`torch_geometric.data.Data` object and returns a
58
+ pre_transform: A function/transform that takes in an
59
+ :class:`torch_geometric.data.Data` object and returns a
62
60
  transformed version. The data object will be transformed before
63
- being saved to disk. (default: :obj:`None`)
64
- pre_filter (callable, optional): A function that takes in an
61
+ being saved to disk.
62
+ pre_filter: A function that takes in an
65
63
  :obj:`torch_geometric.data.Data` object and returns a boolean
66
64
  value, indicating whether the data object should be included in the
67
- final dataset. (default: :obj:`None`)
68
- force_reload (bool, optional): Whether to re-process the dataset.
69
- (default: :obj:`False`)
65
+ final dataset.
66
+ force_reload: Whether to re-process the dataset.
70
67
 
71
68
  **STATS:**
72
69
 
@@ -123,13 +120,13 @@ class AirfRANS(InMemoryDataset):
123
120
  os.unlink(path)
124
121
 
125
122
  def process(self) -> None:
126
- with open(self.raw_paths[1], 'r') as f:
123
+ with open(self.raw_paths[1]) as f:
127
124
  manifest = json.load(f)
128
125
  total = manifest['full_train'] + manifest['full_test']
129
126
  partial = set(manifest[f'{self.task}_{self.split}'])
130
127
 
131
128
  data_list = []
132
- raw_data = torch.load(self.raw_paths[0])
129
+ raw_data = fs.torch_load(self.raw_paths[0])
133
130
  for k, s in enumerate(total):
134
131
  if s in partial:
135
132
  data = Data(**raw_data[k])
@@ -14,22 +14,20 @@ class Airports(InMemoryDataset):
14
14
  and labels correspond to activity levels.
15
15
  Features are given by one-hot encoded node identifiers, as described in the
16
16
  `"GraLSP: Graph Neural Networks with Local Structural Patterns"
17
- ` <https://arxiv.org/abs/1911.07675>`_ paper.
17
+ <https://arxiv.org/abs/1911.07675>`_ paper.
18
18
 
19
19
  Args:
20
- root (str): Root directory where the dataset should be saved.
21
- name (str): The name of the dataset (:obj:`"USA"`, :obj:`"Brazil"`,
20
+ root: Root directory where the dataset should be saved.
21
+ name: The name of the dataset (:obj:`"USA"`, :obj:`"Brazil"`,
22
22
  :obj:`"Europe"`).
23
- transform (callable, optional): A function/transform that takes in an
24
- :obj:`torch_geometric.data.Data` object and returns a transformed
23
+ transform: A function/transform that takes in an
24
+ :class:`torch_geometric.data.Data` object and returns a transformed
25
25
  version. The data object will be transformed before every access.
26
- (default: :obj:`None`)
27
26
  pre_transform (callable, optional): A function/transform that takes in
28
- an :obj:`torch_geometric.data.Data` object and returns a
27
+ :class:`torch_geometric.data.Data` object and returns a
29
28
  transformed version. The data object will be transformed before
30
- being saved to disk. (default: :obj:`None`)
31
- force_reload (bool, optional): Whether to re-process the dataset.
32
- (default: :obj:`False`)
29
+ being saved to disk.
30
+ force_reload: Whether to re-process the dataset.
33
31
  """
34
32
  edge_url = ('https://github.com/leoribeiro/struc2vec/'
35
33
  'raw/master/graph/{}-airports.edgelist')
@@ -75,7 +73,7 @@ class Airports(InMemoryDataset):
75
73
 
76
74
  def process(self) -> None:
77
75
  index_map, ys = {}, []
78
- with open(self.raw_paths[1], 'r') as f:
76
+ with open(self.raw_paths[1]) as f:
79
77
  rows = f.read().split('\n')[1:-1]
80
78
  for i, row in enumerate(rows):
81
79
  idx, label = row.split()
@@ -85,7 +83,7 @@ class Airports(InMemoryDataset):
85
83
  x = torch.eye(y.size(0))
86
84
 
87
85
  edge_indices = []
88
- with open(self.raw_paths[0], 'r') as f:
86
+ with open(self.raw_paths[0]) as f:
89
87
  rows = f.read().split('\n')[:-1]
90
88
  for row in rows:
91
89
  src, dst = row.split()
@@ -15,19 +15,16 @@ class Amazon(InMemoryDataset):
15
15
  map goods to their respective product category.
16
16
 
17
17
  Args:
18
- root (str): Root directory where the dataset should be saved.
19
- name (str): The name of the dataset (:obj:`"Computers"`,
20
- :obj:`"Photo"`).
21
- transform (callable, optional): A function/transform that takes in an
22
- :obj:`torch_geometric.data.Data` object and returns a transformed
18
+ root: Root directory where the dataset should be saved.
19
+ name: The name of the dataset (:obj:`"Computers"`, :obj:`"Photo"`).
20
+ transform: A function/transform that takes in a
21
+ :class:`torch_geometric.data.Data` object and returns a transformed
23
22
  version. The data object will be transformed before every access.
24
- (default: :obj:`None`)
25
- pre_transform (callable, optional): A function/transform that takes in
26
- an :obj:`torch_geometric.data.Data` object and returns a
23
+ pre_transform: A function/transform that takes in an
24
+ :class:`torch_geometric.data.Data` object and returns a
27
25
  transformed version. The data object will be transformed before
28
- being saved to disk. (default: :obj:`None`)
29
- force_reload (bool, optional): Whether to re-process the dataset.
30
- (default: :obj:`False`)
26
+ being saved to disk.
27
+ force_reload: Whether to re-process the dataset.
31
28
 
32
29
  **STATS:**
33
30
 
@@ -14,17 +14,16 @@ class AmazonBook(InMemoryDataset):
14
14
  No labels or features are provided.
15
15
 
16
16
  Args:
17
- root (str): Root directory where the dataset should be saved.
18
- transform (callable, optional): A function/transform that takes in an
19
- :obj:`torch_geometric.data.HeteroData` object and returns a
17
+ root: Root directory where the dataset should be saved.
18
+ transform: A function/transform that takes in an
19
+ :class:`torch_geometric.data.HeteroData` object and returns a
20
20
  transformed version. The data object will be transformed before
21
- every access. (default: :obj:`None`)
22
- pre_transform (callable, optional): A function/transform that takes in
23
- an :obj:`torch_geometric.data.HeteroData` object and returns a
21
+ every access.
22
+ pre_transform: A function/transform that takes in an
23
+ :class:`torch_geometric.data.HeteroData` object and returns a
24
24
  transformed version. The data object will be transformed before
25
- being saved to disk. (default: :obj:`None`)
26
- force_reload (bool, optional): Whether to re-process the dataset.
27
- (default: :obj:`False`)
25
+ being saved to disk.
26
+ force_reload: Whether to re-process the dataset.
28
27
  """
29
28
  url = ('https://raw.githubusercontent.com/gusye1234/LightGCN-PyTorch/'
30
29
  'master/data/amazon-book')
@@ -67,7 +66,7 @@ class AmazonBook(InMemoryDataset):
67
66
  attr_names = ['edge_index', 'edge_label_index']
68
67
  for path, attr_name in zip(self.raw_paths[2:], attr_names):
69
68
  rows, cols = [], []
70
- with open(path, 'r') as f:
69
+ with open(path) as f:
71
70
  lines = f.readlines()
72
71
  for line in lines:
73
72
  indices = line.strip().split(' ')
@@ -3,7 +3,6 @@ import os.path as osp
3
3
  from typing import Callable, List, Optional
4
4
 
5
5
  import numpy as np
6
- import scipy.sparse as sp
7
6
  import torch
8
7
 
9
8
  from torch_geometric.data import Data, InMemoryDataset, download_google_url
@@ -15,17 +14,15 @@ class AmazonProducts(InMemoryDataset):
15
14
  containing products and its categories.
16
15
 
17
16
  Args:
18
- root (str): Root directory where the dataset should be saved.
19
- transform (callable, optional): A function/transform that takes in an
20
- :obj:`torch_geometric.data.Data` object and returns a transformed
17
+ root: Root directory where the dataset should be saved.
18
+ transform: A function/transform that takes in an
19
+ :class:`torch_geometric.data.Data` object and returns a transformed
21
20
  version. The data object will be transformed before every access.
22
- (default: :obj:`None`)
23
- pre_transform (callable, optional): A function/transform that takes in
24
- an :obj:`torch_geometric.data.Data` object and returns a
21
+ pre_transform: A function/transform that takes in a
22
+ :class:`torch_geometric.data.Data` object and returns a
25
23
  transformed version. The data object will be transformed before
26
- being saved to disk. (default: :obj:`None`)
27
- force_reload (bool, optional): Whether to re-process the dataset.
28
- (default: :obj:`False`)
24
+ being saved to disk.
25
+ force_reload: Whether to re-process the dataset.
29
26
 
30
27
  **STATS:**
31
28
 
@@ -73,6 +70,8 @@ class AmazonProducts(InMemoryDataset):
73
70
  download_google_url(self.role_id, self.raw_dir, 'role.json')
74
71
 
75
72
  def process(self) -> None:
73
+ import scipy.sparse as sp
74
+
76
75
  f = np.load(osp.join(self.raw_dir, 'adj_full.npz'))
77
76
  adj = sp.csr_matrix((f['data'], f['indices'], f['indptr']), f['shape'])
78
77
  adj = adj.tocoo()
@@ -24,17 +24,16 @@ class AMiner(InMemoryDataset):
24
24
  truth labels for a subset of nodes.
25
25
 
26
26
  Args:
27
- root (str): Root directory where the dataset should be saved.
28
- transform (callable, optional): A function/transform that takes in an
29
- :obj:`torch_geometric.data.HeteroData` object and returns a
27
+ root: Root directory where the dataset should be saved.
28
+ transform: A function/transform that takes in a
29
+ :class:`torch_geometric.data.HeteroData` object and returns a
30
30
  transformed version. The data object will be transformed before
31
- every access. (default: :obj:`None`)
32
- pre_transform (callable, optional): A function/transform that takes in
33
- an :obj:`torch_geometric.data.HeteroData` object and returns a
31
+ every access.
32
+ pre_transform: A function/transform that takes in a
33
+ :class:`torch_geometric.data.HeteroData` object and returns a
34
34
  transformed version. The data object will be transformed before
35
- being saved to disk. (default: :obj:`None`)
36
- force_reload (bool, optional): Whether to re-process the dataset.
37
- (default: :obj:`False`)
35
+ being saved to disk.
36
+ force_reload: Whether to re-process the dataset.
38
37
  """
39
38
 
40
39
  url = 'https://www.dropbox.com/s/1bnz8r7mofx0osf/net_aminer.zip?dl=1'
@@ -30,25 +30,22 @@ class AQSOL(InMemoryDataset):
30
30
  the :class:`~torch_geometric.datasets.ZINC` dataset.
31
31
 
32
32
  Args:
33
- root (str): Root directory where the dataset should be saved.
34
- split (str, optional): If :obj:`"train"`, loads the training dataset.
33
+ root: Root directory where the dataset should be saved.
34
+ split: If :obj:`"train"`, loads the training dataset.
35
35
  If :obj:`"val"`, loads the validation dataset.
36
36
  If :obj:`"test"`, loads the test dataset.
37
- (default: :obj:`"train"`)
38
- transform (callable, optional): A function/transform that takes in an
39
- :obj:`torch_geometric.data.Data` object and returns a transformed
37
+ transform: A function/transform that takes in a
38
+ :class:`torch_geometric.data.Data` object and returns a transformed
40
39
  version. The data object will be transformed before every access.
41
- (default: :obj:`None`)
42
- pre_transform (callable, optional): A function/transform that takes in
43
- an :obj:`torch_geometric.data.Data` object and returns a
40
+ pre_transform: A function/transform that takes in a
41
+ :class:`torch_geometric.data.Data` object and returns a
44
42
  transformed version. The data object will be transformed before
45
- being saved to disk. (default: :obj:`None`)
43
+ being saved to disk.
46
44
  pre_filter (callable, optional): A function that takes in an
47
- :obj:`torch_geometric.data.Data` object and returns a boolean
45
+ :class:`torch_geometric.data.Data` object and returns a boolean
48
46
  value, indicating whether the data object should be included in
49
- the final dataset. (default: :obj:`None`)
50
- force_reload (bool, optional): Whether to re-process the dataset.
51
- (default: :obj:`False`)
47
+ the final dataset.
48
+ force_reload: Whether to re-process the dataset.
52
49
 
53
50
  **STATS:**
54
51
 
@@ -2,7 +2,6 @@ import os
2
2
  import os.path as osp
3
3
  from typing import Callable, List, Optional
4
4
 
5
- import scipy.sparse as sp
6
5
  import torch
7
6
 
8
7
  from torch_geometric.data import (
@@ -20,21 +19,19 @@ class AttributedGraphDataset(InMemoryDataset):
20
19
  <https://arxiv.org/abs/2009.00826>`_ paper.
21
20
 
22
21
  Args:
23
- root (str): Root directory where the dataset should be saved.
24
- name (str): The name of the dataset (:obj:`"Wiki"`, :obj:`"Cora"`
22
+ root: Root directory where the dataset should be saved.
23
+ name: The name of the dataset (:obj:`"Wiki"`, :obj:`"Cora"`,
25
24
  :obj:`"CiteSeer"`, :obj:`"PubMed"`, :obj:`"BlogCatalog"`,
26
25
  :obj:`"PPI"`, :obj:`"Flickr"`, :obj:`"Facebook"`, :obj:`"Twitter"`,
27
26
  :obj:`"TWeibo"`, :obj:`"MAG"`).
28
- transform (callable, optional): A function/transform that takes in an
29
- :obj:`torch_geometric.data.Data` object and returns a transformed
27
+ transform: A function/transform that takes in a
28
+ :class:`torch_geometric.data.Data` object and returns a transformed
30
29
  version. The data object will be transformed before every access.
31
- (default: :obj:`None`)
32
- pre_transform (callable, optional): A function/transform that takes in
33
- an :obj:`torch_geometric.data.Data` object and returns a
30
+ pre_transform: A function/transform that takes in a
31
+ :class:`torch_geometric.data.Data` object and returns a
34
32
  transformed version. The data object will be transformed before
35
- being saved to disk. (default: :obj:`None`)
36
- force_reload (bool, optional): Whether to re-process the dataset.
37
- (default: :obj:`False`)
33
+ being saved to disk.
34
+ force_reload: Whether to re-process the dataset.
38
35
 
39
36
  **STATS:**
40
37
 
@@ -156,6 +153,7 @@ class AttributedGraphDataset(InMemoryDataset):
156
153
 
157
154
  def process(self) -> None:
158
155
  import pandas as pd
156
+ import scipy.sparse as sp
159
157
 
160
158
  x = sp.load_npz(self.raw_paths[0]).tocsr()
161
159
  if x.shape[-1] > 10000 or self.name == 'mag':
@@ -172,7 +170,7 @@ class AttributedGraphDataset(InMemoryDataset):
172
170
  engine='python')
173
171
  edge_index = torch.from_numpy(df.values).t().contiguous()
174
172
 
175
- with open(self.raw_paths[2], 'r') as f:
173
+ with open(self.raw_paths[2]) as f:
176
174
  rows = f.read().split('\n')[:-1]
177
175
  ys = [[int(y) - 1 for y in row.split()[1:]] for row in rows]
178
176
  multilabel = max([len(y) for y in ys]) > 1
@@ -25,21 +25,19 @@ class BAMultiShapesDataset(InMemoryDataset):
25
25
  This dataset is pre-computed from the official implementation.
26
26
 
27
27
  Args:
28
- root (str): Root directory where the dataset should be saved.
29
- transform (callable, optional): A function/transform that takes in an
30
- :obj:`torch_geometric.data.Data` object and returns a transformed
28
+ root: Root directory where the dataset should be saved.
29
+ transform: A function/transform that takes in a
30
+ :class:`torch_geometric.data.Data` object and returns a transformed
31
31
  version. The data object will be transformed before every access.
32
- (default: :obj:`None`)
33
- pre_transform (callable, optional): A function/transform that takes in
34
- an :obj:`torch_geometric.data.Data` object and returns a
32
+ pre_transform: A function/transform that takes in a
33
+ :class:`torch_geometric.data.Data` object and returns a
35
34
  transformed version. The data object will be transformed before
36
- being saved to disk. (default: :obj:`None`)
37
- pre_filter (callable, optional): A function that takes in an
38
- :obj:`torch_geometric.data.Data` object and returns a boolean
35
+ being saved to disk.
36
+ pre_filter: A function that takes in a
37
+ :class:`torch_geometric.data.Data` object and returns a boolean
39
38
  value, indicating whether the data object should be included in the
40
- final dataset. (default: :obj:`None`)
41
- force_reload (bool, optional): Whether to re-process the dataset.
42
- (default: :obj:`False`)
39
+ final dataset.
40
+ force_reload: Whether to re-process the dataset.
43
41
 
44
42
  **STATS:**
45
43