pyg-nightly 2.6.0.dev20240318__py3-none-any.whl → 2.7.0.dev20250115__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- {pyg_nightly-2.6.0.dev20240318.dist-info → pyg_nightly-2.7.0.dev20250115.dist-info}/METADATA +31 -47
- {pyg_nightly-2.6.0.dev20240318.dist-info → pyg_nightly-2.7.0.dev20250115.dist-info}/RECORD +226 -199
- {pyg_nightly-2.6.0.dev20240318.dist-info → pyg_nightly-2.7.0.dev20250115.dist-info}/WHEEL +1 -1
- torch_geometric/__init__.py +28 -1
- torch_geometric/_compile.py +8 -1
- torch_geometric/_onnx.py +14 -0
- torch_geometric/config_mixin.py +113 -0
- torch_geometric/config_store.py +28 -19
- torch_geometric/data/__init__.py +24 -1
- torch_geometric/data/batch.py +2 -2
- torch_geometric/data/collate.py +8 -2
- torch_geometric/data/data.py +16 -8
- torch_geometric/data/database.py +61 -15
- torch_geometric/data/dataset.py +14 -6
- torch_geometric/data/feature_store.py +25 -42
- torch_geometric/data/graph_store.py +1 -5
- torch_geometric/data/hetero_data.py +18 -9
- torch_geometric/data/in_memory_dataset.py +2 -4
- torch_geometric/data/large_graph_indexer.py +677 -0
- torch_geometric/data/lightning/datamodule.py +4 -4
- torch_geometric/data/separate.py +6 -1
- torch_geometric/data/storage.py +17 -7
- torch_geometric/data/summary.py +14 -4
- torch_geometric/data/temporal.py +1 -2
- torch_geometric/datasets/__init__.py +17 -2
- torch_geometric/datasets/actor.py +9 -11
- torch_geometric/datasets/airfrans.py +15 -18
- torch_geometric/datasets/airports.py +10 -12
- torch_geometric/datasets/amazon.py +8 -11
- torch_geometric/datasets/amazon_book.py +9 -10
- torch_geometric/datasets/amazon_products.py +9 -10
- torch_geometric/datasets/aminer.py +8 -9
- torch_geometric/datasets/aqsol.py +10 -13
- torch_geometric/datasets/attributed_graph_dataset.py +10 -12
- torch_geometric/datasets/ba_multi_shapes.py +10 -12
- torch_geometric/datasets/ba_shapes.py +5 -6
- torch_geometric/datasets/bitcoin_otc.py +1 -1
- torch_geometric/datasets/brca_tgca.py +1 -1
- torch_geometric/datasets/cornell.py +145 -0
- torch_geometric/datasets/dblp.py +2 -1
- torch_geometric/datasets/dbp15k.py +2 -2
- torch_geometric/datasets/fake.py +1 -3
- torch_geometric/datasets/flickr.py +2 -1
- torch_geometric/datasets/freebase.py +1 -1
- torch_geometric/datasets/gdelt_lite.py +3 -2
- torch_geometric/datasets/ged_dataset.py +3 -2
- torch_geometric/datasets/git_mol_dataset.py +263 -0
- torch_geometric/datasets/gnn_benchmark_dataset.py +11 -10
- torch_geometric/datasets/hgb_dataset.py +8 -8
- torch_geometric/datasets/imdb.py +2 -1
- torch_geometric/datasets/karate.py +3 -2
- torch_geometric/datasets/last_fm.py +2 -1
- torch_geometric/datasets/linkx_dataset.py +4 -3
- torch_geometric/datasets/lrgb.py +3 -5
- torch_geometric/datasets/malnet_tiny.py +4 -3
- torch_geometric/datasets/mnist_superpixels.py +2 -3
- torch_geometric/datasets/molecule_gpt_dataset.py +485 -0
- torch_geometric/datasets/molecule_net.py +15 -3
- torch_geometric/datasets/motif_generator/base.py +0 -1
- torch_geometric/datasets/neurograph.py +1 -3
- torch_geometric/datasets/ogb_mag.py +1 -1
- torch_geometric/datasets/opf.py +239 -0
- torch_geometric/datasets/ose_gvcs.py +1 -1
- torch_geometric/datasets/pascal.py +11 -9
- torch_geometric/datasets/pascal_pf.py +1 -1
- torch_geometric/datasets/pcpnet_dataset.py +1 -1
- torch_geometric/datasets/pcqm4m.py +10 -3
- torch_geometric/datasets/ppi.py +1 -1
- torch_geometric/datasets/qm9.py +8 -7
- torch_geometric/datasets/rcdd.py +4 -4
- torch_geometric/datasets/reddit.py +2 -1
- torch_geometric/datasets/reddit2.py +2 -1
- torch_geometric/datasets/rel_link_pred_dataset.py +3 -3
- torch_geometric/datasets/s3dis.py +5 -3
- torch_geometric/datasets/shapenet.py +3 -3
- torch_geometric/datasets/shrec2016.py +2 -2
- torch_geometric/datasets/snap_dataset.py +7 -1
- torch_geometric/datasets/tag_dataset.py +350 -0
- torch_geometric/datasets/upfd.py +2 -1
- torch_geometric/datasets/web_qsp_dataset.py +246 -0
- torch_geometric/datasets/webkb.py +2 -2
- torch_geometric/datasets/wikics.py +1 -1
- torch_geometric/datasets/wikidata.py +3 -2
- torch_geometric/datasets/wikipedia_network.py +2 -2
- torch_geometric/datasets/willow_object_class.py +1 -1
- torch_geometric/datasets/word_net.py +2 -2
- torch_geometric/datasets/yelp.py +2 -1
- torch_geometric/datasets/zinc.py +1 -1
- torch_geometric/device.py +42 -0
- torch_geometric/distributed/local_feature_store.py +3 -2
- torch_geometric/distributed/local_graph_store.py +2 -1
- torch_geometric/distributed/partition.py +9 -8
- torch_geometric/edge_index.py +616 -438
- torch_geometric/explain/algorithm/base.py +0 -1
- torch_geometric/explain/algorithm/graphmask_explainer.py +1 -2
- torch_geometric/explain/algorithm/pg_explainer.py +1 -1
- torch_geometric/explain/explanation.py +2 -2
- torch_geometric/graphgym/checkpoint.py +2 -1
- torch_geometric/graphgym/logger.py +4 -4
- torch_geometric/graphgym/loss.py +1 -1
- torch_geometric/graphgym/utils/agg_runs.py +6 -6
- torch_geometric/index.py +826 -0
- torch_geometric/inspector.py +13 -7
- torch_geometric/io/fs.py +28 -2
- torch_geometric/io/npz.py +2 -1
- torch_geometric/io/off.py +2 -2
- torch_geometric/io/sdf.py +2 -2
- torch_geometric/io/tu.py +4 -5
- torch_geometric/loader/__init__.py +4 -0
- torch_geometric/loader/cluster.py +10 -4
- torch_geometric/loader/graph_saint.py +2 -1
- torch_geometric/loader/ibmb_loader.py +12 -4
- torch_geometric/loader/mixin.py +1 -1
- torch_geometric/loader/neighbor_loader.py +1 -1
- torch_geometric/loader/neighbor_sampler.py +2 -2
- torch_geometric/loader/prefetch.py +1 -1
- torch_geometric/loader/rag_loader.py +107 -0
- torch_geometric/loader/utils.py +8 -7
- torch_geometric/loader/zip_loader.py +10 -0
- torch_geometric/metrics/__init__.py +11 -2
- torch_geometric/metrics/link_pred.py +317 -65
- torch_geometric/nn/aggr/__init__.py +4 -0
- torch_geometric/nn/aggr/attention.py +0 -2
- torch_geometric/nn/aggr/base.py +3 -5
- torch_geometric/nn/aggr/patch_transformer.py +143 -0
- torch_geometric/nn/aggr/set_transformer.py +1 -1
- torch_geometric/nn/aggr/variance_preserving.py +33 -0
- torch_geometric/nn/attention/__init__.py +5 -1
- torch_geometric/nn/attention/qformer.py +71 -0
- torch_geometric/nn/conv/collect.jinja +7 -4
- torch_geometric/nn/conv/cugraph/base.py +8 -12
- torch_geometric/nn/conv/edge_conv.py +3 -2
- torch_geometric/nn/conv/fused_gat_conv.py +1 -1
- torch_geometric/nn/conv/gat_conv.py +35 -7
- torch_geometric/nn/conv/gatv2_conv.py +36 -6
- torch_geometric/nn/conv/general_conv.py +1 -1
- torch_geometric/nn/conv/graph_conv.py +21 -3
- torch_geometric/nn/conv/gravnet_conv.py +3 -2
- torch_geometric/nn/conv/hetero_conv.py +3 -3
- torch_geometric/nn/conv/hgt_conv.py +1 -1
- torch_geometric/nn/conv/message_passing.py +138 -87
- torch_geometric/nn/conv/mixhop_conv.py +1 -1
- torch_geometric/nn/conv/propagate.jinja +9 -1
- torch_geometric/nn/conv/rgcn_conv.py +5 -5
- torch_geometric/nn/conv/spline_conv.py +4 -4
- torch_geometric/nn/conv/x_conv.py +3 -2
- torch_geometric/nn/dense/linear.py +11 -6
- torch_geometric/nn/fx.py +3 -3
- torch_geometric/nn/model_hub.py +3 -1
- torch_geometric/nn/models/__init__.py +10 -2
- torch_geometric/nn/models/deep_graph_infomax.py +1 -2
- torch_geometric/nn/models/dimenet_utils.py +5 -7
- torch_geometric/nn/models/g_retriever.py +230 -0
- torch_geometric/nn/models/git_mol.py +336 -0
- torch_geometric/nn/models/glem.py +385 -0
- torch_geometric/nn/models/gnnff.py +0 -1
- torch_geometric/nn/models/graph_unet.py +12 -3
- torch_geometric/nn/models/jumping_knowledge.py +63 -4
- torch_geometric/nn/models/lightgcn.py +1 -1
- torch_geometric/nn/models/metapath2vec.py +5 -5
- torch_geometric/nn/models/molecule_gpt.py +222 -0
- torch_geometric/nn/models/node2vec.py +2 -3
- torch_geometric/nn/models/schnet.py +2 -1
- torch_geometric/nn/models/signed_gcn.py +3 -3
- torch_geometric/nn/module_dict.py +2 -2
- torch_geometric/nn/nlp/__init__.py +9 -0
- torch_geometric/nn/nlp/llm.py +329 -0
- torch_geometric/nn/nlp/sentence_transformer.py +134 -0
- torch_geometric/nn/nlp/vision_transformer.py +33 -0
- torch_geometric/nn/norm/batch_norm.py +1 -1
- torch_geometric/nn/parameter_dict.py +2 -2
- torch_geometric/nn/pool/__init__.py +21 -5
- torch_geometric/nn/pool/cluster_pool.py +145 -0
- torch_geometric/nn/pool/connect/base.py +0 -1
- torch_geometric/nn/pool/edge_pool.py +1 -1
- torch_geometric/nn/pool/graclus.py +4 -2
- torch_geometric/nn/pool/pool.py +8 -2
- torch_geometric/nn/pool/select/base.py +0 -1
- torch_geometric/nn/pool/voxel_grid.py +3 -2
- torch_geometric/nn/resolver.py +1 -1
- torch_geometric/nn/sequential.jinja +10 -23
- torch_geometric/nn/sequential.py +204 -78
- torch_geometric/nn/summary.py +1 -1
- torch_geometric/nn/to_hetero_with_bases_transformer.py +19 -19
- torch_geometric/profile/__init__.py +2 -0
- torch_geometric/profile/nvtx.py +66 -0
- torch_geometric/profile/profiler.py +30 -19
- torch_geometric/resolver.py +1 -1
- torch_geometric/sampler/base.py +34 -13
- torch_geometric/sampler/neighbor_sampler.py +11 -10
- torch_geometric/sampler/utils.py +1 -1
- torch_geometric/template.py +1 -0
- torch_geometric/testing/__init__.py +6 -2
- torch_geometric/testing/decorators.py +56 -22
- torch_geometric/testing/feature_store.py +1 -1
- torch_geometric/transforms/__init__.py +2 -0
- torch_geometric/transforms/add_metapaths.py +5 -5
- torch_geometric/transforms/add_positional_encoding.py +1 -1
- torch_geometric/transforms/delaunay.py +65 -14
- torch_geometric/transforms/face_to_edge.py +32 -3
- torch_geometric/transforms/gdc.py +7 -6
- torch_geometric/transforms/laplacian_lambda_max.py +3 -3
- torch_geometric/transforms/mask.py +5 -1
- torch_geometric/transforms/node_property_split.py +1 -2
- torch_geometric/transforms/pad.py +7 -6
- torch_geometric/transforms/random_link_split.py +1 -1
- torch_geometric/transforms/remove_self_loops.py +36 -0
- torch_geometric/transforms/svd_feature_reduction.py +1 -1
- torch_geometric/transforms/to_sparse_tensor.py +1 -1
- torch_geometric/transforms/two_hop.py +1 -1
- torch_geometric/transforms/virtual_node.py +2 -1
- torch_geometric/typing.py +43 -6
- torch_geometric/utils/__init__.py +5 -1
- torch_geometric/utils/_negative_sampling.py +1 -1
- torch_geometric/utils/_normalize_edge_index.py +46 -0
- torch_geometric/utils/_scatter.py +38 -12
- torch_geometric/utils/_subgraph.py +4 -0
- torch_geometric/utils/_tree_decomposition.py +2 -2
- torch_geometric/utils/augmentation.py +1 -1
- torch_geometric/utils/convert.py +12 -8
- torch_geometric/utils/geodesic.py +24 -22
- torch_geometric/utils/hetero.py +1 -1
- torch_geometric/utils/map.py +8 -2
- torch_geometric/utils/smiles.py +65 -27
- torch_geometric/utils/sparse.py +39 -25
- torch_geometric/visualization/graph.py +3 -4
@@ -0,0 +1,239 @@
|
|
1
|
+
import json
|
2
|
+
import os
|
3
|
+
import os.path as osp
|
4
|
+
from typing import Callable, Dict, List, Literal, Optional
|
5
|
+
|
6
|
+
import torch
|
7
|
+
import tqdm
|
8
|
+
from torch import Tensor
|
9
|
+
|
10
|
+
from torch_geometric.data import (
|
11
|
+
HeteroData,
|
12
|
+
InMemoryDataset,
|
13
|
+
download_url,
|
14
|
+
extract_tar,
|
15
|
+
)
|
16
|
+
|
17
|
+
|
18
|
+
class OPFDataset(InMemoryDataset):
|
19
|
+
r"""The heterogeneous OPF data from the `"Large-scale Datasets for AC
|
20
|
+
Optimal Power Flow with Topological Perturbations"
|
21
|
+
<https://arxiv.org/abs/2406.07234>`_ paper.
|
22
|
+
|
23
|
+
:class:`OPFDataset` is a large-scale dataset of solved optimal power flow
|
24
|
+
problems, derived from the
|
25
|
+
`pglib-opf <https://github.com/power-grid-lib/pglib-opf>`_ dataset.
|
26
|
+
|
27
|
+
The physical topology of the grid is represented by the :obj:`"bus"` node
|
28
|
+
type, and the connecting AC lines and transformers. Additionally,
|
29
|
+
:obj:`"generator"`, :obj:`"load"`, and :obj:`"shunt"` nodes are connected
|
30
|
+
to :obj:`"bus"` nodes using a dedicated edge type each, *e.g.*,
|
31
|
+
:obj:`"generator_link"`.
|
32
|
+
|
33
|
+
Edge direction corresponds to the properties of the line, *e.g.*,
|
34
|
+
:obj:`b_fr` is the line charging susceptance at the :obj:`from`
|
35
|
+
(source/sender) bus.
|
36
|
+
|
37
|
+
Args:
|
38
|
+
root (str): Root directory where the dataset should be saved.
|
39
|
+
split (str, optional): If :obj:`"train"`, loads the training dataset.
|
40
|
+
If :obj:`"val"`, loads the validation dataset.
|
41
|
+
If :obj:`"test"`, loads the test dataset. (default: :obj:`"train"`)
|
42
|
+
case_name (str, optional): The name of the original pglib-opf case.
|
43
|
+
(default: :obj:`"pglib_opf_case14_ieee"`)
|
44
|
+
num_groups (int, optional): The dataset is divided into 20 groups with
|
45
|
+
each group containing 15,000 samples.
|
46
|
+
For large networks, this amount of data can be overwhelming.
|
47
|
+
The :obj:`num_groups` parameters controls the amount of data being
|
48
|
+
downloaded. Allowed values are :obj:`[1, 20]`.
|
49
|
+
(default: :obj:`20`)
|
50
|
+
topological_perturbations (bool, optional): Whether to use the dataset
|
51
|
+
with added topological perturbations. (default: :obj:`False`)
|
52
|
+
transform (callable, optional): A function/transform that takes in
|
53
|
+
a :obj:`torch_geometric.data.HeteroData` object and returns a
|
54
|
+
transformed version. The data object will be transformed before
|
55
|
+
every access. (default: :obj:`None`)
|
56
|
+
pre_transform (callable, optional): A function/transform that takes
|
57
|
+
in a :obj:`torch_geometric.data.HeteroData` object and returns
|
58
|
+
a transformed version. The data object will be transformed before
|
59
|
+
being saved to disk. (default: :obj:`None`)
|
60
|
+
pre_filter (callable, optional): A function that takes in a
|
61
|
+
:obj:`torch_geometric.data.HeteroData` object and returns a boolean
|
62
|
+
value, indicating whether the data object should be included in the
|
63
|
+
final dataset. (default: :obj:`None`)
|
64
|
+
force_reload (bool, optional): Whether to re-process the dataset.
|
65
|
+
(default: :obj:`False`)
|
66
|
+
"""
|
67
|
+
url = 'https://storage.googleapis.com/gridopt-dataset'
|
68
|
+
|
69
|
+
def __init__(
|
70
|
+
self,
|
71
|
+
root: str,
|
72
|
+
split: Literal['train', 'val', 'test'] = 'train',
|
73
|
+
case_name: Literal[
|
74
|
+
'pglib_opf_case14_ieee',
|
75
|
+
'pglib_opf_case30_ieee',
|
76
|
+
'pglib_opf_case57_ieee',
|
77
|
+
'pglib_opf_case118_ieee',
|
78
|
+
'pglib_opf_case500_goc',
|
79
|
+
'pglib_opf_case2000_goc',
|
80
|
+
'pglib_opf_case6470_rte',
|
81
|
+
'pglib_opf_case4661_sdet'
|
82
|
+
'pglib_opf_case10000_goc',
|
83
|
+
'pglib_opf_case13659_pegase',
|
84
|
+
] = 'pglib_opf_case14_ieee',
|
85
|
+
num_groups: int = 20,
|
86
|
+
topological_perturbations: bool = False,
|
87
|
+
transform: Optional[Callable] = None,
|
88
|
+
pre_transform: Optional[Callable] = None,
|
89
|
+
pre_filter: Optional[Callable] = None,
|
90
|
+
force_reload: bool = False,
|
91
|
+
) -> None:
|
92
|
+
|
93
|
+
self.split = split
|
94
|
+
self.case_name = case_name
|
95
|
+
self.num_groups = num_groups
|
96
|
+
self.topological_perturbations = topological_perturbations
|
97
|
+
|
98
|
+
self._release = 'dataset_release_1'
|
99
|
+
if topological_perturbations:
|
100
|
+
self._release += '_nminusone'
|
101
|
+
|
102
|
+
super().__init__(root, transform, pre_transform, pre_filter,
|
103
|
+
force_reload=force_reload)
|
104
|
+
|
105
|
+
idx = self.processed_file_names.index(f'{split}.pt')
|
106
|
+
self.load(self.processed_paths[idx])
|
107
|
+
|
108
|
+
@property
|
109
|
+
def raw_dir(self) -> str:
|
110
|
+
return osp.join(self.root, self._release, self.case_name, 'raw')
|
111
|
+
|
112
|
+
@property
|
113
|
+
def processed_dir(self) -> str:
|
114
|
+
return osp.join(self.root, self._release, self.case_name,
|
115
|
+
f'processed_{self.num_groups}')
|
116
|
+
|
117
|
+
@property
|
118
|
+
def raw_file_names(self) -> List[str]:
|
119
|
+
return [f'{self.case_name}_{i}.tar.gz' for i in range(self.num_groups)]
|
120
|
+
|
121
|
+
@property
|
122
|
+
def processed_file_names(self) -> List[str]:
|
123
|
+
return ['train.pt', 'val.pt', 'test.pt']
|
124
|
+
|
125
|
+
def download(self) -> None:
|
126
|
+
for name in self.raw_file_names:
|
127
|
+
url = f'{self.url}/{self._release}/{name}'
|
128
|
+
path = download_url(url, self.raw_dir)
|
129
|
+
extract_tar(path, self.raw_dir)
|
130
|
+
|
131
|
+
def process(self) -> None:
|
132
|
+
train_data_list = []
|
133
|
+
val_data_list = []
|
134
|
+
test_data_list = []
|
135
|
+
|
136
|
+
for group in tqdm.tqdm(range(self.num_groups)):
|
137
|
+
tmp_dir = osp.join(
|
138
|
+
self.raw_dir,
|
139
|
+
'gridopt-dataset-tmp',
|
140
|
+
self._release,
|
141
|
+
self.case_name,
|
142
|
+
f'group_{group}',
|
143
|
+
)
|
144
|
+
|
145
|
+
for name in os.listdir(tmp_dir):
|
146
|
+
with open(osp.join(tmp_dir, name)) as f:
|
147
|
+
obj = json.load(f)
|
148
|
+
|
149
|
+
grid = obj['grid']
|
150
|
+
solution = obj['solution']
|
151
|
+
metadata = obj['metadata']
|
152
|
+
|
153
|
+
# Graph-level properties:
|
154
|
+
data = HeteroData()
|
155
|
+
data.x = torch.tensor(grid['context']).view(-1)
|
156
|
+
|
157
|
+
data.objective = torch.tensor(metadata['objective'])
|
158
|
+
|
159
|
+
# Nodes (only some have a target):
|
160
|
+
data['bus'].x = torch.tensor(grid['nodes']['bus'])
|
161
|
+
data['bus'].y = torch.tensor(solution['nodes']['bus'])
|
162
|
+
|
163
|
+
data['generator'].x = torch.tensor(grid['nodes']['generator'])
|
164
|
+
data['generator'].y = torch.tensor(
|
165
|
+
solution['nodes']['generator'])
|
166
|
+
|
167
|
+
data['load'].x = torch.tensor(grid['nodes']['load'])
|
168
|
+
|
169
|
+
data['shunt'].x = torch.tensor(grid['nodes']['shunt'])
|
170
|
+
|
171
|
+
# Edges (only ac lines and transformers have features):
|
172
|
+
data['bus', 'ac_line', 'bus'].edge_index = ( #
|
173
|
+
extract_edge_index(obj, 'ac_line'))
|
174
|
+
data['bus', 'ac_line', 'bus'].edge_attr = torch.tensor(
|
175
|
+
grid['edges']['ac_line']['features'])
|
176
|
+
data['bus', 'ac_line', 'bus'].edge_label = torch.tensor(
|
177
|
+
solution['edges']['ac_line']['features'])
|
178
|
+
|
179
|
+
data['bus', 'transformer', 'bus'].edge_index = ( #
|
180
|
+
extract_edge_index(obj, 'transformer'))
|
181
|
+
data['bus', 'transformer', 'bus'].edge_attr = torch.tensor(
|
182
|
+
grid['edges']['transformer']['features'])
|
183
|
+
data['bus', 'transformer', 'bus'].edge_label = torch.tensor(
|
184
|
+
solution['edges']['transformer']['features'])
|
185
|
+
|
186
|
+
data['generator', 'generator_link', 'bus'].edge_index = ( #
|
187
|
+
extract_edge_index(obj, 'generator_link'))
|
188
|
+
data['bus', 'generator_link', 'generator'].edge_index = ( #
|
189
|
+
extract_edge_index_rev(obj, 'generator_link'))
|
190
|
+
|
191
|
+
data['load', 'load_link', 'bus'].edge_index = ( #
|
192
|
+
extract_edge_index(obj, 'load_link'))
|
193
|
+
data['bus', 'load_link', 'load'].edge_index = ( #
|
194
|
+
extract_edge_index_rev(obj, 'load_link'))
|
195
|
+
|
196
|
+
data['shunt', 'shunt_link', 'bus'].edge_index = ( #
|
197
|
+
extract_edge_index(obj, 'shunt_link'))
|
198
|
+
data['bus', 'shunt_link', 'shunt'].edge_index = ( #
|
199
|
+
extract_edge_index_rev(obj, 'shunt_link'))
|
200
|
+
|
201
|
+
if self.pre_filter is not None and not self.pre_filter(data):
|
202
|
+
continue
|
203
|
+
|
204
|
+
if self.pre_transform is not None:
|
205
|
+
data = self.pre_transform(data)
|
206
|
+
|
207
|
+
i = int(name.split('.')[0].split('_')[1])
|
208
|
+
train_limit = int(15_000 * self.num_groups * 0.9)
|
209
|
+
val_limit = train_limit + int(15_000 * self.num_groups * 0.05)
|
210
|
+
if i < train_limit:
|
211
|
+
train_data_list.append(data)
|
212
|
+
elif i < val_limit:
|
213
|
+
val_data_list.append(data)
|
214
|
+
else:
|
215
|
+
test_data_list.append(data)
|
216
|
+
|
217
|
+
self.save(train_data_list, self.processed_paths[0])
|
218
|
+
self.save(val_data_list, self.processed_paths[1])
|
219
|
+
self.save(test_data_list, self.processed_paths[2])
|
220
|
+
|
221
|
+
def __repr__(self) -> str:
|
222
|
+
return (f'{self.__class__.__name__}({len(self)}, '
|
223
|
+
f'split={self.split}, '
|
224
|
+
f'case_name={self.case_name}, '
|
225
|
+
f'topological_perturbations={self.topological_perturbations})')
|
226
|
+
|
227
|
+
|
228
|
+
def extract_edge_index(obj: Dict, edge_name: str) -> Tensor:
|
229
|
+
return torch.tensor([
|
230
|
+
obj['grid']['edges'][edge_name]['senders'],
|
231
|
+
obj['grid']['edges'][edge_name]['receivers'],
|
232
|
+
])
|
233
|
+
|
234
|
+
|
235
|
+
def extract_edge_index_rev(obj: Dict, edge_name: str) -> Tensor:
|
236
|
+
return torch.tensor([
|
237
|
+
obj['grid']['edges'][edge_name]['receivers'],
|
238
|
+
obj['grid']['edges'][edge_name]['senders'],
|
239
|
+
])
|
@@ -97,7 +97,7 @@ class OSE_GVCS(InMemoryDataset):
|
|
97
97
|
edges = defaultdict(list)
|
98
98
|
|
99
99
|
for path in self.raw_paths:
|
100
|
-
with open(path
|
100
|
+
with open(path) as f:
|
101
101
|
product = json.load(f)
|
102
102
|
categories.append(self.categories.index(product['category']))
|
103
103
|
for interaction in product['ecology']:
|
@@ -192,19 +192,19 @@ class PascalVOCKeypoints(InMemoryDataset):
|
|
192
192
|
|
193
193
|
child = obj.getElementsByTagName('xmin')[0].firstChild
|
194
194
|
assert child is not None
|
195
|
-
xmin
|
195
|
+
xmin = int(child.data) # type: ignore
|
196
196
|
|
197
197
|
child = obj.getElementsByTagName('xmax')[0].firstChild
|
198
198
|
assert child is not None
|
199
|
-
xmax =
|
199
|
+
xmax = int(child.data) # type: ignore
|
200
200
|
|
201
201
|
child = obj.getElementsByTagName('ymin')[0].firstChild
|
202
202
|
assert child is not None
|
203
|
-
ymin =
|
203
|
+
ymin = int(child.data) # type: ignore
|
204
204
|
|
205
205
|
child = obj.getElementsByTagName('ymax')[0].firstChild
|
206
206
|
assert child is not None
|
207
|
-
ymax =
|
207
|
+
ymax = int(child.data) # type: ignore
|
208
208
|
|
209
209
|
box = (xmin, ymin, xmax, ymax)
|
210
210
|
|
@@ -227,10 +227,12 @@ class PascalVOCKeypoints(InMemoryDataset):
|
|
227
227
|
|
228
228
|
# Add a small offset to the bounding because some keypoints lay
|
229
229
|
# outside the bounding box intervals.
|
230
|
-
box = (
|
231
|
-
|
232
|
-
|
233
|
-
|
230
|
+
box = (
|
231
|
+
min(int(pos[:, 0].min().floor()), box[0]) - 16,
|
232
|
+
min(int(pos[:, 1].min().floor()), box[1]) - 16,
|
233
|
+
max(int(pos[:, 0].max().ceil()), box[2]) + 16,
|
234
|
+
max(int(pos[:, 1].max().ceil()), box[3]) + 16,
|
235
|
+
)
|
234
236
|
|
235
237
|
# Rescale keypoints.
|
236
238
|
pos[:, 0] = (pos[:, 0] - box[0]) * 256.0 / (box[2] - box[0])
|
@@ -239,7 +241,7 @@ class PascalVOCKeypoints(InMemoryDataset):
|
|
239
241
|
path = osp.join(image_path, f'{filename}.jpg')
|
240
242
|
with open(path, 'rb') as f:
|
241
243
|
img = Image.open(f).convert('RGB').crop(box)
|
242
|
-
img = img.resize((256, 256), resample=Image.BICUBIC)
|
244
|
+
img = img.resize((256, 256), resample=Image.Resampling.BICUBIC)
|
243
245
|
|
244
246
|
img = transform(img)
|
245
247
|
|
@@ -66,7 +66,7 @@ class PascalPF(InMemoryDataset):
|
|
66
66
|
super().__init__(root, transform, pre_transform, pre_filter,
|
67
67
|
force_reload=force_reload)
|
68
68
|
self.load(self.processed_paths[0])
|
69
|
-
self.pairs =
|
69
|
+
self.pairs = fs.torch_load(self.processed_paths[1])
|
70
70
|
|
71
71
|
@property
|
72
72
|
def raw_file_names(self) -> List[str]:
|
@@ -121,7 +121,7 @@ class PCPNetDataset(InMemoryDataset):
|
|
121
121
|
|
122
122
|
def process(self) -> None:
|
123
123
|
path_file = self.raw_paths
|
124
|
-
with open(path_file[0]
|
124
|
+
with open(path_file[0]) as f:
|
125
125
|
filenames = f.read().split('\n')[:-1]
|
126
126
|
data_list = []
|
127
127
|
for filename in filenames:
|
@@ -7,7 +7,8 @@ from tqdm import tqdm
|
|
7
7
|
|
8
8
|
from torch_geometric.data import Data, OnDiskDataset, download_url, extract_zip
|
9
9
|
from torch_geometric.data.data import BaseData
|
10
|
-
from torch_geometric.
|
10
|
+
from torch_geometric.io import fs
|
11
|
+
from torch_geometric.utils import from_smiles as _from_smiles
|
11
12
|
|
12
13
|
|
13
14
|
class PCQM4Mv2(OnDiskDataset):
|
@@ -36,6 +37,10 @@ class PCQM4Mv2(OnDiskDataset):
|
|
36
37
|
(default: :obj:`None`)
|
37
38
|
backend (str): The :class:`Database` backend to use.
|
38
39
|
(default: :obj:`"sqlite"`)
|
40
|
+
from_smiles (callable, optional): A custom function that takes a SMILES
|
41
|
+
string and outputs a :obj:`~torch_geometric.data.Data` object.
|
42
|
+
If not set, defaults to :meth:`~torch_geometric.utils.from_smiles`.
|
43
|
+
(default: :obj:`None`)
|
39
44
|
"""
|
40
45
|
url = ('https://dgl-data.s3-accelerate.amazonaws.com/dataset/OGB-LSC/'
|
41
46
|
'pcqm4m-v2.zip')
|
@@ -53,6 +58,7 @@ class PCQM4Mv2(OnDiskDataset):
|
|
53
58
|
split: str = 'train',
|
54
59
|
transform: Optional[Callable] = None,
|
55
60
|
backend: str = 'sqlite',
|
61
|
+
from_smiles: Optional[Callable] = None,
|
56
62
|
) -> None:
|
57
63
|
assert split in ['train', 'val', 'test', 'holdout']
|
58
64
|
|
@@ -64,9 +70,10 @@ class PCQM4Mv2(OnDiskDataset):
|
|
64
70
|
'y': float,
|
65
71
|
}
|
66
72
|
|
73
|
+
self.from_smiles = from_smiles or _from_smiles
|
67
74
|
super().__init__(root, transform, backend=backend, schema=schema)
|
68
75
|
|
69
|
-
split_idx =
|
76
|
+
split_idx = fs.torch_load(self.raw_paths[1])
|
70
77
|
self._indices = split_idx[self.split_mapping[split]].tolist()
|
71
78
|
|
72
79
|
@property
|
@@ -89,7 +96,7 @@ class PCQM4Mv2(OnDiskDataset):
|
|
89
96
|
data_list: List[Data] = []
|
90
97
|
iterator = enumerate(zip(df['smiles'], df['homolumogap']))
|
91
98
|
for i, (smiles, y) in tqdm(iterator, total=len(df)):
|
92
|
-
data = from_smiles(smiles)
|
99
|
+
data = self.from_smiles(smiles)
|
93
100
|
data.y = y
|
94
101
|
|
95
102
|
data_list.append(data)
|
torch_geometric/datasets/ppi.py
CHANGED
@@ -106,7 +106,7 @@ class PPI(InMemoryDataset):
|
|
106
106
|
|
107
107
|
for s, split in enumerate(['train', 'valid', 'test']):
|
108
108
|
path = osp.join(self.raw_dir, f'{split}_graph.json')
|
109
|
-
with open(path
|
109
|
+
with open(path) as f:
|
110
110
|
G = nx.DiGraph(json_graph.node_link_graph(json.load(f)))
|
111
111
|
|
112
112
|
x = np.load(osp.join(self.raw_dir, f'{split}_feats.npy'))
|
torch_geometric/datasets/qm9.py
CHANGED
@@ -13,6 +13,7 @@ from torch_geometric.data import (
|
|
13
13
|
download_url,
|
14
14
|
extract_zip,
|
15
15
|
)
|
16
|
+
from torch_geometric.io import fs
|
16
17
|
from torch_geometric.utils import one_hot, scatter
|
17
18
|
|
18
19
|
HAR2EV = 27.211386246
|
@@ -198,21 +199,21 @@ class QM9(InMemoryDataset):
|
|
198
199
|
|
199
200
|
def process(self) -> None:
|
200
201
|
try:
|
201
|
-
import rdkit
|
202
202
|
from rdkit import Chem, RDLogger
|
203
203
|
from rdkit.Chem.rdchem import BondType as BT
|
204
204
|
from rdkit.Chem.rdchem import HybridizationType
|
205
|
-
RDLogger.DisableLog('rdApp.*')
|
205
|
+
RDLogger.DisableLog('rdApp.*') # type: ignore
|
206
|
+
WITH_RDKIT = True
|
206
207
|
|
207
208
|
except ImportError:
|
208
|
-
|
209
|
+
WITH_RDKIT = False
|
209
210
|
|
210
|
-
if
|
211
|
+
if not WITH_RDKIT:
|
211
212
|
print(("Using a pre-processed version of the dataset. Please "
|
212
213
|
"install 'rdkit' to alternatively process the raw data."),
|
213
214
|
file=sys.stderr)
|
214
215
|
|
215
|
-
data_list =
|
216
|
+
data_list = fs.torch_load(self.raw_paths[0])
|
216
217
|
data_list = [Data(**data_dict) for data_dict in data_list]
|
217
218
|
|
218
219
|
if self.pre_filter is not None:
|
@@ -227,14 +228,14 @@ class QM9(InMemoryDataset):
|
|
227
228
|
types = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4}
|
228
229
|
bonds = {BT.SINGLE: 0, BT.DOUBLE: 1, BT.TRIPLE: 2, BT.AROMATIC: 3}
|
229
230
|
|
230
|
-
with open(self.raw_paths[1]
|
231
|
+
with open(self.raw_paths[1]) as f:
|
231
232
|
target = [[float(x) for x in line.split(',')[1:20]]
|
232
233
|
for line in f.read().split('\n')[1:-1]]
|
233
234
|
y = torch.tensor(target, dtype=torch.float)
|
234
235
|
y = torch.cat([y[:, 3:], y[:, :3]], dim=-1)
|
235
236
|
y = y * conversion.view(1, -1)
|
236
237
|
|
237
|
-
with open(self.raw_paths[2]
|
238
|
+
with open(self.raw_paths[2]) as f:
|
238
239
|
skip = [int(x.split()[0]) - 1 for x in f.read().split('\n')[9:-2]]
|
239
240
|
|
240
241
|
suppl = Chem.SDMolSupplier(self.raw_paths[0], removeHs=False,
|
torch_geometric/datasets/rcdd.py
CHANGED
@@ -85,13 +85,13 @@ class RCDD(InMemoryDataset):
|
|
85
85
|
mapping = torch.empty(len(node_df), dtype=torch.long)
|
86
86
|
for node_type in node_df['node_type'].unique():
|
87
87
|
mask = node_df['node_type'] == node_type
|
88
|
-
|
89
|
-
num_nodes =
|
90
|
-
mapping[
|
88
|
+
node_id = torch.from_numpy(node_df['node_id'][mask].values)
|
89
|
+
num_nodes = mask.sum()
|
90
|
+
mapping[node_id] = torch.arange(num_nodes)
|
91
91
|
data[node_type].num_nodes = num_nodes
|
92
92
|
x = np.vstack([
|
93
93
|
np.asarray(f.split(':'), dtype=np.float32)
|
94
|
-
for f in node_df['node_feat'][mask
|
94
|
+
for f in node_df['node_feat'][mask]
|
95
95
|
])
|
96
96
|
data[node_type].x = torch.from_numpy(x)
|
97
97
|
|
@@ -3,7 +3,6 @@ import os.path as osp
|
|
3
3
|
from typing import Callable, List, Optional
|
4
4
|
|
5
5
|
import numpy as np
|
6
|
-
import scipy.sparse as sp
|
7
6
|
import torch
|
8
7
|
|
9
8
|
from torch_geometric.data import (
|
@@ -76,6 +75,8 @@ class Reddit(InMemoryDataset):
|
|
76
75
|
os.unlink(path)
|
77
76
|
|
78
77
|
def process(self) -> None:
|
78
|
+
import scipy.sparse as sp
|
79
|
+
|
79
80
|
data = np.load(osp.join(self.raw_dir, 'reddit_data.npz'))
|
80
81
|
x = torch.from_numpy(data['feature']).to(torch.float)
|
81
82
|
y = torch.from_numpy(data['label']).to(torch.long)
|
@@ -3,7 +3,6 @@ import os.path as osp
|
|
3
3
|
from typing import Callable, List, Optional
|
4
4
|
|
5
5
|
import numpy as np
|
6
|
-
import scipy.sparse as sp
|
7
6
|
import torch
|
8
7
|
|
9
8
|
from torch_geometric.data import Data, InMemoryDataset, download_google_url
|
@@ -81,6 +80,8 @@ class Reddit2(InMemoryDataset):
|
|
81
80
|
download_google_url(self.role_id, self.raw_dir, 'role.json')
|
82
81
|
|
83
82
|
def process(self) -> None:
|
83
|
+
import scipy.sparse as sp
|
84
|
+
|
84
85
|
f = np.load(osp.join(self.raw_dir, 'adj_full.npz'))
|
85
86
|
adj = sp.csr_matrix((f['data'], f['indices'], f['indptr']), f['shape'])
|
86
87
|
adj = adj.tocoo()
|
@@ -89,17 +89,17 @@ class RelLinkPredDataset(InMemoryDataset):
|
|
89
89
|
download_url(f'{self.urls[self.name]}/{file_name}', self.raw_dir)
|
90
90
|
|
91
91
|
def process(self) -> None:
|
92
|
-
with open(osp.join(self.raw_dir, 'entities.dict')
|
92
|
+
with open(osp.join(self.raw_dir, 'entities.dict')) as f:
|
93
93
|
lines = [row.split('\t') for row in f.read().split('\n')[:-1]]
|
94
94
|
entities_dict = {key: int(value) for value, key in lines}
|
95
95
|
|
96
|
-
with open(osp.join(self.raw_dir, 'relations.dict')
|
96
|
+
with open(osp.join(self.raw_dir, 'relations.dict')) as f:
|
97
97
|
lines = [row.split('\t') for row in f.read().split('\n')[:-1]]
|
98
98
|
relations_dict = {key: int(value) for value, key in lines}
|
99
99
|
|
100
100
|
kwargs = {}
|
101
101
|
for split in ['train', 'valid', 'test']:
|
102
|
-
with open(osp.join(self.raw_dir, f'{split}.txt')
|
102
|
+
with open(osp.join(self.raw_dir, f'{split}.txt')) as f:
|
103
103
|
lines = [row.split('\t') for row in f.read().split('\n')[:-1]]
|
104
104
|
src = [entities_dict[row[0]] for row in lines]
|
105
105
|
rel = [relations_dict[row[1]] for row in lines]
|
@@ -3,6 +3,7 @@ import os.path as osp
|
|
3
3
|
from typing import Callable, List, Optional
|
4
4
|
|
5
5
|
import torch
|
6
|
+
from torch import Tensor
|
6
7
|
|
7
8
|
from torch_geometric.data import (
|
8
9
|
Data,
|
@@ -85,13 +86,14 @@ class S3DIS(InMemoryDataset):
|
|
85
86
|
def process(self) -> None:
|
86
87
|
import h5py
|
87
88
|
|
88
|
-
with open(self.raw_paths[0]
|
89
|
+
with open(self.raw_paths[0]) as f:
|
89
90
|
filenames = [x.split('/')[-1] for x in f.read().split('\n')[:-1]]
|
90
91
|
|
91
|
-
with open(self.raw_paths[1]
|
92
|
+
with open(self.raw_paths[1]) as f:
|
92
93
|
rooms = f.read().split('\n')[:-1]
|
93
94
|
|
94
|
-
xs
|
95
|
+
xs: List[Tensor] = []
|
96
|
+
ys: List[Tensor] = []
|
95
97
|
for filename in filenames:
|
96
98
|
h5 = h5py.File(osp.join(self.raw_dir, filename))
|
97
99
|
xs += torch.from_numpy(h5['data'][:]).unbind(0)
|
@@ -148,8 +148,8 @@ class ShapeNet(InMemoryDataset):
|
|
148
148
|
elif split == 'trainval':
|
149
149
|
path = self.processed_paths[3]
|
150
150
|
else:
|
151
|
-
raise ValueError(
|
152
|
-
|
151
|
+
raise ValueError(f'Split {split} found, but expected either '
|
152
|
+
'train, val, trainval or test')
|
153
153
|
|
154
154
|
self.load(path)
|
155
155
|
|
@@ -213,7 +213,7 @@ class ShapeNet(InMemoryDataset):
|
|
213
213
|
for i, split in enumerate(['train', 'val', 'test']):
|
214
214
|
path = osp.join(self.raw_dir, 'train_test_split',
|
215
215
|
f'shuffled_{split}_file_list.json')
|
216
|
-
with open(path
|
216
|
+
with open(path) as f:
|
217
217
|
filenames = [
|
218
218
|
osp.sep.join(name.split('/')[1:]) + '.txt'
|
219
219
|
for name in json.load(f)
|
@@ -6,7 +6,7 @@ from typing import Callable, List, Optional
|
|
6
6
|
import torch
|
7
7
|
|
8
8
|
from torch_geometric.data import InMemoryDataset, download_url, extract_zip
|
9
|
-
from torch_geometric.io import read_off, read_txt_array
|
9
|
+
from torch_geometric.io import fs, read_off, read_txt_array
|
10
10
|
|
11
11
|
|
12
12
|
class SHREC2016(InMemoryDataset):
|
@@ -79,7 +79,7 @@ class SHREC2016(InMemoryDataset):
|
|
79
79
|
self.cat = category.lower()
|
80
80
|
super().__init__(root, transform, pre_transform, pre_filter,
|
81
81
|
force_reload=force_reload)
|
82
|
-
self.__ref__ =
|
82
|
+
self.__ref__ = fs.torch_load(self.processed_paths[0])
|
83
83
|
path = self.processed_paths[1] if train else self.processed_paths[2]
|
84
84
|
self.load(path)
|
85
85
|
|
@@ -22,6 +22,9 @@ class EgoData(Data):
|
|
22
22
|
|
23
23
|
def read_ego(files: List[str], name: str) -> List[EgoData]:
|
24
24
|
import pandas as pd
|
25
|
+
import tqdm
|
26
|
+
|
27
|
+
files = sorted(files)
|
25
28
|
|
26
29
|
all_featnames = []
|
27
30
|
files = [
|
@@ -38,7 +41,7 @@ def read_ego(files: List[str], name: str) -> List[EgoData]:
|
|
38
41
|
all_featnames_dict = {key: i for i, key in enumerate(all_featnames)}
|
39
42
|
|
40
43
|
data_list = []
|
41
|
-
for i in range(0, len(files), 5):
|
44
|
+
for i in tqdm.tqdm(range(0, len(files), 5)):
|
42
45
|
circles_file = files[i]
|
43
46
|
edges_file = files[i + 1]
|
44
47
|
egofeat_file = files[i + 2]
|
@@ -65,6 +68,9 @@ def read_ego(files: List[str], name: str) -> List[EgoData]:
|
|
65
68
|
x_all[:, torch.tensor(indices)] = x
|
66
69
|
x = x_all
|
67
70
|
|
71
|
+
if x.size(1) > 100_000:
|
72
|
+
x = x.to_sparse_csr()
|
73
|
+
|
68
74
|
idx = pd.read_csv(feat_file, sep=' ', header=None, dtype=str,
|
69
75
|
usecols=[0]).squeeze()
|
70
76
|
|