pyg-nightly 2.6.0.dev20240318__py3-none-any.whl → 2.7.0.dev20250115__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- {pyg_nightly-2.6.0.dev20240318.dist-info → pyg_nightly-2.7.0.dev20250115.dist-info}/METADATA +31 -47
- {pyg_nightly-2.6.0.dev20240318.dist-info → pyg_nightly-2.7.0.dev20250115.dist-info}/RECORD +226 -199
- {pyg_nightly-2.6.0.dev20240318.dist-info → pyg_nightly-2.7.0.dev20250115.dist-info}/WHEEL +1 -1
- torch_geometric/__init__.py +28 -1
- torch_geometric/_compile.py +8 -1
- torch_geometric/_onnx.py +14 -0
- torch_geometric/config_mixin.py +113 -0
- torch_geometric/config_store.py +28 -19
- torch_geometric/data/__init__.py +24 -1
- torch_geometric/data/batch.py +2 -2
- torch_geometric/data/collate.py +8 -2
- torch_geometric/data/data.py +16 -8
- torch_geometric/data/database.py +61 -15
- torch_geometric/data/dataset.py +14 -6
- torch_geometric/data/feature_store.py +25 -42
- torch_geometric/data/graph_store.py +1 -5
- torch_geometric/data/hetero_data.py +18 -9
- torch_geometric/data/in_memory_dataset.py +2 -4
- torch_geometric/data/large_graph_indexer.py +677 -0
- torch_geometric/data/lightning/datamodule.py +4 -4
- torch_geometric/data/separate.py +6 -1
- torch_geometric/data/storage.py +17 -7
- torch_geometric/data/summary.py +14 -4
- torch_geometric/data/temporal.py +1 -2
- torch_geometric/datasets/__init__.py +17 -2
- torch_geometric/datasets/actor.py +9 -11
- torch_geometric/datasets/airfrans.py +15 -18
- torch_geometric/datasets/airports.py +10 -12
- torch_geometric/datasets/amazon.py +8 -11
- torch_geometric/datasets/amazon_book.py +9 -10
- torch_geometric/datasets/amazon_products.py +9 -10
- torch_geometric/datasets/aminer.py +8 -9
- torch_geometric/datasets/aqsol.py +10 -13
- torch_geometric/datasets/attributed_graph_dataset.py +10 -12
- torch_geometric/datasets/ba_multi_shapes.py +10 -12
- torch_geometric/datasets/ba_shapes.py +5 -6
- torch_geometric/datasets/bitcoin_otc.py +1 -1
- torch_geometric/datasets/brca_tgca.py +1 -1
- torch_geometric/datasets/cornell.py +145 -0
- torch_geometric/datasets/dblp.py +2 -1
- torch_geometric/datasets/dbp15k.py +2 -2
- torch_geometric/datasets/fake.py +1 -3
- torch_geometric/datasets/flickr.py +2 -1
- torch_geometric/datasets/freebase.py +1 -1
- torch_geometric/datasets/gdelt_lite.py +3 -2
- torch_geometric/datasets/ged_dataset.py +3 -2
- torch_geometric/datasets/git_mol_dataset.py +263 -0
- torch_geometric/datasets/gnn_benchmark_dataset.py +11 -10
- torch_geometric/datasets/hgb_dataset.py +8 -8
- torch_geometric/datasets/imdb.py +2 -1
- torch_geometric/datasets/karate.py +3 -2
- torch_geometric/datasets/last_fm.py +2 -1
- torch_geometric/datasets/linkx_dataset.py +4 -3
- torch_geometric/datasets/lrgb.py +3 -5
- torch_geometric/datasets/malnet_tiny.py +4 -3
- torch_geometric/datasets/mnist_superpixels.py +2 -3
- torch_geometric/datasets/molecule_gpt_dataset.py +485 -0
- torch_geometric/datasets/molecule_net.py +15 -3
- torch_geometric/datasets/motif_generator/base.py +0 -1
- torch_geometric/datasets/neurograph.py +1 -3
- torch_geometric/datasets/ogb_mag.py +1 -1
- torch_geometric/datasets/opf.py +239 -0
- torch_geometric/datasets/ose_gvcs.py +1 -1
- torch_geometric/datasets/pascal.py +11 -9
- torch_geometric/datasets/pascal_pf.py +1 -1
- torch_geometric/datasets/pcpnet_dataset.py +1 -1
- torch_geometric/datasets/pcqm4m.py +10 -3
- torch_geometric/datasets/ppi.py +1 -1
- torch_geometric/datasets/qm9.py +8 -7
- torch_geometric/datasets/rcdd.py +4 -4
- torch_geometric/datasets/reddit.py +2 -1
- torch_geometric/datasets/reddit2.py +2 -1
- torch_geometric/datasets/rel_link_pred_dataset.py +3 -3
- torch_geometric/datasets/s3dis.py +5 -3
- torch_geometric/datasets/shapenet.py +3 -3
- torch_geometric/datasets/shrec2016.py +2 -2
- torch_geometric/datasets/snap_dataset.py +7 -1
- torch_geometric/datasets/tag_dataset.py +350 -0
- torch_geometric/datasets/upfd.py +2 -1
- torch_geometric/datasets/web_qsp_dataset.py +246 -0
- torch_geometric/datasets/webkb.py +2 -2
- torch_geometric/datasets/wikics.py +1 -1
- torch_geometric/datasets/wikidata.py +3 -2
- torch_geometric/datasets/wikipedia_network.py +2 -2
- torch_geometric/datasets/willow_object_class.py +1 -1
- torch_geometric/datasets/word_net.py +2 -2
- torch_geometric/datasets/yelp.py +2 -1
- torch_geometric/datasets/zinc.py +1 -1
- torch_geometric/device.py +42 -0
- torch_geometric/distributed/local_feature_store.py +3 -2
- torch_geometric/distributed/local_graph_store.py +2 -1
- torch_geometric/distributed/partition.py +9 -8
- torch_geometric/edge_index.py +616 -438
- torch_geometric/explain/algorithm/base.py +0 -1
- torch_geometric/explain/algorithm/graphmask_explainer.py +1 -2
- torch_geometric/explain/algorithm/pg_explainer.py +1 -1
- torch_geometric/explain/explanation.py +2 -2
- torch_geometric/graphgym/checkpoint.py +2 -1
- torch_geometric/graphgym/logger.py +4 -4
- torch_geometric/graphgym/loss.py +1 -1
- torch_geometric/graphgym/utils/agg_runs.py +6 -6
- torch_geometric/index.py +826 -0
- torch_geometric/inspector.py +13 -7
- torch_geometric/io/fs.py +28 -2
- torch_geometric/io/npz.py +2 -1
- torch_geometric/io/off.py +2 -2
- torch_geometric/io/sdf.py +2 -2
- torch_geometric/io/tu.py +4 -5
- torch_geometric/loader/__init__.py +4 -0
- torch_geometric/loader/cluster.py +10 -4
- torch_geometric/loader/graph_saint.py +2 -1
- torch_geometric/loader/ibmb_loader.py +12 -4
- torch_geometric/loader/mixin.py +1 -1
- torch_geometric/loader/neighbor_loader.py +1 -1
- torch_geometric/loader/neighbor_sampler.py +2 -2
- torch_geometric/loader/prefetch.py +1 -1
- torch_geometric/loader/rag_loader.py +107 -0
- torch_geometric/loader/utils.py +8 -7
- torch_geometric/loader/zip_loader.py +10 -0
- torch_geometric/metrics/__init__.py +11 -2
- torch_geometric/metrics/link_pred.py +317 -65
- torch_geometric/nn/aggr/__init__.py +4 -0
- torch_geometric/nn/aggr/attention.py +0 -2
- torch_geometric/nn/aggr/base.py +3 -5
- torch_geometric/nn/aggr/patch_transformer.py +143 -0
- torch_geometric/nn/aggr/set_transformer.py +1 -1
- torch_geometric/nn/aggr/variance_preserving.py +33 -0
- torch_geometric/nn/attention/__init__.py +5 -1
- torch_geometric/nn/attention/qformer.py +71 -0
- torch_geometric/nn/conv/collect.jinja +7 -4
- torch_geometric/nn/conv/cugraph/base.py +8 -12
- torch_geometric/nn/conv/edge_conv.py +3 -2
- torch_geometric/nn/conv/fused_gat_conv.py +1 -1
- torch_geometric/nn/conv/gat_conv.py +35 -7
- torch_geometric/nn/conv/gatv2_conv.py +36 -6
- torch_geometric/nn/conv/general_conv.py +1 -1
- torch_geometric/nn/conv/graph_conv.py +21 -3
- torch_geometric/nn/conv/gravnet_conv.py +3 -2
- torch_geometric/nn/conv/hetero_conv.py +3 -3
- torch_geometric/nn/conv/hgt_conv.py +1 -1
- torch_geometric/nn/conv/message_passing.py +138 -87
- torch_geometric/nn/conv/mixhop_conv.py +1 -1
- torch_geometric/nn/conv/propagate.jinja +9 -1
- torch_geometric/nn/conv/rgcn_conv.py +5 -5
- torch_geometric/nn/conv/spline_conv.py +4 -4
- torch_geometric/nn/conv/x_conv.py +3 -2
- torch_geometric/nn/dense/linear.py +11 -6
- torch_geometric/nn/fx.py +3 -3
- torch_geometric/nn/model_hub.py +3 -1
- torch_geometric/nn/models/__init__.py +10 -2
- torch_geometric/nn/models/deep_graph_infomax.py +1 -2
- torch_geometric/nn/models/dimenet_utils.py +5 -7
- torch_geometric/nn/models/g_retriever.py +230 -0
- torch_geometric/nn/models/git_mol.py +336 -0
- torch_geometric/nn/models/glem.py +385 -0
- torch_geometric/nn/models/gnnff.py +0 -1
- torch_geometric/nn/models/graph_unet.py +12 -3
- torch_geometric/nn/models/jumping_knowledge.py +63 -4
- torch_geometric/nn/models/lightgcn.py +1 -1
- torch_geometric/nn/models/metapath2vec.py +5 -5
- torch_geometric/nn/models/molecule_gpt.py +222 -0
- torch_geometric/nn/models/node2vec.py +2 -3
- torch_geometric/nn/models/schnet.py +2 -1
- torch_geometric/nn/models/signed_gcn.py +3 -3
- torch_geometric/nn/module_dict.py +2 -2
- torch_geometric/nn/nlp/__init__.py +9 -0
- torch_geometric/nn/nlp/llm.py +329 -0
- torch_geometric/nn/nlp/sentence_transformer.py +134 -0
- torch_geometric/nn/nlp/vision_transformer.py +33 -0
- torch_geometric/nn/norm/batch_norm.py +1 -1
- torch_geometric/nn/parameter_dict.py +2 -2
- torch_geometric/nn/pool/__init__.py +21 -5
- torch_geometric/nn/pool/cluster_pool.py +145 -0
- torch_geometric/nn/pool/connect/base.py +0 -1
- torch_geometric/nn/pool/edge_pool.py +1 -1
- torch_geometric/nn/pool/graclus.py +4 -2
- torch_geometric/nn/pool/pool.py +8 -2
- torch_geometric/nn/pool/select/base.py +0 -1
- torch_geometric/nn/pool/voxel_grid.py +3 -2
- torch_geometric/nn/resolver.py +1 -1
- torch_geometric/nn/sequential.jinja +10 -23
- torch_geometric/nn/sequential.py +204 -78
- torch_geometric/nn/summary.py +1 -1
- torch_geometric/nn/to_hetero_with_bases_transformer.py +19 -19
- torch_geometric/profile/__init__.py +2 -0
- torch_geometric/profile/nvtx.py +66 -0
- torch_geometric/profile/profiler.py +30 -19
- torch_geometric/resolver.py +1 -1
- torch_geometric/sampler/base.py +34 -13
- torch_geometric/sampler/neighbor_sampler.py +11 -10
- torch_geometric/sampler/utils.py +1 -1
- torch_geometric/template.py +1 -0
- torch_geometric/testing/__init__.py +6 -2
- torch_geometric/testing/decorators.py +56 -22
- torch_geometric/testing/feature_store.py +1 -1
- torch_geometric/transforms/__init__.py +2 -0
- torch_geometric/transforms/add_metapaths.py +5 -5
- torch_geometric/transforms/add_positional_encoding.py +1 -1
- torch_geometric/transforms/delaunay.py +65 -14
- torch_geometric/transforms/face_to_edge.py +32 -3
- torch_geometric/transforms/gdc.py +7 -6
- torch_geometric/transforms/laplacian_lambda_max.py +3 -3
- torch_geometric/transforms/mask.py +5 -1
- torch_geometric/transforms/node_property_split.py +1 -2
- torch_geometric/transforms/pad.py +7 -6
- torch_geometric/transforms/random_link_split.py +1 -1
- torch_geometric/transforms/remove_self_loops.py +36 -0
- torch_geometric/transforms/svd_feature_reduction.py +1 -1
- torch_geometric/transforms/to_sparse_tensor.py +1 -1
- torch_geometric/transforms/two_hop.py +1 -1
- torch_geometric/transforms/virtual_node.py +2 -1
- torch_geometric/typing.py +43 -6
- torch_geometric/utils/__init__.py +5 -1
- torch_geometric/utils/_negative_sampling.py +1 -1
- torch_geometric/utils/_normalize_edge_index.py +46 -0
- torch_geometric/utils/_scatter.py +38 -12
- torch_geometric/utils/_subgraph.py +4 -0
- torch_geometric/utils/_tree_decomposition.py +2 -2
- torch_geometric/utils/augmentation.py +1 -1
- torch_geometric/utils/convert.py +12 -8
- torch_geometric/utils/geodesic.py +24 -22
- torch_geometric/utils/hetero.py +1 -1
- torch_geometric/utils/map.py +8 -2
- torch_geometric/utils/smiles.py +65 -27
- torch_geometric/utils/sparse.py +39 -25
- torch_geometric/visualization/graph.py +3 -4
@@ -30,15 +30,14 @@ class BAShapes(InMemoryDataset):
|
|
30
30
|
:class:`torch_geometric.datasets.graph_generator.BAGraph` instead.
|
31
31
|
|
32
32
|
Args:
|
33
|
-
connection_distribution
|
34
|
-
|
33
|
+
connection_distribution: Specifies how the houses and the BA graph get
|
34
|
+
connected. Valid inputs are :obj:`"random"`
|
35
35
|
(random BA graph nodes are selected for connection to the houses),
|
36
36
|
and :obj:`"uniform"` (uniformly distributed BA graph nodes are
|
37
|
-
selected for connection to the houses).
|
38
|
-
transform
|
39
|
-
:
|
37
|
+
selected for connection to the houses).
|
38
|
+
transform: A function/transform that takes in a
|
39
|
+
:class:`torch_geometric.data.Data` object and returns a transformed
|
40
40
|
version. The data object will be transformed before every access.
|
41
|
-
(default: :obj:`None`)
|
42
41
|
"""
|
43
42
|
def __init__(
|
44
43
|
self,
|
@@ -94,7 +94,7 @@ class BrcaTcga(InMemoryDataset):
|
|
94
94
|
graph_feat = torch.from_numpy(graph_feat).to(torch.float)
|
95
95
|
graph_labels = np.loadtxt(self.raw_paths[1], delimiter=',')
|
96
96
|
graph_label = torch.from_numpy(graph_labels).to(torch.float)
|
97
|
-
edge_index =
|
97
|
+
edge_index = fs.torch_load(self.raw_paths[2])
|
98
98
|
|
99
99
|
data_list = []
|
100
100
|
for x, y in zip(graph_feat, graph_label):
|
@@ -0,0 +1,145 @@
|
|
1
|
+
import os.path as osp
|
2
|
+
from typing import Callable, List, Optional
|
3
|
+
|
4
|
+
import torch
|
5
|
+
|
6
|
+
from torch_geometric.data import InMemoryDataset, download_url
|
7
|
+
from torch_geometric.data.hypergraph_data import HyperGraphData
|
8
|
+
|
9
|
+
|
10
|
+
class CornellTemporalHyperGraphDataset(InMemoryDataset):
|
11
|
+
r"""A collection of temporal higher-order network datasets from the
|
12
|
+
`"Simplicial Closure and higher-order link prediction"
|
13
|
+
<https://arxiv.org/abs/1802.06916>`_ paper.
|
14
|
+
Each of the datasets is a timestamped sequence of simplices, where a
|
15
|
+
simplex is a set of :math:`k` nodes.
|
16
|
+
|
17
|
+
See the original `datasets page
|
18
|
+
<https://www.cs.cornell.edu/~arb/data/>`_ for more details about
|
19
|
+
individual datasets.
|
20
|
+
|
21
|
+
Args:
|
22
|
+
root (str): Root directory where the dataset should be saved.
|
23
|
+
name (str): The name of the dataset.
|
24
|
+
split (str, optional): If :obj:`"train"`, loads the training dataset.
|
25
|
+
If :obj:`"val"`, loads the validation dataset.
|
26
|
+
If :obj:`"test"`, loads the test dataset.
|
27
|
+
(default: :obj:`"train"`)
|
28
|
+
setting (str, optional): If :obj:`"transductive"`, loads the dataset
|
29
|
+
for transductive training.
|
30
|
+
If :obj:`"inductive"`, loads the dataset for inductive training.
|
31
|
+
(default: :obj:`"transductive"`)
|
32
|
+
transform (callable, optional): A function/transform that takes in an
|
33
|
+
:obj:`torch_geometric.data.Data` object and returns a transformed
|
34
|
+
version. The data object will be transformed before every access.
|
35
|
+
(default: :obj:`None`)
|
36
|
+
pre_transform (callable, optional): A function/transform that takes in
|
37
|
+
an :obj:`torch_geometric.data.Data` object and returns a
|
38
|
+
transformed version. The data object will be transformed before
|
39
|
+
being saved to disk. (default: :obj:`None`)
|
40
|
+
pre_filter (callable, optional): A function that takes in an
|
41
|
+
:obj:`torch_geometric.data.Data` object and returns a boolean
|
42
|
+
value, indicating whether the data object should be included in the
|
43
|
+
final dataset. (default: :obj:`None`)
|
44
|
+
force_reload (bool, optional): Whether to re-process the dataset.
|
45
|
+
(default: :obj:`False`)
|
46
|
+
"""
|
47
|
+
names = [
|
48
|
+
'email-Eu',
|
49
|
+
'email-Enron',
|
50
|
+
'NDC-classes',
|
51
|
+
'tags-math-sx',
|
52
|
+
'email-Eu-25',
|
53
|
+
'NDC-substances',
|
54
|
+
'congress-bills',
|
55
|
+
'tags-ask-ubuntu',
|
56
|
+
'email-Enron-25',
|
57
|
+
'NDC-classes-25',
|
58
|
+
'threads-ask-ubuntu',
|
59
|
+
'contact-high-school',
|
60
|
+
'NDC-substances-25',
|
61
|
+
'congress-bills-25',
|
62
|
+
'contact-primary-school',
|
63
|
+
]
|
64
|
+
url = ('https://huggingface.co/datasets/SauravMaheshkar/{}/raw/main/'
|
65
|
+
'processed/{}/{}')
|
66
|
+
|
67
|
+
def __init__(
|
68
|
+
self,
|
69
|
+
root: str,
|
70
|
+
name: str,
|
71
|
+
split: str = 'train',
|
72
|
+
setting: str = 'transductive',
|
73
|
+
transform: Optional[Callable] = None,
|
74
|
+
pre_transform: Optional[Callable] = None,
|
75
|
+
pre_filter: Optional[Callable] = None,
|
76
|
+
force_reload: bool = False,
|
77
|
+
) -> None:
|
78
|
+
assert name in self.names
|
79
|
+
assert setting in ['transductive', 'inductive']
|
80
|
+
|
81
|
+
self.name = name
|
82
|
+
self.setting = setting
|
83
|
+
|
84
|
+
super().__init__(root, transform, pre_transform, pre_filter,
|
85
|
+
force_reload)
|
86
|
+
|
87
|
+
if split == 'train':
|
88
|
+
path = self.processed_paths[0]
|
89
|
+
elif split == 'val':
|
90
|
+
path = self.processed_paths[1]
|
91
|
+
elif split == 'test':
|
92
|
+
path = self.processed_paths[2]
|
93
|
+
else:
|
94
|
+
raise ValueError(f"Split '{split}' not found")
|
95
|
+
|
96
|
+
self.load(path)
|
97
|
+
|
98
|
+
@property
|
99
|
+
def raw_dir(self) -> str:
|
100
|
+
return osp.join(self.root, self.name, self.setting, 'raw')
|
101
|
+
|
102
|
+
@property
|
103
|
+
def raw_file_names(self) -> List[str]:
|
104
|
+
return ['train_df.csv', 'val_df.csv', 'test_df.csv']
|
105
|
+
|
106
|
+
@property
|
107
|
+
def processed_dir(self) -> str:
|
108
|
+
return osp.join(self.root, self.name, self.setting, 'processed')
|
109
|
+
|
110
|
+
@property
|
111
|
+
def processed_file_names(self) -> List[str]:
|
112
|
+
return ['train_data.pt', 'val_data.pt', 'test_data.pt']
|
113
|
+
|
114
|
+
def download(self) -> None:
|
115
|
+
for filename in self.raw_file_names:
|
116
|
+
url = self.url.format(self.name, self.setting, filename)
|
117
|
+
download_url(url, self.raw_dir)
|
118
|
+
|
119
|
+
def process(self) -> None:
|
120
|
+
import pandas as pd
|
121
|
+
|
122
|
+
for raw_path, path in zip(self.raw_paths, self.processed_paths):
|
123
|
+
df = pd.read_csv(raw_path)
|
124
|
+
|
125
|
+
data_list = []
|
126
|
+
for i, row in df.iterrows():
|
127
|
+
edge_indices: List[List[int]] = [[], []]
|
128
|
+
for node in eval(row['nodes']): # str(list) -> list:
|
129
|
+
edge_indices[0].append(node)
|
130
|
+
edge_indices[1].append(i) # Use `i` as hyper-edge index.
|
131
|
+
|
132
|
+
x = torch.tensor([[row['timestamp']]], dtype=torch.float)
|
133
|
+
edge_index = torch.tensor(edge_indices)
|
134
|
+
|
135
|
+
data = HyperGraphData(x=x, edge_index=edge_index)
|
136
|
+
|
137
|
+
if self.pre_filter is not None and not self.pre_filter(data):
|
138
|
+
continue
|
139
|
+
|
140
|
+
if self.pre_transform is not None:
|
141
|
+
data = self.pre_transform(data)
|
142
|
+
|
143
|
+
data_list.append(data)
|
144
|
+
|
145
|
+
self.save(data_list, path)
|
torch_geometric/datasets/dblp.py
CHANGED
@@ -4,7 +4,6 @@ from itertools import product
|
|
4
4
|
from typing import Callable, List, Optional
|
5
5
|
|
6
6
|
import numpy as np
|
7
|
-
import scipy.sparse as sp
|
8
7
|
import torch
|
9
8
|
|
10
9
|
from torch_geometric.data import (
|
@@ -110,6 +109,8 @@ class DBLP(InMemoryDataset):
|
|
110
109
|
os.remove(path)
|
111
110
|
|
112
111
|
def process(self) -> None:
|
112
|
+
import scipy.sparse as sp
|
113
|
+
|
113
114
|
data = HeteroData()
|
114
115
|
|
115
116
|
node_types = ['author', 'paper', 'term', 'conference']
|
@@ -72,7 +72,7 @@ class DBP15K(InMemoryDataset):
|
|
72
72
|
|
73
73
|
def process(self) -> None:
|
74
74
|
embs = {}
|
75
|
-
with open(osp.join(self.raw_dir, 'sub.glove.300d')
|
75
|
+
with open(osp.join(self.raw_dir, 'sub.glove.300d')) as f:
|
76
76
|
for i, line in enumerate(f):
|
77
77
|
info = line.strip().split(' ')
|
78
78
|
if len(info) > 300:
|
@@ -112,7 +112,7 @@ class DBP15K(InMemoryDataset):
|
|
112
112
|
subj, rel, obj = g1.t()
|
113
113
|
|
114
114
|
x_dict = {}
|
115
|
-
with open(feature_path
|
115
|
+
with open(feature_path) as f:
|
116
116
|
for line in f:
|
117
117
|
info = line.strip().split('\t')
|
118
118
|
info = info if len(info) == 2 else info + ['**UNK**']
|
torch_geometric/datasets/fake.py
CHANGED
@@ -170,7 +170,7 @@ class FakeHeteroDataset(InMemoryDataset):
|
|
170
170
|
random.shuffle(edge_types)
|
171
171
|
|
172
172
|
self.edge_types: List[Tuple[str, str, str]] = []
|
173
|
-
count: Dict[Tuple[str, str], int] = defaultdict(
|
173
|
+
count: Dict[Tuple[str, str], int] = defaultdict(int)
|
174
174
|
for edge_type in edge_types[:max(num_edge_types, 1)]:
|
175
175
|
rel = f'e{count[edge_type]}'
|
176
176
|
count[edge_type] += 1
|
@@ -222,8 +222,6 @@ class FakeHeteroDataset(InMemoryDataset):
|
|
222
222
|
elif self.edge_dim == 1:
|
223
223
|
store.edge_weight = torch.rand(store.num_edges)
|
224
224
|
|
225
|
-
pass
|
226
|
-
|
227
225
|
if self._num_classes > 0 and self.task == 'graph':
|
228
226
|
data.y = torch.tensor([random.randint(0, self._num_classes - 1)])
|
229
227
|
|
@@ -3,7 +3,6 @@ import os.path as osp
|
|
3
3
|
from typing import Callable, List, Optional
|
4
4
|
|
5
5
|
import numpy as np
|
6
|
-
import scipy.sparse as sp
|
7
6
|
import torch
|
8
7
|
|
9
8
|
from torch_geometric.data import Data, InMemoryDataset, download_google_url
|
@@ -73,6 +72,8 @@ class Flickr(InMemoryDataset):
|
|
73
72
|
download_google_url(self.role_id, self.raw_dir, 'role.json')
|
74
73
|
|
75
74
|
def process(self) -> None:
|
75
|
+
import scipy.sparse as sp
|
76
|
+
|
76
77
|
f = np.load(osp.join(self.raw_dir, 'adj_full.npz'))
|
77
78
|
adj = sp.csr_matrix((f['data'], f['indices'], f['indptr']), f['shape'])
|
78
79
|
adj = adj.tocoo()
|
@@ -75,7 +75,7 @@ class FB15k_237(InMemoryDataset):
|
|
75
75
|
rel_dict: Dict[str, int] = {}
|
76
76
|
|
77
77
|
for path in self.raw_paths:
|
78
|
-
with open(path
|
78
|
+
with open(path) as f:
|
79
79
|
lines = [x.split('\t') for x in f.read().split('\n')[:-1]]
|
80
80
|
|
81
81
|
edge_index = torch.empty((2, len(lines)), dtype=torch.long)
|
@@ -9,6 +9,7 @@ from torch_geometric.data import (
|
|
9
9
|
download_url,
|
10
10
|
extract_zip,
|
11
11
|
)
|
12
|
+
from torch_geometric.io import fs
|
12
13
|
|
13
14
|
|
14
15
|
class GDELTLite(InMemoryDataset):
|
@@ -80,9 +81,9 @@ class GDELTLite(InMemoryDataset):
|
|
80
81
|
def process(self) -> None:
|
81
82
|
import pandas as pd
|
82
83
|
|
83
|
-
x =
|
84
|
+
x = fs.torch_load(self.raw_paths[0])
|
84
85
|
df = pd.read_csv(self.raw_paths[1])
|
85
|
-
edge_attr =
|
86
|
+
edge_attr = fs.torch_load(self.raw_paths[2])
|
86
87
|
|
87
88
|
row = torch.from_numpy(df['src'].values)
|
88
89
|
col = torch.from_numpy(df['dst'].values)
|
@@ -13,6 +13,7 @@ from torch_geometric.data import (
|
|
13
13
|
extract_tar,
|
14
14
|
extract_zip,
|
15
15
|
)
|
16
|
+
from torch_geometric.io import fs
|
16
17
|
from torch_geometric.utils import one_hot, to_undirected
|
17
18
|
|
18
19
|
|
@@ -145,9 +146,9 @@ class GEDDataset(InMemoryDataset):
|
|
145
146
|
path = self.processed_paths[0] if train else self.processed_paths[1]
|
146
147
|
self.load(path)
|
147
148
|
path = osp.join(self.processed_dir, f'{self.name}_ged.pt')
|
148
|
-
self.ged =
|
149
|
+
self.ged = fs.torch_load(path)
|
149
150
|
path = osp.join(self.processed_dir, f'{self.name}_norm_ged.pt')
|
150
|
-
self.norm_ged =
|
151
|
+
self.norm_ged = fs.torch_load(path)
|
151
152
|
|
152
153
|
@property
|
153
154
|
def raw_file_names(self) -> List[str]:
|
@@ -0,0 +1,263 @@
|
|
1
|
+
import sys
|
2
|
+
from typing import Any, Callable, Dict, List, Optional
|
3
|
+
|
4
|
+
import numpy as np
|
5
|
+
import torch
|
6
|
+
from tqdm import tqdm
|
7
|
+
|
8
|
+
from torch_geometric.data import (
|
9
|
+
Data,
|
10
|
+
InMemoryDataset,
|
11
|
+
download_google_url,
|
12
|
+
extract_zip,
|
13
|
+
)
|
14
|
+
from torch_geometric.io import fs
|
15
|
+
|
16
|
+
|
17
|
+
def safe_index(lst: List[Any], e: int) -> int:
|
18
|
+
return lst.index(e) if e in lst else len(lst) - 1
|
19
|
+
|
20
|
+
|
21
|
+
class GitMolDataset(InMemoryDataset):
|
22
|
+
r"""The dataset from the `"GIT-Mol: A Multi-modal Large Language Model
|
23
|
+
for Molecular Science with Graph, Image, and Text"
|
24
|
+
<https://arxiv.org/pdf/2308.06911>`_ paper.
|
25
|
+
|
26
|
+
Args:
|
27
|
+
root (str): Root directory where the dataset should be saved.
|
28
|
+
transform (callable, optional): A function/transform that takes in an
|
29
|
+
:obj:`torch_geometric.data.Data` object and returns a transformed
|
30
|
+
version. The data object will be transformed before every access.
|
31
|
+
(default: :obj:`None`)
|
32
|
+
pre_transform (callable, optional): A function/transform that takes in
|
33
|
+
an :obj:`torch_geometric.data.Data` object and returns a
|
34
|
+
transformed version. The data object will be transformed before
|
35
|
+
being saved to disk. (default: :obj:`None`)
|
36
|
+
pre_filter (callable, optional): A function that takes in an
|
37
|
+
:obj:`torch_geometric.data.Data` object and returns a boolean
|
38
|
+
value, indicating whether the data object should be included in the
|
39
|
+
final dataset. (default: :obj:`None`)
|
40
|
+
force_reload (bool, optional): Whether to re-process the dataset.
|
41
|
+
(default: :obj:`False`)
|
42
|
+
split (int, optional): Datasets split, train/valid/test=0/1/2.
|
43
|
+
(default: :obj:`0`)
|
44
|
+
"""
|
45
|
+
|
46
|
+
raw_url_id = '1loBXabD6ncAFY-vanRsVtRUSFkEtBweg'
|
47
|
+
|
48
|
+
def __init__(
|
49
|
+
self,
|
50
|
+
root: str,
|
51
|
+
transform: Optional[Callable] = None,
|
52
|
+
pre_transform: Optional[Callable] = None,
|
53
|
+
pre_filter: Optional[Callable] = None,
|
54
|
+
force_reload: bool = False,
|
55
|
+
split: int = 0,
|
56
|
+
):
|
57
|
+
from torchvision import transforms
|
58
|
+
|
59
|
+
self.split = split
|
60
|
+
|
61
|
+
if self.split == 0:
|
62
|
+
self.img_transform = transforms.Compose([
|
63
|
+
transforms.Resize((224, 224)),
|
64
|
+
transforms.RandomRotation(15),
|
65
|
+
transforms.ColorJitter(brightness=0.5, contrast=0.5, hue=0.5),
|
66
|
+
transforms.ToTensor(),
|
67
|
+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
68
|
+
std=[0.229, 0.224, 0.225])
|
69
|
+
])
|
70
|
+
else:
|
71
|
+
self.img_transform = transforms.Compose([
|
72
|
+
transforms.Resize((224, 224)),
|
73
|
+
transforms.ToTensor(),
|
74
|
+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
75
|
+
std=[0.229, 0.224, 0.225])
|
76
|
+
])
|
77
|
+
|
78
|
+
super().__init__(root, transform, pre_transform, pre_filter,
|
79
|
+
force_reload=force_reload)
|
80
|
+
|
81
|
+
self.load(self.processed_paths[0])
|
82
|
+
|
83
|
+
@property
|
84
|
+
def raw_file_names(self) -> List[str]:
|
85
|
+
return ['train_3500.pkl', 'valid_450.pkl', 'test_450.pkl']
|
86
|
+
|
87
|
+
@property
|
88
|
+
def processed_file_names(self) -> str:
|
89
|
+
return ['train.pt', 'valid.pt', 'test.pt'][self.split]
|
90
|
+
|
91
|
+
def download(self) -> None:
|
92
|
+
file_path = download_google_url(
|
93
|
+
self.raw_url_id,
|
94
|
+
self.raw_dir,
|
95
|
+
'gitmol.zip',
|
96
|
+
)
|
97
|
+
extract_zip(file_path, self.raw_dir)
|
98
|
+
|
99
|
+
def process(self) -> None:
|
100
|
+
import pandas as pd
|
101
|
+
from PIL import Image
|
102
|
+
|
103
|
+
try:
|
104
|
+
from rdkit import Chem, RDLogger
|
105
|
+
RDLogger.DisableLog('rdApp.*') # type: ignore
|
106
|
+
WITH_RDKIT = True
|
107
|
+
|
108
|
+
except ImportError:
|
109
|
+
WITH_RDKIT = False
|
110
|
+
|
111
|
+
if not WITH_RDKIT:
|
112
|
+
print(("Using a pre-processed version of the dataset. Please "
|
113
|
+
"install 'rdkit' to alternatively process the raw data."),
|
114
|
+
file=sys.stderr)
|
115
|
+
|
116
|
+
data_list = fs.torch_load(self.raw_paths[0])
|
117
|
+
data_list = [Data(**data_dict) for data_dict in data_list]
|
118
|
+
|
119
|
+
if self.pre_filter is not None:
|
120
|
+
data_list = [d for d in data_list if self.pre_filter(d)]
|
121
|
+
|
122
|
+
if self.pre_transform is not None:
|
123
|
+
data_list = [self.pre_transform(d) for d in data_list]
|
124
|
+
|
125
|
+
self.save(data_list, self.processed_paths[0])
|
126
|
+
return
|
127
|
+
|
128
|
+
allowable_features: Dict[str, List[Any]] = {
|
129
|
+
'possible_atomic_num_list':
|
130
|
+
list(range(1, 119)) + ['misc'],
|
131
|
+
'possible_formal_charge_list':
|
132
|
+
[-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 'misc'],
|
133
|
+
'possible_chirality_list': [
|
134
|
+
Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
|
135
|
+
Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
|
136
|
+
Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
|
137
|
+
Chem.rdchem.ChiralType.CHI_OTHER
|
138
|
+
],
|
139
|
+
'possible_hybridization_list': [
|
140
|
+
Chem.rdchem.HybridizationType.SP,
|
141
|
+
Chem.rdchem.HybridizationType.SP2,
|
142
|
+
Chem.rdchem.HybridizationType.SP3,
|
143
|
+
Chem.rdchem.HybridizationType.SP3D,
|
144
|
+
Chem.rdchem.HybridizationType.SP3D2,
|
145
|
+
Chem.rdchem.HybridizationType.UNSPECIFIED, 'misc'
|
146
|
+
],
|
147
|
+
'possible_numH_list': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'],
|
148
|
+
'possible_implicit_valence_list': [0, 1, 2, 3, 4, 5, 6],
|
149
|
+
'possible_degree_list': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 'misc'],
|
150
|
+
'possible_number_radical_e_list': [0, 1, 2, 3, 4, 'misc'],
|
151
|
+
'possible_is_aromatic_list': [False, True],
|
152
|
+
'possible_is_in_ring_list': [False, True],
|
153
|
+
'possible_bond_type_list': [
|
154
|
+
Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE,
|
155
|
+
Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC,
|
156
|
+
Chem.rdchem.BondType.ZERO
|
157
|
+
],
|
158
|
+
'possible_bond_dirs': [ # only for double bond stereo information
|
159
|
+
Chem.rdchem.BondDir.NONE, Chem.rdchem.BondDir.ENDUPRIGHT,
|
160
|
+
Chem.rdchem.BondDir.ENDDOWNRIGHT
|
161
|
+
],
|
162
|
+
'possible_bond_stereo_list': [
|
163
|
+
Chem.rdchem.BondStereo.STEREONONE,
|
164
|
+
Chem.rdchem.BondStereo.STEREOZ,
|
165
|
+
Chem.rdchem.BondStereo.STEREOE,
|
166
|
+
Chem.rdchem.BondStereo.STEREOCIS,
|
167
|
+
Chem.rdchem.BondStereo.STEREOTRANS,
|
168
|
+
Chem.rdchem.BondStereo.STEREOANY,
|
169
|
+
],
|
170
|
+
'possible_is_conjugated_list': [False, True]
|
171
|
+
}
|
172
|
+
|
173
|
+
data = pd.read_pickle(
|
174
|
+
f'{self.raw_dir}/igcdata_toy/{self.raw_file_names[self.split]}')
|
175
|
+
|
176
|
+
data_list = []
|
177
|
+
for _, r in tqdm(data.iterrows(), total=data.shape[0]):
|
178
|
+
smiles = r['isosmiles']
|
179
|
+
mol = Chem.MolFromSmiles(smiles.strip('\n'))
|
180
|
+
if mol is not None:
|
181
|
+
# text
|
182
|
+
summary = r['summary']
|
183
|
+
# image
|
184
|
+
cid = r['cid']
|
185
|
+
img_file = f'{self.raw_dir}/igcdata_toy/imgs/CID_{cid}.png'
|
186
|
+
img = Image.open(img_file).convert('RGB')
|
187
|
+
img = self.img_transform(img).unsqueeze(0)
|
188
|
+
# graph
|
189
|
+
atom_features_list = []
|
190
|
+
for atom in mol.GetAtoms():
|
191
|
+
atom_feature = [
|
192
|
+
safe_index(
|
193
|
+
allowable_features['possible_atomic_num_list'],
|
194
|
+
atom.GetAtomicNum()),
|
195
|
+
allowable_features['possible_chirality_list'].index(
|
196
|
+
atom.GetChiralTag()),
|
197
|
+
safe_index(allowable_features['possible_degree_list'],
|
198
|
+
atom.GetTotalDegree()),
|
199
|
+
safe_index(
|
200
|
+
allowable_features['possible_formal_charge_list'],
|
201
|
+
atom.GetFormalCharge()),
|
202
|
+
safe_index(allowable_features['possible_numH_list'],
|
203
|
+
atom.GetTotalNumHs()),
|
204
|
+
safe_index(
|
205
|
+
allowable_features[
|
206
|
+
'possible_number_radical_e_list'],
|
207
|
+
atom.GetNumRadicalElectrons()),
|
208
|
+
safe_index(
|
209
|
+
allowable_features['possible_hybridization_list'],
|
210
|
+
atom.GetHybridization()),
|
211
|
+
allowable_features['possible_is_aromatic_list'].index(
|
212
|
+
atom.GetIsAromatic()),
|
213
|
+
allowable_features['possible_is_in_ring_list'].index(
|
214
|
+
atom.IsInRing()),
|
215
|
+
]
|
216
|
+
atom_features_list.append(atom_feature)
|
217
|
+
x = torch.tensor(np.array(atom_features_list),
|
218
|
+
dtype=torch.long)
|
219
|
+
|
220
|
+
edges_list = []
|
221
|
+
edge_features_list = []
|
222
|
+
for bond in mol.GetBonds():
|
223
|
+
i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
|
224
|
+
edge_feature = [
|
225
|
+
safe_index(
|
226
|
+
allowable_features['possible_bond_type_list'],
|
227
|
+
bond.GetBondType()),
|
228
|
+
allowable_features['possible_bond_stereo_list'].index(
|
229
|
+
bond.GetStereo()),
|
230
|
+
allowable_features['possible_is_conjugated_list'].
|
231
|
+
index(bond.GetIsConjugated()),
|
232
|
+
]
|
233
|
+
edges_list.append((i, j))
|
234
|
+
edge_features_list.append(edge_feature)
|
235
|
+
edges_list.append((j, i))
|
236
|
+
edge_features_list.append(edge_feature)
|
237
|
+
|
238
|
+
edge_index = torch.tensor(
|
239
|
+
np.array(edges_list).T,
|
240
|
+
dtype=torch.long,
|
241
|
+
)
|
242
|
+
edge_attr = torch.tensor(
|
243
|
+
np.array(edge_features_list),
|
244
|
+
dtype=torch.long,
|
245
|
+
)
|
246
|
+
|
247
|
+
data = Data(
|
248
|
+
x=x,
|
249
|
+
edge_index=edge_index,
|
250
|
+
smiles=smiles,
|
251
|
+
edge_attr=edge_attr,
|
252
|
+
image=img,
|
253
|
+
caption=summary,
|
254
|
+
)
|
255
|
+
|
256
|
+
if self.pre_filter is not None and not self.pre_filter(data):
|
257
|
+
continue
|
258
|
+
if self.pre_transform is not None:
|
259
|
+
data = self.pre_transform(data)
|
260
|
+
|
261
|
+
data_list.append(data)
|
262
|
+
|
263
|
+
self.save(data_list, self.processed_paths[0])
|
@@ -12,6 +12,7 @@ from torch_geometric.data import (
|
|
12
12
|
download_url,
|
13
13
|
extract_zip,
|
14
14
|
)
|
15
|
+
from torch_geometric.io import fs
|
15
16
|
from torch_geometric.utils import remove_self_loops
|
16
17
|
|
17
18
|
|
@@ -61,31 +62,31 @@ class GNNBenchmarkDataset(InMemoryDataset):
|
|
61
62
|
- #features
|
62
63
|
- #classes
|
63
64
|
* - PATTERN
|
64
|
-
-
|
65
|
+
- 14,000
|
65
66
|
- ~118.9
|
66
67
|
- ~6,098.9
|
67
68
|
- 3
|
68
69
|
- 2
|
69
70
|
* - CLUSTER
|
70
|
-
-
|
71
|
+
- 12,000
|
71
72
|
- ~117.2
|
72
73
|
- ~4,303.9
|
73
74
|
- 7
|
74
75
|
- 6
|
75
76
|
* - MNIST
|
76
|
-
-
|
77
|
+
- 70,000
|
77
78
|
- ~70.6
|
78
79
|
- ~564.5
|
79
80
|
- 3
|
80
81
|
- 10
|
81
82
|
* - CIFAR10
|
82
|
-
-
|
83
|
+
- 60,000
|
83
84
|
- ~117.6
|
84
85
|
- ~941.2
|
85
86
|
- 5
|
86
87
|
- 10
|
87
88
|
* - TSP
|
88
|
-
-
|
89
|
+
- 12,000
|
89
90
|
- ~275.4
|
90
91
|
- ~6,885.0
|
91
92
|
- 2
|
@@ -126,9 +127,9 @@ class GNNBenchmarkDataset(InMemoryDataset):
|
|
126
127
|
if self.name == 'CSL' and split != 'train':
|
127
128
|
split = 'train'
|
128
129
|
logging.warning(
|
129
|
-
|
130
|
-
|
131
|
-
|
130
|
+
"Dataset 'CSL' does not provide a standardized splitting. "
|
131
|
+
"Instead, it is recommended to perform 5-fold cross "
|
132
|
+
"validation with stratifed sampling")
|
132
133
|
|
133
134
|
super().__init__(root, transform, pre_transform, pre_filter,
|
134
135
|
force_reload=force_reload)
|
@@ -181,7 +182,7 @@ class GNNBenchmarkDataset(InMemoryDataset):
|
|
181
182
|
data_list = self.process_CSL()
|
182
183
|
self.save(data_list, self.processed_paths[0])
|
183
184
|
else:
|
184
|
-
inputs =
|
185
|
+
inputs = fs.torch_load(self.raw_paths[0])
|
185
186
|
for i in range(len(inputs)):
|
186
187
|
data_list = [Data(**data_dict) for data_dict in inputs[i]]
|
187
188
|
|
@@ -197,7 +198,7 @@ class GNNBenchmarkDataset(InMemoryDataset):
|
|
197
198
|
with open(self.raw_paths[0], 'rb') as f:
|
198
199
|
adjs = pickle.load(f)
|
199
200
|
|
200
|
-
ys =
|
201
|
+
ys = fs.torch_load(self.raw_paths[1]).tolist()
|
201
202
|
|
202
203
|
data_list = []
|
203
204
|
for adj, y in zip(adjs, ys):
|