pyg-nightly 2.6.0.dev20240318__py3-none-any.whl → 2.7.0.dev20250115__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- {pyg_nightly-2.6.0.dev20240318.dist-info → pyg_nightly-2.7.0.dev20250115.dist-info}/METADATA +31 -47
- {pyg_nightly-2.6.0.dev20240318.dist-info → pyg_nightly-2.7.0.dev20250115.dist-info}/RECORD +226 -199
- {pyg_nightly-2.6.0.dev20240318.dist-info → pyg_nightly-2.7.0.dev20250115.dist-info}/WHEEL +1 -1
- torch_geometric/__init__.py +28 -1
- torch_geometric/_compile.py +8 -1
- torch_geometric/_onnx.py +14 -0
- torch_geometric/config_mixin.py +113 -0
- torch_geometric/config_store.py +28 -19
- torch_geometric/data/__init__.py +24 -1
- torch_geometric/data/batch.py +2 -2
- torch_geometric/data/collate.py +8 -2
- torch_geometric/data/data.py +16 -8
- torch_geometric/data/database.py +61 -15
- torch_geometric/data/dataset.py +14 -6
- torch_geometric/data/feature_store.py +25 -42
- torch_geometric/data/graph_store.py +1 -5
- torch_geometric/data/hetero_data.py +18 -9
- torch_geometric/data/in_memory_dataset.py +2 -4
- torch_geometric/data/large_graph_indexer.py +677 -0
- torch_geometric/data/lightning/datamodule.py +4 -4
- torch_geometric/data/separate.py +6 -1
- torch_geometric/data/storage.py +17 -7
- torch_geometric/data/summary.py +14 -4
- torch_geometric/data/temporal.py +1 -2
- torch_geometric/datasets/__init__.py +17 -2
- torch_geometric/datasets/actor.py +9 -11
- torch_geometric/datasets/airfrans.py +15 -18
- torch_geometric/datasets/airports.py +10 -12
- torch_geometric/datasets/amazon.py +8 -11
- torch_geometric/datasets/amazon_book.py +9 -10
- torch_geometric/datasets/amazon_products.py +9 -10
- torch_geometric/datasets/aminer.py +8 -9
- torch_geometric/datasets/aqsol.py +10 -13
- torch_geometric/datasets/attributed_graph_dataset.py +10 -12
- torch_geometric/datasets/ba_multi_shapes.py +10 -12
- torch_geometric/datasets/ba_shapes.py +5 -6
- torch_geometric/datasets/bitcoin_otc.py +1 -1
- torch_geometric/datasets/brca_tgca.py +1 -1
- torch_geometric/datasets/cornell.py +145 -0
- torch_geometric/datasets/dblp.py +2 -1
- torch_geometric/datasets/dbp15k.py +2 -2
- torch_geometric/datasets/fake.py +1 -3
- torch_geometric/datasets/flickr.py +2 -1
- torch_geometric/datasets/freebase.py +1 -1
- torch_geometric/datasets/gdelt_lite.py +3 -2
- torch_geometric/datasets/ged_dataset.py +3 -2
- torch_geometric/datasets/git_mol_dataset.py +263 -0
- torch_geometric/datasets/gnn_benchmark_dataset.py +11 -10
- torch_geometric/datasets/hgb_dataset.py +8 -8
- torch_geometric/datasets/imdb.py +2 -1
- torch_geometric/datasets/karate.py +3 -2
- torch_geometric/datasets/last_fm.py +2 -1
- torch_geometric/datasets/linkx_dataset.py +4 -3
- torch_geometric/datasets/lrgb.py +3 -5
- torch_geometric/datasets/malnet_tiny.py +4 -3
- torch_geometric/datasets/mnist_superpixels.py +2 -3
- torch_geometric/datasets/molecule_gpt_dataset.py +485 -0
- torch_geometric/datasets/molecule_net.py +15 -3
- torch_geometric/datasets/motif_generator/base.py +0 -1
- torch_geometric/datasets/neurograph.py +1 -3
- torch_geometric/datasets/ogb_mag.py +1 -1
- torch_geometric/datasets/opf.py +239 -0
- torch_geometric/datasets/ose_gvcs.py +1 -1
- torch_geometric/datasets/pascal.py +11 -9
- torch_geometric/datasets/pascal_pf.py +1 -1
- torch_geometric/datasets/pcpnet_dataset.py +1 -1
- torch_geometric/datasets/pcqm4m.py +10 -3
- torch_geometric/datasets/ppi.py +1 -1
- torch_geometric/datasets/qm9.py +8 -7
- torch_geometric/datasets/rcdd.py +4 -4
- torch_geometric/datasets/reddit.py +2 -1
- torch_geometric/datasets/reddit2.py +2 -1
- torch_geometric/datasets/rel_link_pred_dataset.py +3 -3
- torch_geometric/datasets/s3dis.py +5 -3
- torch_geometric/datasets/shapenet.py +3 -3
- torch_geometric/datasets/shrec2016.py +2 -2
- torch_geometric/datasets/snap_dataset.py +7 -1
- torch_geometric/datasets/tag_dataset.py +350 -0
- torch_geometric/datasets/upfd.py +2 -1
- torch_geometric/datasets/web_qsp_dataset.py +246 -0
- torch_geometric/datasets/webkb.py +2 -2
- torch_geometric/datasets/wikics.py +1 -1
- torch_geometric/datasets/wikidata.py +3 -2
- torch_geometric/datasets/wikipedia_network.py +2 -2
- torch_geometric/datasets/willow_object_class.py +1 -1
- torch_geometric/datasets/word_net.py +2 -2
- torch_geometric/datasets/yelp.py +2 -1
- torch_geometric/datasets/zinc.py +1 -1
- torch_geometric/device.py +42 -0
- torch_geometric/distributed/local_feature_store.py +3 -2
- torch_geometric/distributed/local_graph_store.py +2 -1
- torch_geometric/distributed/partition.py +9 -8
- torch_geometric/edge_index.py +616 -438
- torch_geometric/explain/algorithm/base.py +0 -1
- torch_geometric/explain/algorithm/graphmask_explainer.py +1 -2
- torch_geometric/explain/algorithm/pg_explainer.py +1 -1
- torch_geometric/explain/explanation.py +2 -2
- torch_geometric/graphgym/checkpoint.py +2 -1
- torch_geometric/graphgym/logger.py +4 -4
- torch_geometric/graphgym/loss.py +1 -1
- torch_geometric/graphgym/utils/agg_runs.py +6 -6
- torch_geometric/index.py +826 -0
- torch_geometric/inspector.py +13 -7
- torch_geometric/io/fs.py +28 -2
- torch_geometric/io/npz.py +2 -1
- torch_geometric/io/off.py +2 -2
- torch_geometric/io/sdf.py +2 -2
- torch_geometric/io/tu.py +4 -5
- torch_geometric/loader/__init__.py +4 -0
- torch_geometric/loader/cluster.py +10 -4
- torch_geometric/loader/graph_saint.py +2 -1
- torch_geometric/loader/ibmb_loader.py +12 -4
- torch_geometric/loader/mixin.py +1 -1
- torch_geometric/loader/neighbor_loader.py +1 -1
- torch_geometric/loader/neighbor_sampler.py +2 -2
- torch_geometric/loader/prefetch.py +1 -1
- torch_geometric/loader/rag_loader.py +107 -0
- torch_geometric/loader/utils.py +8 -7
- torch_geometric/loader/zip_loader.py +10 -0
- torch_geometric/metrics/__init__.py +11 -2
- torch_geometric/metrics/link_pred.py +317 -65
- torch_geometric/nn/aggr/__init__.py +4 -0
- torch_geometric/nn/aggr/attention.py +0 -2
- torch_geometric/nn/aggr/base.py +3 -5
- torch_geometric/nn/aggr/patch_transformer.py +143 -0
- torch_geometric/nn/aggr/set_transformer.py +1 -1
- torch_geometric/nn/aggr/variance_preserving.py +33 -0
- torch_geometric/nn/attention/__init__.py +5 -1
- torch_geometric/nn/attention/qformer.py +71 -0
- torch_geometric/nn/conv/collect.jinja +7 -4
- torch_geometric/nn/conv/cugraph/base.py +8 -12
- torch_geometric/nn/conv/edge_conv.py +3 -2
- torch_geometric/nn/conv/fused_gat_conv.py +1 -1
- torch_geometric/nn/conv/gat_conv.py +35 -7
- torch_geometric/nn/conv/gatv2_conv.py +36 -6
- torch_geometric/nn/conv/general_conv.py +1 -1
- torch_geometric/nn/conv/graph_conv.py +21 -3
- torch_geometric/nn/conv/gravnet_conv.py +3 -2
- torch_geometric/nn/conv/hetero_conv.py +3 -3
- torch_geometric/nn/conv/hgt_conv.py +1 -1
- torch_geometric/nn/conv/message_passing.py +138 -87
- torch_geometric/nn/conv/mixhop_conv.py +1 -1
- torch_geometric/nn/conv/propagate.jinja +9 -1
- torch_geometric/nn/conv/rgcn_conv.py +5 -5
- torch_geometric/nn/conv/spline_conv.py +4 -4
- torch_geometric/nn/conv/x_conv.py +3 -2
- torch_geometric/nn/dense/linear.py +11 -6
- torch_geometric/nn/fx.py +3 -3
- torch_geometric/nn/model_hub.py +3 -1
- torch_geometric/nn/models/__init__.py +10 -2
- torch_geometric/nn/models/deep_graph_infomax.py +1 -2
- torch_geometric/nn/models/dimenet_utils.py +5 -7
- torch_geometric/nn/models/g_retriever.py +230 -0
- torch_geometric/nn/models/git_mol.py +336 -0
- torch_geometric/nn/models/glem.py +385 -0
- torch_geometric/nn/models/gnnff.py +0 -1
- torch_geometric/nn/models/graph_unet.py +12 -3
- torch_geometric/nn/models/jumping_knowledge.py +63 -4
- torch_geometric/nn/models/lightgcn.py +1 -1
- torch_geometric/nn/models/metapath2vec.py +5 -5
- torch_geometric/nn/models/molecule_gpt.py +222 -0
- torch_geometric/nn/models/node2vec.py +2 -3
- torch_geometric/nn/models/schnet.py +2 -1
- torch_geometric/nn/models/signed_gcn.py +3 -3
- torch_geometric/nn/module_dict.py +2 -2
- torch_geometric/nn/nlp/__init__.py +9 -0
- torch_geometric/nn/nlp/llm.py +329 -0
- torch_geometric/nn/nlp/sentence_transformer.py +134 -0
- torch_geometric/nn/nlp/vision_transformer.py +33 -0
- torch_geometric/nn/norm/batch_norm.py +1 -1
- torch_geometric/nn/parameter_dict.py +2 -2
- torch_geometric/nn/pool/__init__.py +21 -5
- torch_geometric/nn/pool/cluster_pool.py +145 -0
- torch_geometric/nn/pool/connect/base.py +0 -1
- torch_geometric/nn/pool/edge_pool.py +1 -1
- torch_geometric/nn/pool/graclus.py +4 -2
- torch_geometric/nn/pool/pool.py +8 -2
- torch_geometric/nn/pool/select/base.py +0 -1
- torch_geometric/nn/pool/voxel_grid.py +3 -2
- torch_geometric/nn/resolver.py +1 -1
- torch_geometric/nn/sequential.jinja +10 -23
- torch_geometric/nn/sequential.py +204 -78
- torch_geometric/nn/summary.py +1 -1
- torch_geometric/nn/to_hetero_with_bases_transformer.py +19 -19
- torch_geometric/profile/__init__.py +2 -0
- torch_geometric/profile/nvtx.py +66 -0
- torch_geometric/profile/profiler.py +30 -19
- torch_geometric/resolver.py +1 -1
- torch_geometric/sampler/base.py +34 -13
- torch_geometric/sampler/neighbor_sampler.py +11 -10
- torch_geometric/sampler/utils.py +1 -1
- torch_geometric/template.py +1 -0
- torch_geometric/testing/__init__.py +6 -2
- torch_geometric/testing/decorators.py +56 -22
- torch_geometric/testing/feature_store.py +1 -1
- torch_geometric/transforms/__init__.py +2 -0
- torch_geometric/transforms/add_metapaths.py +5 -5
- torch_geometric/transforms/add_positional_encoding.py +1 -1
- torch_geometric/transforms/delaunay.py +65 -14
- torch_geometric/transforms/face_to_edge.py +32 -3
- torch_geometric/transforms/gdc.py +7 -6
- torch_geometric/transforms/laplacian_lambda_max.py +3 -3
- torch_geometric/transforms/mask.py +5 -1
- torch_geometric/transforms/node_property_split.py +1 -2
- torch_geometric/transforms/pad.py +7 -6
- torch_geometric/transforms/random_link_split.py +1 -1
- torch_geometric/transforms/remove_self_loops.py +36 -0
- torch_geometric/transforms/svd_feature_reduction.py +1 -1
- torch_geometric/transforms/to_sparse_tensor.py +1 -1
- torch_geometric/transforms/two_hop.py +1 -1
- torch_geometric/transforms/virtual_node.py +2 -1
- torch_geometric/typing.py +43 -6
- torch_geometric/utils/__init__.py +5 -1
- torch_geometric/utils/_negative_sampling.py +1 -1
- torch_geometric/utils/_normalize_edge_index.py +46 -0
- torch_geometric/utils/_scatter.py +38 -12
- torch_geometric/utils/_subgraph.py +4 -0
- torch_geometric/utils/_tree_decomposition.py +2 -2
- torch_geometric/utils/augmentation.py +1 -1
- torch_geometric/utils/convert.py +12 -8
- torch_geometric/utils/geodesic.py +24 -22
- torch_geometric/utils/hetero.py +1 -1
- torch_geometric/utils/map.py +8 -2
- torch_geometric/utils/smiles.py +65 -27
- torch_geometric/utils/sparse.py +39 -25
- torch_geometric/visualization/graph.py +3 -4
@@ -13,7 +13,7 @@ from torch_geometric.typing import InputEdges, InputNodes, OptTensor
|
|
13
13
|
try:
|
14
14
|
from pytorch_lightning import LightningDataModule as PLLightningDataModule
|
15
15
|
no_pytorch_lightning = False
|
16
|
-
except
|
16
|
+
except ImportError:
|
17
17
|
PLLightningDataModule = object # type: ignore
|
18
18
|
no_pytorch_lightning = True
|
19
19
|
|
@@ -221,7 +221,7 @@ class LightningDataset(LightningDataModule):
|
|
221
221
|
speed.html>`__ are supported in order to correctly share data across
|
222
222
|
all devices/processes:
|
223
223
|
|
224
|
-
.. code-block::
|
224
|
+
.. code-block:: python
|
225
225
|
|
226
226
|
import pytorch_lightning as pl
|
227
227
|
trainer = pl.Trainer(strategy="ddp_spawn", accelerator="gpu",
|
@@ -332,7 +332,7 @@ class LightningNodeData(LightningData):
|
|
332
332
|
speed.html>`__ are supported in order to correctly share data across
|
333
333
|
all devices/processes:
|
334
334
|
|
335
|
-
.. code-block::
|
335
|
+
.. code-block:: python
|
336
336
|
|
337
337
|
import pytorch_lightning as pl
|
338
338
|
trainer = pl.Trainer(strategy="ddp_spawn", accelerator="gpu",
|
@@ -525,7 +525,7 @@ class LightningLinkData(LightningData):
|
|
525
525
|
speed.html>`__ are supported in order to correctly share data across
|
526
526
|
all devices/processes:
|
527
527
|
|
528
|
-
.. code-block::
|
528
|
+
.. code-block:: python
|
529
529
|
|
530
530
|
import pytorch_lightning as pl
|
531
531
|
trainer = pl.Trainer(strategy="ddp_spawn", accelerator="gpu",
|
torch_geometric/data/separate.py
CHANGED
@@ -3,7 +3,7 @@ from typing import Any, Type, TypeVar
|
|
3
3
|
|
4
4
|
from torch import Tensor
|
5
5
|
|
6
|
-
from torch_geometric import EdgeIndex
|
6
|
+
from torch_geometric import EdgeIndex, Index
|
7
7
|
from torch_geometric.data.data import BaseData
|
8
8
|
from torch_geometric.data.storage import BaseStorage
|
9
9
|
from torch_geometric.typing import SparseTensor, TensorFrame
|
@@ -76,6 +76,11 @@ def _separate(
|
|
76
76
|
value = narrow(values, cat_dim or 0, start, end - start)
|
77
77
|
value = value.squeeze(0) if cat_dim is None else value
|
78
78
|
|
79
|
+
if isinstance(values, Index) and values._cat_metadata is not None:
|
80
|
+
# Reconstruct original `Index` metadata:
|
81
|
+
value._dim_size = values._cat_metadata.dim_size[idx]
|
82
|
+
value._is_sorted = values._cat_metadata.is_sorted[idx]
|
83
|
+
|
79
84
|
if isinstance(values, EdgeIndex) and values._cat_metadata is not None:
|
80
85
|
# Reconstruct original `EdgeIndex` metadata:
|
81
86
|
value._sparse_size = values._cat_metadata.sparse_size[idx]
|
torch_geometric/data/storage.py
CHANGED
@@ -370,18 +370,20 @@ class BaseStorage(MutableMapping):
|
|
370
370
|
self,
|
371
371
|
start_time: Union[float, int],
|
372
372
|
end_time: Union[float, int],
|
373
|
+
attr: str = 'time',
|
373
374
|
) -> Self:
|
374
|
-
if
|
375
|
-
|
375
|
+
if attr in self:
|
376
|
+
time = self[attr]
|
377
|
+
mask = (time >= start_time) & (time <= end_time)
|
376
378
|
|
377
|
-
if self.is_node_attr(
|
379
|
+
if self.is_node_attr(attr):
|
378
380
|
keys = self.node_attrs()
|
379
|
-
elif self.is_edge_attr(
|
381
|
+
elif self.is_edge_attr(attr):
|
380
382
|
keys = self.edge_attrs()
|
381
383
|
|
382
384
|
self._select(keys, mask)
|
383
385
|
|
384
|
-
if self.is_node_attr(
|
386
|
+
if self.is_node_attr(attr) and 'num_nodes' in self:
|
385
387
|
self.num_nodes: Optional[int] = int(mask.sum())
|
386
388
|
|
387
389
|
return self
|
@@ -443,9 +445,9 @@ class NodeStorage(BaseStorage):
|
|
443
445
|
return self.edge_index.sparse_size(0)
|
444
446
|
if self.edge_index.sparse_size(1) is not None:
|
445
447
|
return self.edge_index.sparse_size(1)
|
446
|
-
if 'adj' in self and isinstance(self.adj, SparseTensor):
|
448
|
+
if 'adj' in self and isinstance(self.adj, (Tensor, SparseTensor)):
|
447
449
|
return self.adj.size(0)
|
448
|
-
if 'adj_t' in self and isinstance(self.adj_t, SparseTensor):
|
450
|
+
if 'adj_t' in self and isinstance(self.adj_t, (Tensor, SparseTensor)):
|
449
451
|
return self.adj_t.size(1)
|
450
452
|
warnings.warn(
|
451
453
|
f"Unable to accurately infer 'num_nodes' from the attribute set "
|
@@ -804,6 +806,10 @@ class GlobalStorage(NodeStorage, EdgeStorage):
|
|
804
806
|
return False
|
805
807
|
|
806
808
|
cat_dim = self._parent().__cat_dim__(key, value, self)
|
809
|
+
|
810
|
+
if not isinstance(cat_dim, int):
|
811
|
+
return False
|
812
|
+
|
807
813
|
num_nodes, num_edges = self.num_nodes, self.num_edges
|
808
814
|
|
809
815
|
if value.shape[cat_dim] != num_nodes:
|
@@ -850,6 +856,10 @@ class GlobalStorage(NodeStorage, EdgeStorage):
|
|
850
856
|
return False
|
851
857
|
|
852
858
|
cat_dim = self._parent().__cat_dim__(key, value, self)
|
859
|
+
|
860
|
+
if not isinstance(cat_dim, int):
|
861
|
+
return False
|
862
|
+
|
853
863
|
num_nodes, num_edges = self.num_nodes, self.num_edges
|
854
864
|
|
855
865
|
if value.shape[cat_dim] != num_edges:
|
torch_geometric/data/summary.py
CHANGED
@@ -117,7 +117,14 @@ class Summary:
|
|
117
117
|
num_edges_per_type=num_edges_per_type,
|
118
118
|
)
|
119
119
|
|
120
|
-
def
|
120
|
+
def format(self, fmt: str = "psql") -> str:
|
121
|
+
r"""Formats summary statistics of the dataset.
|
122
|
+
|
123
|
+
Args:
|
124
|
+
fmt (str, optional): Summary tables format. Available table formats
|
125
|
+
can be found `here <https://github.com/astanin/python-tabulate?
|
126
|
+
tab=readme-ov-file#table-format>`__. (default: :obj:`"psql"`)
|
127
|
+
"""
|
121
128
|
from tabulate import tabulate
|
122
129
|
|
123
130
|
body = f'{self.name} (#graphs={self.num_graphs}):\n'
|
@@ -127,7 +134,7 @@ class Summary:
|
|
127
134
|
for field in Stats.__dataclass_fields__:
|
128
135
|
row = [field] + [f'{getattr(s, field):.1f}' for s in stats]
|
129
136
|
content.append(row)
|
130
|
-
body += tabulate(content, headers='firstrow', tablefmt=
|
137
|
+
body += tabulate(content, headers='firstrow', tablefmt=fmt)
|
131
138
|
|
132
139
|
if self.num_nodes_per_type is not None:
|
133
140
|
content = [['']]
|
@@ -140,7 +147,7 @@ class Summary:
|
|
140
147
|
]
|
141
148
|
content.append(row)
|
142
149
|
body += "\nNumber of nodes per node type:\n"
|
143
|
-
body += tabulate(content, headers='firstrow', tablefmt=
|
150
|
+
body += tabulate(content, headers='firstrow', tablefmt=fmt)
|
144
151
|
|
145
152
|
if self.num_edges_per_type is not None:
|
146
153
|
content = [['']]
|
@@ -156,6 +163,9 @@ class Summary:
|
|
156
163
|
]
|
157
164
|
content.append(row)
|
158
165
|
body += "\nNumber of edges per edge type:\n"
|
159
|
-
body += tabulate(content, headers='firstrow', tablefmt=
|
166
|
+
body += tabulate(content, headers='firstrow', tablefmt=fmt)
|
160
167
|
|
161
168
|
return body
|
169
|
+
|
170
|
+
def __repr__(self) -> str:
|
171
|
+
return self.format()
|
torch_geometric/data/temporal.py
CHANGED
@@ -156,8 +156,7 @@ class TemporalData(BaseData):
|
|
156
156
|
return self.num_events
|
157
157
|
|
158
158
|
def __call__(self, *args: List[str]) -> Iterable:
|
159
|
-
|
160
|
-
yield key, value
|
159
|
+
yield from self._store.items(*args)
|
161
160
|
|
162
161
|
def __copy__(self):
|
163
162
|
out = self.__class__.__new__(self.__class__)
|
@@ -61,7 +61,6 @@ from .gemsec import GemsecDeezer
|
|
61
61
|
from .twitch import Twitch
|
62
62
|
from .airports import Airports
|
63
63
|
from .lrgb import LRGBDataset
|
64
|
-
from .neurograph import NeuroGraphDataset
|
65
64
|
from .malnet_tiny import MalNetTiny
|
66
65
|
from .omdb import OMDB
|
67
66
|
from .polblogs import PolBlogs
|
@@ -76,6 +75,11 @@ from .jodie import JODIEDataset
|
|
76
75
|
from .wikidata import Wikidata5M
|
77
76
|
from .myket import MyketDataset
|
78
77
|
from .brca_tgca import BrcaTcga
|
78
|
+
from .neurograph import NeuroGraphDataset
|
79
|
+
from .web_qsp_dataset import WebQSPDataset
|
80
|
+
from .git_mol_dataset import GitMolDataset
|
81
|
+
from .molecule_gpt_dataset import MoleculeGPTDataset
|
82
|
+
from .tag_dataset import TAGDataset
|
79
83
|
|
80
84
|
from .dbp15k import DBP15K
|
81
85
|
from .aminer import AMiner
|
@@ -93,6 +97,9 @@ from .amazon_book import AmazonBook
|
|
93
97
|
from .hm import HM
|
94
98
|
from .ose_gvcs import OSE_GVCS
|
95
99
|
from .rcdd import RCDD
|
100
|
+
from .opf import OPFDataset
|
101
|
+
|
102
|
+
from .cornell import CornellTemporalHyperGraphDataset
|
96
103
|
|
97
104
|
from .fake import FakeDataset, FakeHeteroDataset
|
98
105
|
from .sbm_dataset import StochasticBlockModelDataset
|
@@ -185,6 +192,10 @@ homo_datasets = [
|
|
185
192
|
'MyketDataset',
|
186
193
|
'BrcaTcga',
|
187
194
|
'NeuroGraphDataset',
|
195
|
+
'WebQSPDataset',
|
196
|
+
'GitMolDataset',
|
197
|
+
'MoleculeGPTDataset',
|
198
|
+
'TAGDataset',
|
188
199
|
]
|
189
200
|
|
190
201
|
hetero_datasets = [
|
@@ -204,6 +215,10 @@ hetero_datasets = [
|
|
204
215
|
'HM',
|
205
216
|
'OSE_GVCS',
|
206
217
|
'RCDD',
|
218
|
+
'OPFDataset',
|
219
|
+
]
|
220
|
+
hyper_datasets = [
|
221
|
+
'CornellTemporalHyperGraphDataset',
|
207
222
|
]
|
208
223
|
synthetic_datasets = [
|
209
224
|
'FakeDataset',
|
@@ -218,4 +233,4 @@ synthetic_datasets = [
|
|
218
233
|
'BAShapes',
|
219
234
|
]
|
220
235
|
|
221
|
-
__all__ = homo_datasets + hetero_datasets + synthetic_datasets
|
236
|
+
__all__ = homo_datasets + hetero_datasets + hyper_datasets + synthetic_datasets
|
@@ -19,17 +19,15 @@ class Actor(InMemoryDataset):
|
|
19
19
|
actor's Wikipedia.
|
20
20
|
|
21
21
|
Args:
|
22
|
-
root
|
23
|
-
transform
|
22
|
+
root: Root directory where the dataset should be saved.
|
23
|
+
transform: A function/transform that takes in an
|
24
24
|
:obj:`torch_geometric.data.Data` object and returns a transformed
|
25
25
|
version. The data object will be transformed before every access.
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
force_reload (bool, optional): Whether to re-process the dataset.
|
32
|
-
(default: :obj:`False`)
|
26
|
+
pre_transform: A function/transform that takes in an
|
27
|
+
:class:`torch_geometric.data.Data` object and returns a transformed
|
28
|
+
version. The data object will be transformed before being saved to
|
29
|
+
disk.
|
30
|
+
force_reload: Whether to re-process the dataset.
|
33
31
|
|
34
32
|
**STATS:**
|
35
33
|
|
@@ -76,7 +74,7 @@ class Actor(InMemoryDataset):
|
|
76
74
|
download_url(f'{self.url}/splits/{f}', self.raw_dir)
|
77
75
|
|
78
76
|
def process(self) -> None:
|
79
|
-
with open(self.raw_paths[0]
|
77
|
+
with open(self.raw_paths[0]) as f:
|
80
78
|
node_data = [x.split('\t') for x in f.read().split('\n')[1:-1]]
|
81
79
|
|
82
80
|
rows, cols = [], []
|
@@ -93,7 +91,7 @@ class Actor(InMemoryDataset):
|
|
93
91
|
for n_id, _, label in node_data:
|
94
92
|
y[int(n_id)] = int(label)
|
95
93
|
|
96
|
-
with open(self.raw_paths[1]
|
94
|
+
with open(self.raw_paths[1]) as f:
|
97
95
|
edge_data = f.read().split('\n')[1:-1]
|
98
96
|
edge_indices = [[int(v) for v in r.split('\t')] for r in edge_data]
|
99
97
|
edge_index = torch.tensor(edge_indices).t().contiguous()
|
@@ -2,14 +2,13 @@ import json
|
|
2
2
|
import os
|
3
3
|
from typing import Callable, List, Optional
|
4
4
|
|
5
|
-
import torch
|
6
|
-
|
7
5
|
from torch_geometric.data import (
|
8
6
|
Data,
|
9
7
|
InMemoryDataset,
|
10
8
|
download_url,
|
11
9
|
extract_zip,
|
12
10
|
)
|
11
|
+
from torch_geometric.io import fs
|
13
12
|
|
14
13
|
|
15
14
|
class AirfRANS(InMemoryDataset):
|
@@ -47,26 +46,24 @@ class AirfRANS(InMemoryDataset):
|
|
47
46
|
:obj:`torch_geometric.transforms.RadiusGraph` transform.
|
48
47
|
|
49
48
|
Args:
|
50
|
-
root
|
51
|
-
task
|
49
|
+
root: Root directory where the dataset should be saved.
|
50
|
+
task: The task to study (:obj:`"full"`, :obj:`"scarce"`,
|
52
51
|
:obj:`"reynolds"`, :obj:`"aoa"`) that defines the utilized training
|
53
52
|
and test splits.
|
54
|
-
train
|
55
|
-
|
56
|
-
transform
|
57
|
-
:
|
53
|
+
train: If :obj:`True`, loads the training dataset, otherwise the test
|
54
|
+
dataset.
|
55
|
+
transform: A function/transform that takes in an
|
56
|
+
:class:`torch_geometric.data.Data` object and returns a transformed
|
58
57
|
version. The data object will be transformed before every access.
|
59
|
-
|
60
|
-
|
61
|
-
an :obj:`torch_geometric.data.Data` object and returns a
|
58
|
+
pre_transform: A function/transform that takes in an
|
59
|
+
:class:`torch_geometric.data.Data` object and returns a
|
62
60
|
transformed version. The data object will be transformed before
|
63
|
-
being saved to disk.
|
64
|
-
pre_filter
|
61
|
+
being saved to disk.
|
62
|
+
pre_filter: A function that takes in an
|
65
63
|
:obj:`torch_geometric.data.Data` object and returns a boolean
|
66
64
|
value, indicating whether the data object should be included in the
|
67
|
-
final dataset.
|
68
|
-
force_reload
|
69
|
-
(default: :obj:`False`)
|
65
|
+
final dataset.
|
66
|
+
force_reload: Whether to re-process the dataset.
|
70
67
|
|
71
68
|
**STATS:**
|
72
69
|
|
@@ -123,13 +120,13 @@ class AirfRANS(InMemoryDataset):
|
|
123
120
|
os.unlink(path)
|
124
121
|
|
125
122
|
def process(self) -> None:
|
126
|
-
with open(self.raw_paths[1]
|
123
|
+
with open(self.raw_paths[1]) as f:
|
127
124
|
manifest = json.load(f)
|
128
125
|
total = manifest['full_train'] + manifest['full_test']
|
129
126
|
partial = set(manifest[f'{self.task}_{self.split}'])
|
130
127
|
|
131
128
|
data_list = []
|
132
|
-
raw_data =
|
129
|
+
raw_data = fs.torch_load(self.raw_paths[0])
|
133
130
|
for k, s in enumerate(total):
|
134
131
|
if s in partial:
|
135
132
|
data = Data(**raw_data[k])
|
@@ -14,22 +14,20 @@ class Airports(InMemoryDataset):
|
|
14
14
|
and labels correspond to activity levels.
|
15
15
|
Features are given by one-hot encoded node identifiers, as described in the
|
16
16
|
`"GraLSP: Graph Neural Networks with Local Structural Patterns"
|
17
|
-
|
17
|
+
<https://arxiv.org/abs/1911.07675>`_ paper.
|
18
18
|
|
19
19
|
Args:
|
20
|
-
root
|
21
|
-
name
|
20
|
+
root: Root directory where the dataset should be saved.
|
21
|
+
name: The name of the dataset (:obj:`"USA"`, :obj:`"Brazil"`,
|
22
22
|
:obj:`"Europe"`).
|
23
|
-
transform
|
24
|
-
:
|
23
|
+
transform: A function/transform that takes in an
|
24
|
+
:class:`torch_geometric.data.Data` object and returns a transformed
|
25
25
|
version. The data object will be transformed before every access.
|
26
|
-
(default: :obj:`None`)
|
27
26
|
pre_transform (callable, optional): A function/transform that takes in
|
28
|
-
|
27
|
+
:class:`torch_geometric.data.Data` object and returns a
|
29
28
|
transformed version. The data object will be transformed before
|
30
|
-
being saved to disk.
|
31
|
-
force_reload
|
32
|
-
(default: :obj:`False`)
|
29
|
+
being saved to disk.
|
30
|
+
force_reload: Whether to re-process the dataset.
|
33
31
|
"""
|
34
32
|
edge_url = ('https://github.com/leoribeiro/struc2vec/'
|
35
33
|
'raw/master/graph/{}-airports.edgelist')
|
@@ -75,7 +73,7 @@ class Airports(InMemoryDataset):
|
|
75
73
|
|
76
74
|
def process(self) -> None:
|
77
75
|
index_map, ys = {}, []
|
78
|
-
with open(self.raw_paths[1]
|
76
|
+
with open(self.raw_paths[1]) as f:
|
79
77
|
rows = f.read().split('\n')[1:-1]
|
80
78
|
for i, row in enumerate(rows):
|
81
79
|
idx, label = row.split()
|
@@ -85,7 +83,7 @@ class Airports(InMemoryDataset):
|
|
85
83
|
x = torch.eye(y.size(0))
|
86
84
|
|
87
85
|
edge_indices = []
|
88
|
-
with open(self.raw_paths[0]
|
86
|
+
with open(self.raw_paths[0]) as f:
|
89
87
|
rows = f.read().split('\n')[:-1]
|
90
88
|
for row in rows:
|
91
89
|
src, dst = row.split()
|
@@ -15,19 +15,16 @@ class Amazon(InMemoryDataset):
|
|
15
15
|
map goods to their respective product category.
|
16
16
|
|
17
17
|
Args:
|
18
|
-
root
|
19
|
-
name
|
20
|
-
|
21
|
-
|
22
|
-
:obj:`torch_geometric.data.Data` object and returns a transformed
|
18
|
+
root: Root directory where the dataset should be saved.
|
19
|
+
name: The name of the dataset (:obj:`"Computers"`, :obj:`"Photo"`).
|
20
|
+
transform: A function/transform that takes in a
|
21
|
+
:class:`torch_geometric.data.Data` object and returns a transformed
|
23
22
|
version. The data object will be transformed before every access.
|
24
|
-
|
25
|
-
|
26
|
-
an :obj:`torch_geometric.data.Data` object and returns a
|
23
|
+
pre_transform: A function/transform that takes in an
|
24
|
+
:class:`torch_geometric.data.Data` object and returns a
|
27
25
|
transformed version. The data object will be transformed before
|
28
|
-
being saved to disk.
|
29
|
-
force_reload
|
30
|
-
(default: :obj:`False`)
|
26
|
+
being saved to disk.
|
27
|
+
force_reload: Whether to re-process the dataset.
|
31
28
|
|
32
29
|
**STATS:**
|
33
30
|
|
@@ -14,17 +14,16 @@ class AmazonBook(InMemoryDataset):
|
|
14
14
|
No labels or features are provided.
|
15
15
|
|
16
16
|
Args:
|
17
|
-
root
|
18
|
-
transform
|
19
|
-
:
|
17
|
+
root: Root directory where the dataset should be saved.
|
18
|
+
transform: A function/transform that takes in an
|
19
|
+
:class:`torch_geometric.data.HeteroData` object and returns a
|
20
20
|
transformed version. The data object will be transformed before
|
21
|
-
every access.
|
22
|
-
pre_transform
|
23
|
-
|
21
|
+
every access.
|
22
|
+
pre_transform: A function/transform that takes in an
|
23
|
+
:class:`torch_geometric.data.HeteroData` object and returns a
|
24
24
|
transformed version. The data object will be transformed before
|
25
|
-
being saved to disk.
|
26
|
-
force_reload
|
27
|
-
(default: :obj:`False`)
|
25
|
+
being saved to disk.
|
26
|
+
force_reload: Whether to re-process the dataset.
|
28
27
|
"""
|
29
28
|
url = ('https://raw.githubusercontent.com/gusye1234/LightGCN-PyTorch/'
|
30
29
|
'master/data/amazon-book')
|
@@ -67,7 +66,7 @@ class AmazonBook(InMemoryDataset):
|
|
67
66
|
attr_names = ['edge_index', 'edge_label_index']
|
68
67
|
for path, attr_name in zip(self.raw_paths[2:], attr_names):
|
69
68
|
rows, cols = [], []
|
70
|
-
with open(path
|
69
|
+
with open(path) as f:
|
71
70
|
lines = f.readlines()
|
72
71
|
for line in lines:
|
73
72
|
indices = line.strip().split(' ')
|
@@ -3,7 +3,6 @@ import os.path as osp
|
|
3
3
|
from typing import Callable, List, Optional
|
4
4
|
|
5
5
|
import numpy as np
|
6
|
-
import scipy.sparse as sp
|
7
6
|
import torch
|
8
7
|
|
9
8
|
from torch_geometric.data import Data, InMemoryDataset, download_google_url
|
@@ -15,17 +14,15 @@ class AmazonProducts(InMemoryDataset):
|
|
15
14
|
containing products and its categories.
|
16
15
|
|
17
16
|
Args:
|
18
|
-
root
|
19
|
-
transform
|
20
|
-
:
|
17
|
+
root: Root directory where the dataset should be saved.
|
18
|
+
transform: A function/transform that takes in an
|
19
|
+
:class:`torch_geometric.data.Data` object and returns a transformed
|
21
20
|
version. The data object will be transformed before every access.
|
22
|
-
|
23
|
-
|
24
|
-
an :obj:`torch_geometric.data.Data` object and returns a
|
21
|
+
pre_transform: A function/transform that takes in a
|
22
|
+
:class:`torch_geometric.data.Data` object and returns a
|
25
23
|
transformed version. The data object will be transformed before
|
26
|
-
being saved to disk.
|
27
|
-
force_reload
|
28
|
-
(default: :obj:`False`)
|
24
|
+
being saved to disk.
|
25
|
+
force_reload: Whether to re-process the dataset.
|
29
26
|
|
30
27
|
**STATS:**
|
31
28
|
|
@@ -73,6 +70,8 @@ class AmazonProducts(InMemoryDataset):
|
|
73
70
|
download_google_url(self.role_id, self.raw_dir, 'role.json')
|
74
71
|
|
75
72
|
def process(self) -> None:
|
73
|
+
import scipy.sparse as sp
|
74
|
+
|
76
75
|
f = np.load(osp.join(self.raw_dir, 'adj_full.npz'))
|
77
76
|
adj = sp.csr_matrix((f['data'], f['indices'], f['indptr']), f['shape'])
|
78
77
|
adj = adj.tocoo()
|
@@ -24,17 +24,16 @@ class AMiner(InMemoryDataset):
|
|
24
24
|
truth labels for a subset of nodes.
|
25
25
|
|
26
26
|
Args:
|
27
|
-
root
|
28
|
-
transform
|
29
|
-
:
|
27
|
+
root: Root directory where the dataset should be saved.
|
28
|
+
transform: A function/transform that takes in a
|
29
|
+
:class:`torch_geometric.data.HeteroData` object and returns a
|
30
30
|
transformed version. The data object will be transformed before
|
31
|
-
every access.
|
32
|
-
pre_transform
|
33
|
-
|
31
|
+
every access.
|
32
|
+
pre_transform: A function/transform that takes in a
|
33
|
+
:class:`torch_geometric.data.HeteroData` object and returns a
|
34
34
|
transformed version. The data object will be transformed before
|
35
|
-
being saved to disk.
|
36
|
-
force_reload
|
37
|
-
(default: :obj:`False`)
|
35
|
+
being saved to disk.
|
36
|
+
force_reload: Whether to re-process the dataset.
|
38
37
|
"""
|
39
38
|
|
40
39
|
url = 'https://www.dropbox.com/s/1bnz8r7mofx0osf/net_aminer.zip?dl=1'
|
@@ -30,25 +30,22 @@ class AQSOL(InMemoryDataset):
|
|
30
30
|
the :class:`~torch_geometric.datasets.ZINC` dataset.
|
31
31
|
|
32
32
|
Args:
|
33
|
-
root
|
34
|
-
split
|
33
|
+
root: Root directory where the dataset should be saved.
|
34
|
+
split: If :obj:`"train"`, loads the training dataset.
|
35
35
|
If :obj:`"val"`, loads the validation dataset.
|
36
36
|
If :obj:`"test"`, loads the test dataset.
|
37
|
-
|
38
|
-
|
39
|
-
:obj:`torch_geometric.data.Data` object and returns a transformed
|
37
|
+
transform: A function/transform that takes in a
|
38
|
+
:class:`torch_geometric.data.Data` object and returns a transformed
|
40
39
|
version. The data object will be transformed before every access.
|
41
|
-
|
42
|
-
|
43
|
-
an :obj:`torch_geometric.data.Data` object and returns a
|
40
|
+
pre_transform: A function/transform that takes in a
|
41
|
+
:class:`torch_geometric.data.Data` object and returns a
|
44
42
|
transformed version. The data object will be transformed before
|
45
|
-
being saved to disk.
|
43
|
+
being saved to disk.
|
46
44
|
pre_filter (callable, optional): A function that takes in an
|
47
|
-
:
|
45
|
+
:class:`torch_geometric.data.Data` object and returns a boolean
|
48
46
|
value, indicating whether the data object should be included in
|
49
|
-
the final dataset.
|
50
|
-
force_reload
|
51
|
-
(default: :obj:`False`)
|
47
|
+
the final dataset.
|
48
|
+
force_reload: Whether to re-process the dataset.
|
52
49
|
|
53
50
|
**STATS:**
|
54
51
|
|
@@ -2,7 +2,6 @@ import os
|
|
2
2
|
import os.path as osp
|
3
3
|
from typing import Callable, List, Optional
|
4
4
|
|
5
|
-
import scipy.sparse as sp
|
6
5
|
import torch
|
7
6
|
|
8
7
|
from torch_geometric.data import (
|
@@ -20,21 +19,19 @@ class AttributedGraphDataset(InMemoryDataset):
|
|
20
19
|
<https://arxiv.org/abs/2009.00826>`_ paper.
|
21
20
|
|
22
21
|
Args:
|
23
|
-
root
|
24
|
-
name
|
22
|
+
root: Root directory where the dataset should be saved.
|
23
|
+
name: The name of the dataset (:obj:`"Wiki"`, :obj:`"Cora"`,
|
25
24
|
:obj:`"CiteSeer"`, :obj:`"PubMed"`, :obj:`"BlogCatalog"`,
|
26
25
|
:obj:`"PPI"`, :obj:`"Flickr"`, :obj:`"Facebook"`, :obj:`"Twitter"`,
|
27
26
|
:obj:`"TWeibo"`, :obj:`"MAG"`).
|
28
|
-
transform
|
29
|
-
:
|
27
|
+
transform: A function/transform that takes in a
|
28
|
+
:class:`torch_geometric.data.Data` object and returns a transformed
|
30
29
|
version. The data object will be transformed before every access.
|
31
|
-
|
32
|
-
|
33
|
-
an :obj:`torch_geometric.data.Data` object and returns a
|
30
|
+
pre_transform: A function/transform that takes in a
|
31
|
+
:class:`torch_geometric.data.Data` object and returns a
|
34
32
|
transformed version. The data object will be transformed before
|
35
|
-
being saved to disk.
|
36
|
-
force_reload
|
37
|
-
(default: :obj:`False`)
|
33
|
+
being saved to disk.
|
34
|
+
force_reload: Whether to re-process the dataset.
|
38
35
|
|
39
36
|
**STATS:**
|
40
37
|
|
@@ -156,6 +153,7 @@ class AttributedGraphDataset(InMemoryDataset):
|
|
156
153
|
|
157
154
|
def process(self) -> None:
|
158
155
|
import pandas as pd
|
156
|
+
import scipy.sparse as sp
|
159
157
|
|
160
158
|
x = sp.load_npz(self.raw_paths[0]).tocsr()
|
161
159
|
if x.shape[-1] > 10000 or self.name == 'mag':
|
@@ -172,7 +170,7 @@ class AttributedGraphDataset(InMemoryDataset):
|
|
172
170
|
engine='python')
|
173
171
|
edge_index = torch.from_numpy(df.values).t().contiguous()
|
174
172
|
|
175
|
-
with open(self.raw_paths[2]
|
173
|
+
with open(self.raw_paths[2]) as f:
|
176
174
|
rows = f.read().split('\n')[:-1]
|
177
175
|
ys = [[int(y) - 1 for y in row.split()[1:]] for row in rows]
|
178
176
|
multilabel = max([len(y) for y in ys]) > 1
|
@@ -25,21 +25,19 @@ class BAMultiShapesDataset(InMemoryDataset):
|
|
25
25
|
This dataset is pre-computed from the official implementation.
|
26
26
|
|
27
27
|
Args:
|
28
|
-
root
|
29
|
-
transform
|
30
|
-
:
|
28
|
+
root: Root directory where the dataset should be saved.
|
29
|
+
transform: A function/transform that takes in a
|
30
|
+
:class:`torch_geometric.data.Data` object and returns a transformed
|
31
31
|
version. The data object will be transformed before every access.
|
32
|
-
|
33
|
-
|
34
|
-
an :obj:`torch_geometric.data.Data` object and returns a
|
32
|
+
pre_transform: A function/transform that takes in a
|
33
|
+
:class:`torch_geometric.data.Data` object and returns a
|
35
34
|
transformed version. The data object will be transformed before
|
36
|
-
being saved to disk.
|
37
|
-
pre_filter
|
38
|
-
:
|
35
|
+
being saved to disk.
|
36
|
+
pre_filter: A function that takes in a
|
37
|
+
:class:`torch_geometric.data.Data` object and returns a boolean
|
39
38
|
value, indicating whether the data object should be included in the
|
40
|
-
final dataset.
|
41
|
-
force_reload
|
42
|
-
(default: :obj:`False`)
|
39
|
+
final dataset.
|
40
|
+
force_reload: Whether to re-process the dataset.
|
43
41
|
|
44
42
|
**STATS:**
|
45
43
|
|