pyg-nightly 2.6.0.dev20240318__py3-none-any.whl → 2.7.0.dev20250115__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- {pyg_nightly-2.6.0.dev20240318.dist-info → pyg_nightly-2.7.0.dev20250115.dist-info}/METADATA +31 -47
- {pyg_nightly-2.6.0.dev20240318.dist-info → pyg_nightly-2.7.0.dev20250115.dist-info}/RECORD +226 -199
- {pyg_nightly-2.6.0.dev20240318.dist-info → pyg_nightly-2.7.0.dev20250115.dist-info}/WHEEL +1 -1
- torch_geometric/__init__.py +28 -1
- torch_geometric/_compile.py +8 -1
- torch_geometric/_onnx.py +14 -0
- torch_geometric/config_mixin.py +113 -0
- torch_geometric/config_store.py +28 -19
- torch_geometric/data/__init__.py +24 -1
- torch_geometric/data/batch.py +2 -2
- torch_geometric/data/collate.py +8 -2
- torch_geometric/data/data.py +16 -8
- torch_geometric/data/database.py +61 -15
- torch_geometric/data/dataset.py +14 -6
- torch_geometric/data/feature_store.py +25 -42
- torch_geometric/data/graph_store.py +1 -5
- torch_geometric/data/hetero_data.py +18 -9
- torch_geometric/data/in_memory_dataset.py +2 -4
- torch_geometric/data/large_graph_indexer.py +677 -0
- torch_geometric/data/lightning/datamodule.py +4 -4
- torch_geometric/data/separate.py +6 -1
- torch_geometric/data/storage.py +17 -7
- torch_geometric/data/summary.py +14 -4
- torch_geometric/data/temporal.py +1 -2
- torch_geometric/datasets/__init__.py +17 -2
- torch_geometric/datasets/actor.py +9 -11
- torch_geometric/datasets/airfrans.py +15 -18
- torch_geometric/datasets/airports.py +10 -12
- torch_geometric/datasets/amazon.py +8 -11
- torch_geometric/datasets/amazon_book.py +9 -10
- torch_geometric/datasets/amazon_products.py +9 -10
- torch_geometric/datasets/aminer.py +8 -9
- torch_geometric/datasets/aqsol.py +10 -13
- torch_geometric/datasets/attributed_graph_dataset.py +10 -12
- torch_geometric/datasets/ba_multi_shapes.py +10 -12
- torch_geometric/datasets/ba_shapes.py +5 -6
- torch_geometric/datasets/bitcoin_otc.py +1 -1
- torch_geometric/datasets/brca_tgca.py +1 -1
- torch_geometric/datasets/cornell.py +145 -0
- torch_geometric/datasets/dblp.py +2 -1
- torch_geometric/datasets/dbp15k.py +2 -2
- torch_geometric/datasets/fake.py +1 -3
- torch_geometric/datasets/flickr.py +2 -1
- torch_geometric/datasets/freebase.py +1 -1
- torch_geometric/datasets/gdelt_lite.py +3 -2
- torch_geometric/datasets/ged_dataset.py +3 -2
- torch_geometric/datasets/git_mol_dataset.py +263 -0
- torch_geometric/datasets/gnn_benchmark_dataset.py +11 -10
- torch_geometric/datasets/hgb_dataset.py +8 -8
- torch_geometric/datasets/imdb.py +2 -1
- torch_geometric/datasets/karate.py +3 -2
- torch_geometric/datasets/last_fm.py +2 -1
- torch_geometric/datasets/linkx_dataset.py +4 -3
- torch_geometric/datasets/lrgb.py +3 -5
- torch_geometric/datasets/malnet_tiny.py +4 -3
- torch_geometric/datasets/mnist_superpixels.py +2 -3
- torch_geometric/datasets/molecule_gpt_dataset.py +485 -0
- torch_geometric/datasets/molecule_net.py +15 -3
- torch_geometric/datasets/motif_generator/base.py +0 -1
- torch_geometric/datasets/neurograph.py +1 -3
- torch_geometric/datasets/ogb_mag.py +1 -1
- torch_geometric/datasets/opf.py +239 -0
- torch_geometric/datasets/ose_gvcs.py +1 -1
- torch_geometric/datasets/pascal.py +11 -9
- torch_geometric/datasets/pascal_pf.py +1 -1
- torch_geometric/datasets/pcpnet_dataset.py +1 -1
- torch_geometric/datasets/pcqm4m.py +10 -3
- torch_geometric/datasets/ppi.py +1 -1
- torch_geometric/datasets/qm9.py +8 -7
- torch_geometric/datasets/rcdd.py +4 -4
- torch_geometric/datasets/reddit.py +2 -1
- torch_geometric/datasets/reddit2.py +2 -1
- torch_geometric/datasets/rel_link_pred_dataset.py +3 -3
- torch_geometric/datasets/s3dis.py +5 -3
- torch_geometric/datasets/shapenet.py +3 -3
- torch_geometric/datasets/shrec2016.py +2 -2
- torch_geometric/datasets/snap_dataset.py +7 -1
- torch_geometric/datasets/tag_dataset.py +350 -0
- torch_geometric/datasets/upfd.py +2 -1
- torch_geometric/datasets/web_qsp_dataset.py +246 -0
- torch_geometric/datasets/webkb.py +2 -2
- torch_geometric/datasets/wikics.py +1 -1
- torch_geometric/datasets/wikidata.py +3 -2
- torch_geometric/datasets/wikipedia_network.py +2 -2
- torch_geometric/datasets/willow_object_class.py +1 -1
- torch_geometric/datasets/word_net.py +2 -2
- torch_geometric/datasets/yelp.py +2 -1
- torch_geometric/datasets/zinc.py +1 -1
- torch_geometric/device.py +42 -0
- torch_geometric/distributed/local_feature_store.py +3 -2
- torch_geometric/distributed/local_graph_store.py +2 -1
- torch_geometric/distributed/partition.py +9 -8
- torch_geometric/edge_index.py +616 -438
- torch_geometric/explain/algorithm/base.py +0 -1
- torch_geometric/explain/algorithm/graphmask_explainer.py +1 -2
- torch_geometric/explain/algorithm/pg_explainer.py +1 -1
- torch_geometric/explain/explanation.py +2 -2
- torch_geometric/graphgym/checkpoint.py +2 -1
- torch_geometric/graphgym/logger.py +4 -4
- torch_geometric/graphgym/loss.py +1 -1
- torch_geometric/graphgym/utils/agg_runs.py +6 -6
- torch_geometric/index.py +826 -0
- torch_geometric/inspector.py +13 -7
- torch_geometric/io/fs.py +28 -2
- torch_geometric/io/npz.py +2 -1
- torch_geometric/io/off.py +2 -2
- torch_geometric/io/sdf.py +2 -2
- torch_geometric/io/tu.py +4 -5
- torch_geometric/loader/__init__.py +4 -0
- torch_geometric/loader/cluster.py +10 -4
- torch_geometric/loader/graph_saint.py +2 -1
- torch_geometric/loader/ibmb_loader.py +12 -4
- torch_geometric/loader/mixin.py +1 -1
- torch_geometric/loader/neighbor_loader.py +1 -1
- torch_geometric/loader/neighbor_sampler.py +2 -2
- torch_geometric/loader/prefetch.py +1 -1
- torch_geometric/loader/rag_loader.py +107 -0
- torch_geometric/loader/utils.py +8 -7
- torch_geometric/loader/zip_loader.py +10 -0
- torch_geometric/metrics/__init__.py +11 -2
- torch_geometric/metrics/link_pred.py +317 -65
- torch_geometric/nn/aggr/__init__.py +4 -0
- torch_geometric/nn/aggr/attention.py +0 -2
- torch_geometric/nn/aggr/base.py +3 -5
- torch_geometric/nn/aggr/patch_transformer.py +143 -0
- torch_geometric/nn/aggr/set_transformer.py +1 -1
- torch_geometric/nn/aggr/variance_preserving.py +33 -0
- torch_geometric/nn/attention/__init__.py +5 -1
- torch_geometric/nn/attention/qformer.py +71 -0
- torch_geometric/nn/conv/collect.jinja +7 -4
- torch_geometric/nn/conv/cugraph/base.py +8 -12
- torch_geometric/nn/conv/edge_conv.py +3 -2
- torch_geometric/nn/conv/fused_gat_conv.py +1 -1
- torch_geometric/nn/conv/gat_conv.py +35 -7
- torch_geometric/nn/conv/gatv2_conv.py +36 -6
- torch_geometric/nn/conv/general_conv.py +1 -1
- torch_geometric/nn/conv/graph_conv.py +21 -3
- torch_geometric/nn/conv/gravnet_conv.py +3 -2
- torch_geometric/nn/conv/hetero_conv.py +3 -3
- torch_geometric/nn/conv/hgt_conv.py +1 -1
- torch_geometric/nn/conv/message_passing.py +138 -87
- torch_geometric/nn/conv/mixhop_conv.py +1 -1
- torch_geometric/nn/conv/propagate.jinja +9 -1
- torch_geometric/nn/conv/rgcn_conv.py +5 -5
- torch_geometric/nn/conv/spline_conv.py +4 -4
- torch_geometric/nn/conv/x_conv.py +3 -2
- torch_geometric/nn/dense/linear.py +11 -6
- torch_geometric/nn/fx.py +3 -3
- torch_geometric/nn/model_hub.py +3 -1
- torch_geometric/nn/models/__init__.py +10 -2
- torch_geometric/nn/models/deep_graph_infomax.py +1 -2
- torch_geometric/nn/models/dimenet_utils.py +5 -7
- torch_geometric/nn/models/g_retriever.py +230 -0
- torch_geometric/nn/models/git_mol.py +336 -0
- torch_geometric/nn/models/glem.py +385 -0
- torch_geometric/nn/models/gnnff.py +0 -1
- torch_geometric/nn/models/graph_unet.py +12 -3
- torch_geometric/nn/models/jumping_knowledge.py +63 -4
- torch_geometric/nn/models/lightgcn.py +1 -1
- torch_geometric/nn/models/metapath2vec.py +5 -5
- torch_geometric/nn/models/molecule_gpt.py +222 -0
- torch_geometric/nn/models/node2vec.py +2 -3
- torch_geometric/nn/models/schnet.py +2 -1
- torch_geometric/nn/models/signed_gcn.py +3 -3
- torch_geometric/nn/module_dict.py +2 -2
- torch_geometric/nn/nlp/__init__.py +9 -0
- torch_geometric/nn/nlp/llm.py +329 -0
- torch_geometric/nn/nlp/sentence_transformer.py +134 -0
- torch_geometric/nn/nlp/vision_transformer.py +33 -0
- torch_geometric/nn/norm/batch_norm.py +1 -1
- torch_geometric/nn/parameter_dict.py +2 -2
- torch_geometric/nn/pool/__init__.py +21 -5
- torch_geometric/nn/pool/cluster_pool.py +145 -0
- torch_geometric/nn/pool/connect/base.py +0 -1
- torch_geometric/nn/pool/edge_pool.py +1 -1
- torch_geometric/nn/pool/graclus.py +4 -2
- torch_geometric/nn/pool/pool.py +8 -2
- torch_geometric/nn/pool/select/base.py +0 -1
- torch_geometric/nn/pool/voxel_grid.py +3 -2
- torch_geometric/nn/resolver.py +1 -1
- torch_geometric/nn/sequential.jinja +10 -23
- torch_geometric/nn/sequential.py +204 -78
- torch_geometric/nn/summary.py +1 -1
- torch_geometric/nn/to_hetero_with_bases_transformer.py +19 -19
- torch_geometric/profile/__init__.py +2 -0
- torch_geometric/profile/nvtx.py +66 -0
- torch_geometric/profile/profiler.py +30 -19
- torch_geometric/resolver.py +1 -1
- torch_geometric/sampler/base.py +34 -13
- torch_geometric/sampler/neighbor_sampler.py +11 -10
- torch_geometric/sampler/utils.py +1 -1
- torch_geometric/template.py +1 -0
- torch_geometric/testing/__init__.py +6 -2
- torch_geometric/testing/decorators.py +56 -22
- torch_geometric/testing/feature_store.py +1 -1
- torch_geometric/transforms/__init__.py +2 -0
- torch_geometric/transforms/add_metapaths.py +5 -5
- torch_geometric/transforms/add_positional_encoding.py +1 -1
- torch_geometric/transforms/delaunay.py +65 -14
- torch_geometric/transforms/face_to_edge.py +32 -3
- torch_geometric/transforms/gdc.py +7 -6
- torch_geometric/transforms/laplacian_lambda_max.py +3 -3
- torch_geometric/transforms/mask.py +5 -1
- torch_geometric/transforms/node_property_split.py +1 -2
- torch_geometric/transforms/pad.py +7 -6
- torch_geometric/transforms/random_link_split.py +1 -1
- torch_geometric/transforms/remove_self_loops.py +36 -0
- torch_geometric/transforms/svd_feature_reduction.py +1 -1
- torch_geometric/transforms/to_sparse_tensor.py +1 -1
- torch_geometric/transforms/two_hop.py +1 -1
- torch_geometric/transforms/virtual_node.py +2 -1
- torch_geometric/typing.py +43 -6
- torch_geometric/utils/__init__.py +5 -1
- torch_geometric/utils/_negative_sampling.py +1 -1
- torch_geometric/utils/_normalize_edge_index.py +46 -0
- torch_geometric/utils/_scatter.py +38 -12
- torch_geometric/utils/_subgraph.py +4 -0
- torch_geometric/utils/_tree_decomposition.py +2 -2
- torch_geometric/utils/augmentation.py +1 -1
- torch_geometric/utils/convert.py +12 -8
- torch_geometric/utils/geodesic.py +24 -22
- torch_geometric/utils/hetero.py +1 -1
- torch_geometric/utils/map.py +8 -2
- torch_geometric/utils/smiles.py +65 -27
- torch_geometric/utils/sparse.py +39 -25
- torch_geometric/visualization/graph.py +3 -4
@@ -99,7 +99,7 @@ class HGBDataset(InMemoryDataset):
|
|
99
99
|
# node_types = {0: 'paper', 1, 'author', ...}
|
100
100
|
# edge_types = {0: ('paper', 'cite', 'paper'), ...}
|
101
101
|
if self.name in ['acm', 'dblp', 'imdb']:
|
102
|
-
with open(self.raw_paths[0]
|
102
|
+
with open(self.raw_paths[0]) as f: # `info.dat`
|
103
103
|
info = json.load(f)
|
104
104
|
n_types = info['node.dat']['node type']
|
105
105
|
n_types = {int(k): v for k, v in n_types.items()}
|
@@ -112,7 +112,7 @@ class HGBDataset(InMemoryDataset):
|
|
112
112
|
e_types[key] = (src, rel, dst)
|
113
113
|
num_classes = len(info['label.dat']['node type']['0'])
|
114
114
|
elif self.name in ['freebase']:
|
115
|
-
with open(self.raw_paths[0]
|
115
|
+
with open(self.raw_paths[0]) as f: # `info.dat`
|
116
116
|
info = f.read().split('\n')
|
117
117
|
start = info.index('TYPE\tMEANING') + 1
|
118
118
|
end = info[start:].index('')
|
@@ -124,7 +124,7 @@ class HGBDataset(InMemoryDataset):
|
|
124
124
|
end = info[start:].index('')
|
125
125
|
for key, row in enumerate(info[start:start + end]):
|
126
126
|
row = row.split('\t')[1:]
|
127
|
-
src, dst, rel =
|
127
|
+
src, dst, rel = (v for v in row if v != '')
|
128
128
|
src, dst = n_types[int(src)], n_types[int(dst)]
|
129
129
|
rel = rel.split('-')[1]
|
130
130
|
e_types[key] = (src, rel, dst)
|
@@ -134,8 +134,8 @@ class HGBDataset(InMemoryDataset):
|
|
134
134
|
# Extract node information:
|
135
135
|
mapping_dict = {} # Maps global node indices to local ones.
|
136
136
|
x_dict = defaultdict(list)
|
137
|
-
num_nodes_dict: Dict[str, int] = defaultdict(
|
138
|
-
with open(self.raw_paths[1]
|
137
|
+
num_nodes_dict: Dict[str, int] = defaultdict(int)
|
138
|
+
with open(self.raw_paths[1]) as f: # `node.dat`
|
139
139
|
xs = [v.split('\t') for v in f.read().split('\n')[:-1]]
|
140
140
|
for x in xs:
|
141
141
|
n_id, n_type = int(x[0]), n_types[int(x[2])]
|
@@ -151,7 +151,7 @@ class HGBDataset(InMemoryDataset):
|
|
151
151
|
|
152
152
|
edge_index_dict = defaultdict(list)
|
153
153
|
edge_weight_dict = defaultdict(list)
|
154
|
-
with open(self.raw_paths[2]
|
154
|
+
with open(self.raw_paths[2]) as f: # `link.dat`
|
155
155
|
edges = [v.split('\t') for v in f.read().split('\n')[:-1]]
|
156
156
|
for src, dst, rel, weight in edges:
|
157
157
|
e_type = e_types[int(rel)]
|
@@ -168,9 +168,9 @@ class HGBDataset(InMemoryDataset):
|
|
168
168
|
|
169
169
|
# Node classification:
|
170
170
|
if self.name in ['acm', 'dblp', 'freebase', 'imdb']:
|
171
|
-
with open(self.raw_paths[3]
|
171
|
+
with open(self.raw_paths[3]) as f: # `label.dat`
|
172
172
|
train_ys = [v.split('\t') for v in f.read().split('\n')[:-1]]
|
173
|
-
with open(self.raw_paths[4]
|
173
|
+
with open(self.raw_paths[4]) as f: # `label.dat.test`
|
174
174
|
test_ys = [v.split('\t') for v in f.read().split('\n')[:-1]]
|
175
175
|
for y in train_ys:
|
176
176
|
n_id, n_type = mapping_dict[int(y[0])], n_types[int(y[2])]
|
torch_geometric/datasets/imdb.py
CHANGED
@@ -4,7 +4,6 @@ from itertools import product
|
|
4
4
|
from typing import Callable, List, Optional
|
5
5
|
|
6
6
|
import numpy as np
|
7
|
-
import scipy.sparse as sp
|
8
7
|
import torch
|
9
8
|
|
10
9
|
from torch_geometric.data import (
|
@@ -69,6 +68,8 @@ class IMDB(InMemoryDataset):
|
|
69
68
|
os.remove(path)
|
70
69
|
|
71
70
|
def process(self) -> None:
|
71
|
+
import scipy.sparse as sp
|
72
|
+
|
72
73
|
data = HeteroData()
|
73
74
|
|
74
75
|
node_types = ['movie', 'director', 'actor']
|
@@ -8,8 +8,9 @@ from torch_geometric.data import Data, InMemoryDataset
|
|
8
8
|
class KarateClub(InMemoryDataset):
|
9
9
|
r"""Zachary's karate club network from the `"An Information Flow Model for
|
10
10
|
Conflict and Fission in Small Groups"
|
11
|
-
<
|
12
|
-
34 nodes,
|
11
|
+
<https://www.journals.uchicago.edu/doi/abs/10.1086/jar.33.4.3629752>`_
|
12
|
+
paper, containing 34 nodes,
|
13
|
+
connected by 156 (undirected and unweighted) edges.
|
13
14
|
Every node is labeled by one of four classes obtained via modularity-based
|
14
15
|
clustering, following the `"Semi-supervised Classification with Graph
|
15
16
|
Convolutional Networks" <https://arxiv.org/abs/1609.02907>`_ paper.
|
@@ -4,7 +4,6 @@ from itertools import product
|
|
4
4
|
from typing import Callable, List, Optional
|
5
5
|
|
6
6
|
import numpy as np
|
7
|
-
import scipy.sparse as sp
|
8
7
|
import torch
|
9
8
|
|
10
9
|
from torch_geometric.data import (
|
@@ -68,6 +67,8 @@ class LastFM(InMemoryDataset):
|
|
68
67
|
os.remove(path)
|
69
68
|
|
70
69
|
def process(self) -> None:
|
70
|
+
import scipy.sparse as sp
|
71
|
+
|
71
72
|
data = HeteroData()
|
72
73
|
|
73
74
|
node_type_idx = np.load(osp.join(self.raw_dir, 'node_types.npy'))
|
@@ -5,6 +5,7 @@ import numpy as np
|
|
5
5
|
import torch
|
6
6
|
|
7
7
|
from torch_geometric.data import Data, InMemoryDataset, download_url
|
8
|
+
from torch_geometric.io import fs
|
8
9
|
from torch_geometric.utils import one_hot
|
9
10
|
|
10
11
|
|
@@ -115,9 +116,9 @@ class LINKXDataset(InMemoryDataset):
|
|
115
116
|
|
116
117
|
def _process_wiki(self) -> Data:
|
117
118
|
paths = {x.split('/')[-1]: x for x in self.raw_paths}
|
118
|
-
x =
|
119
|
-
edge_index =
|
120
|
-
y =
|
119
|
+
x = fs.torch_load(paths['wiki_features2M.pt'])
|
120
|
+
edge_index = fs.torch_load(paths['wiki_edges2M.pt']).t().contiguous()
|
121
|
+
y = fs.torch_load(paths['wiki_views2M.pt'])
|
121
122
|
|
122
123
|
return Data(x=x, edge_index=edge_index, y=y)
|
123
124
|
|
torch_geometric/datasets/lrgb.py
CHANGED
@@ -188,9 +188,8 @@ class LRGBDataset(InMemoryDataset):
|
|
188
188
|
graphs = pickle.load(f)
|
189
189
|
elif self.name.split('-')[0] == 'peptides':
|
190
190
|
# Peptides-func and Peptides-struct
|
191
|
-
|
192
|
-
|
193
|
-
graphs = torch.load(f)
|
191
|
+
graphs = fs.torch_load(
|
192
|
+
osp.join(self.raw_dir, f'{split}.pt'))
|
194
193
|
|
195
194
|
data_list = []
|
196
195
|
for graph in tqdm(graphs, desc=f'Processing {split} dataset'):
|
@@ -260,8 +259,7 @@ class LRGBDataset(InMemoryDataset):
|
|
260
259
|
|
261
260
|
def process_pcqm_contact(self) -> None:
|
262
261
|
for split in ['train', 'val', 'test']:
|
263
|
-
|
264
|
-
graphs = torch.load(f)
|
262
|
+
graphs = fs.torch_load(osp.join(self.raw_dir, f'{split}.pt'))
|
265
263
|
|
266
264
|
data_list = []
|
267
265
|
for graph in tqdm(graphs, desc=f'Processing {split} dataset'):
|
@@ -11,6 +11,7 @@ from torch_geometric.data import (
|
|
11
11
|
extract_tar,
|
12
12
|
extract_zip,
|
13
13
|
)
|
14
|
+
from torch_geometric.io import fs
|
14
15
|
|
15
16
|
|
16
17
|
class MalNetTiny(InMemoryDataset):
|
@@ -65,7 +66,7 @@ class MalNetTiny(InMemoryDataset):
|
|
65
66
|
self.load(self.processed_paths[0])
|
66
67
|
|
67
68
|
if split is not None:
|
68
|
-
split_slices =
|
69
|
+
split_slices = fs.torch_load(self.processed_paths[1])
|
69
70
|
if split == 'train':
|
70
71
|
self._indices = range(split_slices[0], split_slices[1])
|
71
72
|
elif split == 'val':
|
@@ -98,7 +99,7 @@ class MalNetTiny(InMemoryDataset):
|
|
98
99
|
split_slices = [0]
|
99
100
|
|
100
101
|
for split in ['train', 'val', 'test']:
|
101
|
-
with open(osp.join(self.raw_paths[1], f'{split}.txt')
|
102
|
+
with open(osp.join(self.raw_paths[1], f'{split}.txt')) as f:
|
102
103
|
filenames = f.read().split('\n')[:-1]
|
103
104
|
split_slices.append(split_slices[-1] + len(filenames))
|
104
105
|
|
@@ -107,7 +108,7 @@ class MalNetTiny(InMemoryDataset):
|
|
107
108
|
malware_type = filename.split('/')[0]
|
108
109
|
y = y_map.setdefault(malware_type, len(y_map))
|
109
110
|
|
110
|
-
with open(path
|
111
|
+
with open(path) as f:
|
111
112
|
edges = f.read().split('\n')[5:-1]
|
112
113
|
|
113
114
|
edge_indices = [[int(s) for s in e.split()] for e in edges]
|
@@ -1,14 +1,13 @@
|
|
1
1
|
import os
|
2
2
|
from typing import Callable, List, Optional
|
3
3
|
|
4
|
-
import torch
|
5
|
-
|
6
4
|
from torch_geometric.data import (
|
7
5
|
Data,
|
8
6
|
InMemoryDataset,
|
9
7
|
download_url,
|
10
8
|
extract_zip,
|
11
9
|
)
|
10
|
+
from torch_geometric.io import fs
|
12
11
|
|
13
12
|
|
14
13
|
class MNISTSuperpixels(InMemoryDataset):
|
@@ -85,7 +84,7 @@ class MNISTSuperpixels(InMemoryDataset):
|
|
85
84
|
os.unlink(path)
|
86
85
|
|
87
86
|
def process(self) -> None:
|
88
|
-
inputs =
|
87
|
+
inputs = fs.torch_load(self.raw_paths[0])
|
89
88
|
for i in range(len(inputs)):
|
90
89
|
data_list = [Data(**data_dict) for data_dict in inputs[i]]
|
91
90
|
|