pyg-nightly 2.6.0.dev20240704__py3-none-any.whl → 2.8.0.dev20251207__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pyg-nightly might be problematic. Click here for more details.
- {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/METADATA +81 -58
- {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/RECORD +265 -221
- {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/WHEEL +1 -1
- pyg_nightly-2.8.0.dev20251207.dist-info/licenses/LICENSE +19 -0
- torch_geometric/__init__.py +34 -1
- torch_geometric/_compile.py +11 -3
- torch_geometric/_onnx.py +228 -0
- torch_geometric/config_mixin.py +8 -3
- torch_geometric/config_store.py +1 -1
- torch_geometric/contrib/__init__.py +1 -1
- torch_geometric/contrib/explain/pgm_explainer.py +1 -1
- torch_geometric/data/__init__.py +19 -1
- torch_geometric/data/batch.py +2 -2
- torch_geometric/data/collate.py +1 -3
- torch_geometric/data/data.py +110 -6
- torch_geometric/data/database.py +19 -5
- torch_geometric/data/dataset.py +14 -9
- torch_geometric/data/extract.py +1 -1
- torch_geometric/data/feature_store.py +17 -22
- torch_geometric/data/graph_store.py +3 -2
- torch_geometric/data/hetero_data.py +139 -7
- torch_geometric/data/hypergraph_data.py +2 -2
- torch_geometric/data/in_memory_dataset.py +2 -2
- torch_geometric/data/lightning/datamodule.py +42 -28
- torch_geometric/data/storage.py +9 -1
- torch_geometric/datasets/__init__.py +20 -1
- torch_geometric/datasets/actor.py +7 -9
- torch_geometric/datasets/airfrans.py +17 -20
- torch_geometric/datasets/airports.py +8 -10
- torch_geometric/datasets/amazon.py +8 -11
- torch_geometric/datasets/amazon_book.py +8 -9
- torch_geometric/datasets/amazon_products.py +7 -9
- torch_geometric/datasets/aminer.py +8 -9
- torch_geometric/datasets/aqsol.py +10 -13
- torch_geometric/datasets/attributed_graph_dataset.py +8 -10
- torch_geometric/datasets/ba_multi_shapes.py +10 -12
- torch_geometric/datasets/ba_shapes.py +5 -6
- torch_geometric/datasets/brca_tgca.py +1 -1
- torch_geometric/datasets/city.py +157 -0
- torch_geometric/datasets/dbp15k.py +1 -1
- torch_geometric/datasets/gdelt_lite.py +3 -2
- torch_geometric/datasets/ged_dataset.py +3 -2
- torch_geometric/datasets/git_mol_dataset.py +263 -0
- torch_geometric/datasets/gnn_benchmark_dataset.py +3 -2
- torch_geometric/datasets/hgb_dataset.py +2 -2
- torch_geometric/datasets/hm.py +1 -1
- torch_geometric/datasets/instruct_mol_dataset.py +134 -0
- torch_geometric/datasets/linkx_dataset.py +4 -3
- torch_geometric/datasets/lrgb.py +3 -5
- torch_geometric/datasets/malnet_tiny.py +2 -1
- torch_geometric/datasets/md17.py +3 -3
- torch_geometric/datasets/medshapenet.py +145 -0
- torch_geometric/datasets/mnist_superpixels.py +2 -3
- torch_geometric/datasets/modelnet.py +1 -1
- torch_geometric/datasets/molecule_gpt_dataset.py +492 -0
- torch_geometric/datasets/molecule_net.py +3 -2
- torch_geometric/datasets/neurograph.py +1 -3
- torch_geometric/datasets/ogb_mag.py +1 -1
- torch_geometric/datasets/opf.py +19 -5
- torch_geometric/datasets/pascal_pf.py +1 -1
- torch_geometric/datasets/pcqm4m.py +2 -1
- torch_geometric/datasets/ppi.py +2 -1
- torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
- torch_geometric/datasets/qm7.py +1 -1
- torch_geometric/datasets/qm9.py +3 -2
- torch_geometric/datasets/shrec2016.py +2 -2
- torch_geometric/datasets/snap_dataset.py +8 -4
- torch_geometric/datasets/tag_dataset.py +462 -0
- torch_geometric/datasets/teeth3ds.py +269 -0
- torch_geometric/datasets/web_qsp_dataset.py +342 -0
- torch_geometric/datasets/wikics.py +2 -1
- torch_geometric/datasets/wikidata.py +2 -1
- torch_geometric/deprecation.py +1 -1
- torch_geometric/distributed/__init__.py +13 -0
- torch_geometric/distributed/dist_loader.py +2 -2
- torch_geometric/distributed/local_feature_store.py +3 -2
- torch_geometric/distributed/local_graph_store.py +2 -1
- torch_geometric/distributed/partition.py +9 -8
- torch_geometric/distributed/rpc.py +3 -3
- torch_geometric/edge_index.py +35 -22
- torch_geometric/explain/algorithm/attention_explainer.py +219 -29
- torch_geometric/explain/algorithm/base.py +2 -2
- torch_geometric/explain/algorithm/captum.py +1 -1
- torch_geometric/explain/algorithm/captum_explainer.py +2 -1
- torch_geometric/explain/algorithm/gnn_explainer.py +406 -69
- torch_geometric/explain/algorithm/graphmask_explainer.py +8 -8
- torch_geometric/explain/algorithm/pg_explainer.py +305 -47
- torch_geometric/explain/explainer.py +2 -2
- torch_geometric/explain/explanation.py +89 -5
- torch_geometric/explain/metric/faithfulness.py +1 -1
- torch_geometric/graphgym/checkpoint.py +2 -1
- torch_geometric/graphgym/config.py +3 -2
- torch_geometric/graphgym/imports.py +15 -4
- torch_geometric/graphgym/logger.py +1 -1
- torch_geometric/graphgym/loss.py +1 -1
- torch_geometric/graphgym/models/encoder.py +2 -2
- torch_geometric/graphgym/models/layer.py +1 -1
- torch_geometric/graphgym/utils/comp_budget.py +4 -3
- torch_geometric/hash_tensor.py +798 -0
- torch_geometric/index.py +16 -7
- torch_geometric/inspector.py +6 -2
- torch_geometric/io/fs.py +27 -0
- torch_geometric/io/tu.py +2 -3
- torch_geometric/llm/__init__.py +9 -0
- torch_geometric/llm/large_graph_indexer.py +741 -0
- torch_geometric/llm/models/__init__.py +23 -0
- torch_geometric/llm/models/g_retriever.py +251 -0
- torch_geometric/llm/models/git_mol.py +336 -0
- torch_geometric/llm/models/glem.py +397 -0
- torch_geometric/llm/models/llm.py +470 -0
- torch_geometric/llm/models/llm_judge.py +158 -0
- torch_geometric/llm/models/molecule_gpt.py +222 -0
- torch_geometric/llm/models/protein_mpnn.py +333 -0
- torch_geometric/llm/models/sentence_transformer.py +188 -0
- torch_geometric/llm/models/txt2kg.py +353 -0
- torch_geometric/llm/models/vision_transformer.py +38 -0
- torch_geometric/llm/rag_loader.py +154 -0
- torch_geometric/llm/utils/__init__.py +10 -0
- torch_geometric/llm/utils/backend_utils.py +443 -0
- torch_geometric/llm/utils/feature_store.py +169 -0
- torch_geometric/llm/utils/graph_store.py +199 -0
- torch_geometric/llm/utils/vectorrag.py +125 -0
- torch_geometric/loader/cluster.py +6 -5
- torch_geometric/loader/graph_saint.py +2 -1
- torch_geometric/loader/ibmb_loader.py +4 -4
- torch_geometric/loader/link_loader.py +1 -1
- torch_geometric/loader/link_neighbor_loader.py +2 -1
- torch_geometric/loader/mixin.py +6 -5
- torch_geometric/loader/neighbor_loader.py +1 -1
- torch_geometric/loader/neighbor_sampler.py +2 -2
- torch_geometric/loader/prefetch.py +4 -3
- torch_geometric/loader/temporal_dataloader.py +2 -2
- torch_geometric/loader/utils.py +10 -10
- torch_geometric/metrics/__init__.py +23 -2
- torch_geometric/metrics/link_pred.py +755 -85
- torch_geometric/nn/__init__.py +1 -0
- torch_geometric/nn/aggr/__init__.py +2 -0
- torch_geometric/nn/aggr/base.py +1 -1
- torch_geometric/nn/aggr/equilibrium.py +1 -1
- torch_geometric/nn/aggr/fused.py +1 -1
- torch_geometric/nn/aggr/patch_transformer.py +149 -0
- torch_geometric/nn/aggr/set_transformer.py +1 -1
- torch_geometric/nn/aggr/utils.py +9 -4
- torch_geometric/nn/attention/__init__.py +9 -1
- torch_geometric/nn/attention/polynormer.py +107 -0
- torch_geometric/nn/attention/qformer.py +71 -0
- torch_geometric/nn/attention/sgformer.py +99 -0
- torch_geometric/nn/conv/__init__.py +2 -0
- torch_geometric/nn/conv/appnp.py +1 -1
- torch_geometric/nn/conv/collect.jinja +6 -3
- torch_geometric/nn/conv/cugraph/gat_conv.py +8 -2
- torch_geometric/nn/conv/cugraph/rgcn_conv.py +3 -0
- torch_geometric/nn/conv/cugraph/sage_conv.py +3 -0
- torch_geometric/nn/conv/dna_conv.py +1 -1
- torch_geometric/nn/conv/eg_conv.py +7 -7
- torch_geometric/nn/conv/gat_conv.py +33 -4
- torch_geometric/nn/conv/gatv2_conv.py +35 -4
- torch_geometric/nn/conv/gen_conv.py +1 -1
- torch_geometric/nn/conv/general_conv.py +1 -1
- torch_geometric/nn/conv/gravnet_conv.py +2 -1
- torch_geometric/nn/conv/hetero_conv.py +3 -2
- torch_geometric/nn/conv/meshcnn_conv.py +487 -0
- torch_geometric/nn/conv/message_passing.py +6 -5
- torch_geometric/nn/conv/mixhop_conv.py +1 -1
- torch_geometric/nn/conv/rgcn_conv.py +2 -1
- torch_geometric/nn/conv/sg_conv.py +1 -1
- torch_geometric/nn/conv/spline_conv.py +2 -1
- torch_geometric/nn/conv/ssg_conv.py +1 -1
- torch_geometric/nn/conv/transformer_conv.py +5 -3
- torch_geometric/nn/data_parallel.py +5 -4
- torch_geometric/nn/dense/linear.py +5 -24
- torch_geometric/nn/encoding.py +17 -3
- torch_geometric/nn/fx.py +17 -15
- torch_geometric/nn/model_hub.py +5 -16
- torch_geometric/nn/models/__init__.py +11 -0
- torch_geometric/nn/models/attentive_fp.py +1 -1
- torch_geometric/nn/models/attract_repel.py +148 -0
- torch_geometric/nn/models/basic_gnn.py +2 -1
- torch_geometric/nn/models/captum.py +1 -1
- torch_geometric/nn/models/deep_graph_infomax.py +1 -1
- torch_geometric/nn/models/dimenet.py +2 -2
- torch_geometric/nn/models/dimenet_utils.py +4 -2
- torch_geometric/nn/models/gpse.py +1083 -0
- torch_geometric/nn/models/graph_unet.py +13 -4
- torch_geometric/nn/models/lpformer.py +783 -0
- torch_geometric/nn/models/metapath2vec.py +1 -1
- torch_geometric/nn/models/mlp.py +4 -2
- torch_geometric/nn/models/node2vec.py +1 -1
- torch_geometric/nn/models/polynormer.py +206 -0
- torch_geometric/nn/models/rev_gnn.py +3 -3
- torch_geometric/nn/models/schnet.py +2 -1
- torch_geometric/nn/models/sgformer.py +219 -0
- torch_geometric/nn/models/signed_gcn.py +1 -1
- torch_geometric/nn/models/visnet.py +2 -2
- torch_geometric/nn/norm/batch_norm.py +17 -7
- torch_geometric/nn/norm/diff_group_norm.py +7 -2
- torch_geometric/nn/norm/graph_norm.py +9 -4
- torch_geometric/nn/norm/instance_norm.py +5 -1
- torch_geometric/nn/norm/layer_norm.py +15 -7
- torch_geometric/nn/norm/msg_norm.py +8 -2
- torch_geometric/nn/pool/__init__.py +15 -9
- torch_geometric/nn/pool/cluster_pool.py +144 -0
- torch_geometric/nn/pool/connect/base.py +1 -3
- torch_geometric/nn/pool/edge_pool.py +1 -1
- torch_geometric/nn/pool/knn.py +13 -10
- torch_geometric/nn/pool/select/base.py +1 -4
- torch_geometric/nn/summary.py +1 -1
- torch_geometric/nn/to_hetero_module.py +4 -3
- torch_geometric/nn/to_hetero_transformer.py +3 -3
- torch_geometric/nn/to_hetero_with_bases_transformer.py +5 -5
- torch_geometric/profile/__init__.py +2 -0
- torch_geometric/profile/nvtx.py +66 -0
- torch_geometric/profile/profiler.py +18 -9
- torch_geometric/profile/utils.py +20 -5
- torch_geometric/sampler/__init__.py +2 -1
- torch_geometric/sampler/base.py +337 -8
- torch_geometric/sampler/hgt_sampler.py +11 -1
- torch_geometric/sampler/neighbor_sampler.py +298 -25
- torch_geometric/sampler/utils.py +93 -5
- torch_geometric/testing/__init__.py +4 -0
- torch_geometric/testing/decorators.py +35 -5
- torch_geometric/testing/distributed.py +1 -1
- torch_geometric/transforms/__init__.py +4 -0
- torch_geometric/transforms/add_gpse.py +49 -0
- torch_geometric/transforms/add_metapaths.py +10 -8
- torch_geometric/transforms/add_positional_encoding.py +2 -2
- torch_geometric/transforms/base_transform.py +2 -1
- torch_geometric/transforms/delaunay.py +65 -15
- torch_geometric/transforms/face_to_edge.py +32 -3
- torch_geometric/transforms/gdc.py +8 -9
- torch_geometric/transforms/largest_connected_components.py +1 -1
- torch_geometric/transforms/mask.py +5 -1
- torch_geometric/transforms/node_property_split.py +1 -1
- torch_geometric/transforms/normalize_features.py +3 -3
- torch_geometric/transforms/pad.py +1 -1
- torch_geometric/transforms/random_link_split.py +1 -1
- torch_geometric/transforms/remove_duplicated_edges.py +4 -2
- torch_geometric/transforms/remove_self_loops.py +36 -0
- torch_geometric/transforms/rooted_subgraph.py +1 -1
- torch_geometric/transforms/svd_feature_reduction.py +1 -1
- torch_geometric/transforms/virtual_node.py +2 -1
- torch_geometric/typing.py +82 -17
- torch_geometric/utils/__init__.py +6 -1
- torch_geometric/utils/_lexsort.py +0 -9
- torch_geometric/utils/_negative_sampling.py +28 -13
- torch_geometric/utils/_normalize_edge_index.py +46 -0
- torch_geometric/utils/_scatter.py +126 -164
- torch_geometric/utils/_sort_edge_index.py +0 -2
- torch_geometric/utils/_spmm.py +16 -14
- torch_geometric/utils/_subgraph.py +4 -0
- torch_geometric/utils/_tree_decomposition.py +1 -1
- torch_geometric/utils/_trim_to_layer.py +2 -2
- torch_geometric/utils/augmentation.py +1 -1
- torch_geometric/utils/convert.py +17 -10
- torch_geometric/utils/cross_entropy.py +34 -13
- torch_geometric/utils/embedding.py +91 -2
- torch_geometric/utils/geodesic.py +28 -25
- torch_geometric/utils/influence.py +279 -0
- torch_geometric/utils/map.py +14 -10
- torch_geometric/utils/nested.py +1 -1
- torch_geometric/utils/smiles.py +3 -3
- torch_geometric/utils/sparse.py +32 -24
- torch_geometric/visualization/__init__.py +2 -1
- torch_geometric/visualization/graph.py +250 -5
- torch_geometric/warnings.py +11 -2
- torch_geometric/nn/nlp/__init__.py +0 -7
- torch_geometric/nn/nlp/llm.py +0 -283
- torch_geometric/nn/nlp/sentence_transformer.py +0 -94
|
@@ -1,14 +1,13 @@
|
|
|
1
1
|
import os
|
|
2
2
|
from typing import Callable, List, Optional
|
|
3
3
|
|
|
4
|
-
import torch
|
|
5
|
-
|
|
6
4
|
from torch_geometric.data import (
|
|
7
5
|
Data,
|
|
8
6
|
InMemoryDataset,
|
|
9
7
|
download_url,
|
|
10
8
|
extract_zip,
|
|
11
9
|
)
|
|
10
|
+
from torch_geometric.io import fs
|
|
12
11
|
|
|
13
12
|
|
|
14
13
|
class MNISTSuperpixels(InMemoryDataset):
|
|
@@ -85,7 +84,7 @@ class MNISTSuperpixels(InMemoryDataset):
|
|
|
85
84
|
os.unlink(path)
|
|
86
85
|
|
|
87
86
|
def process(self) -> None:
|
|
88
|
-
inputs =
|
|
87
|
+
inputs = fs.torch_load(self.raw_paths[0])
|
|
89
88
|
for i in range(len(inputs)):
|
|
90
89
|
data_list = [Data(**data_dict) for data_dict in inputs[i]]
|
|
91
90
|
|
|
@@ -79,7 +79,7 @@ class ModelNet(InMemoryDataset):
|
|
|
79
79
|
|
|
80
80
|
urls = {
|
|
81
81
|
'10':
|
|
82
|
-
'http://
|
|
82
|
+
'http://3dvision.princeton.edu/projects/2014/3DShapeNets/ModelNet10.zip', # noqa
|
|
83
83
|
'40': 'http://modelnet.cs.princeton.edu/ModelNet40.zip'
|
|
84
84
|
}
|
|
85
85
|
|
|
@@ -0,0 +1,492 @@
|
|
|
1
|
+
import gzip
|
|
2
|
+
import json
|
|
3
|
+
import multiprocessing
|
|
4
|
+
import os
|
|
5
|
+
import sys
|
|
6
|
+
from collections import defaultdict
|
|
7
|
+
from multiprocessing import Pool
|
|
8
|
+
from typing import Callable, List, Optional, Tuple
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
import requests
|
|
12
|
+
import torch
|
|
13
|
+
from tqdm import tqdm
|
|
14
|
+
|
|
15
|
+
from torch_geometric.data import Data, InMemoryDataset, download_url
|
|
16
|
+
from torch_geometric.io import fs
|
|
17
|
+
from torch_geometric.llm.models import LLM
|
|
18
|
+
from torch_geometric.utils import one_hot
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def clean_up_description(description: str) -> str:
|
|
22
|
+
description = description + " "
|
|
23
|
+
|
|
24
|
+
# extra adj Pure
|
|
25
|
+
if description.startswith("Pure "):
|
|
26
|
+
description = description.replace("Pure ", "")
|
|
27
|
+
# fix typo
|
|
28
|
+
if description.startswith("Mercurycombines"):
|
|
29
|
+
description = description.replace("Mercurycombines",
|
|
30
|
+
"Mercury combines")
|
|
31
|
+
|
|
32
|
+
# a special case
|
|
33
|
+
description = description.replace(
|
|
34
|
+
"17-Hydroxy-6-methylpregna-3,6-diene-3,20-dione. ",
|
|
35
|
+
"17-Hydroxy-6-methylpregna-3,6-diene-3,20-dione is ")
|
|
36
|
+
|
|
37
|
+
# a special case
|
|
38
|
+
description = description.replace("5-Thymidylic acid. ",
|
|
39
|
+
"5-Thymidylic acid. is ")
|
|
40
|
+
|
|
41
|
+
# a special case
|
|
42
|
+
description = description.replace(
|
|
43
|
+
"5'-S-(3-Amino-3-carboxypropyl)-5'-thioadenosine. ",
|
|
44
|
+
"5'-S-(3-Amino-3-carboxypropyl)-5'-thioadenosine. is ")
|
|
45
|
+
|
|
46
|
+
# a special case
|
|
47
|
+
description = description.replace(
|
|
48
|
+
("Guanosine 5'-(trihydrogen diphosphate), monoanhydride"
|
|
49
|
+
" with phosphorothioic acid. "),
|
|
50
|
+
("Guanosine 5'-(trihydrogen diphosphate), monoanhydride"
|
|
51
|
+
" with phosphorothioic acid is "))
|
|
52
|
+
|
|
53
|
+
# a special case
|
|
54
|
+
description = description.replace("5'-Uridylic acid. ",
|
|
55
|
+
"5'-Uridylic acid is ")
|
|
56
|
+
|
|
57
|
+
# a special case
|
|
58
|
+
description = description.replace("5'-Adenylic acid, ",
|
|
59
|
+
"5'-Adenylic acid is ")
|
|
60
|
+
|
|
61
|
+
# a special case
|
|
62
|
+
description = description.replace(
|
|
63
|
+
"Uridine 5'-(tetrahydrogen triphosphate). ",
|
|
64
|
+
"Uridine 5'-(tetrahydrogen triphosphate). is ")
|
|
65
|
+
|
|
66
|
+
# a special case
|
|
67
|
+
description = description.replace("Inosine 5'-Monophosphate. ",
|
|
68
|
+
"Inosine 5'-Monophosphate. is ")
|
|
69
|
+
|
|
70
|
+
# a special case
|
|
71
|
+
description = description.replace("Pivaloyloxymethyl butyrate (AN-9), ",
|
|
72
|
+
"Pivaloyloxymethyl butyrate (AN-9) is ")
|
|
73
|
+
|
|
74
|
+
# a special case
|
|
75
|
+
description = description.replace(
|
|
76
|
+
"4-Amino-5-cyano-7-(D-ribofuranosyl)-7H- pyrrolo(2,3-d)pyrimidine. ",
|
|
77
|
+
"4-Amino-5-cyano-7-(D-ribofuranosyl)-7H- pyrrolo(2,3-d)pyrimidine is ")
|
|
78
|
+
|
|
79
|
+
# a special case
|
|
80
|
+
description = description.replace(
|
|
81
|
+
"Cardamonin (also known as Dihydroxymethoxychalcone), ",
|
|
82
|
+
"Cardamonin (also known as Dihydroxymethoxychalcone) is ")
|
|
83
|
+
|
|
84
|
+
# a special case
|
|
85
|
+
description = description.replace("Lithium has been used to treat ",
|
|
86
|
+
"Lithium is ")
|
|
87
|
+
|
|
88
|
+
# a special case
|
|
89
|
+
description = description.replace("4,4'-Methylenebis ",
|
|
90
|
+
"4,4'-Methylenebis is ")
|
|
91
|
+
|
|
92
|
+
# a special case
|
|
93
|
+
description = description.replace(
|
|
94
|
+
"2,3,7,8-Tetrachlorodibenzo-p-dioxin",
|
|
95
|
+
"2,3,7,8-Tetrachlorodibenzo-p-dioxin is ")
|
|
96
|
+
|
|
97
|
+
# a special case
|
|
98
|
+
description = description.replace("Exposure to 2,4,5-trichlorophenol ",
|
|
99
|
+
"2,4,5-Trichlorophenol exposure ")
|
|
100
|
+
|
|
101
|
+
index = 0
|
|
102
|
+
L = len(description)
|
|
103
|
+
if description.startswith('C.I. '):
|
|
104
|
+
start_index = len('C.I. ')
|
|
105
|
+
elif description.startswith('Nectriapyrone. D '):
|
|
106
|
+
start_index = len('Nectriapyrone. D ')
|
|
107
|
+
elif description.startswith(
|
|
108
|
+
'Salmonella enterica sv. Minnesota LPS core oligosaccharide'):
|
|
109
|
+
start_index = len(
|
|
110
|
+
'Salmonella enterica sv. Minnesota LPS core oligosaccharide')
|
|
111
|
+
else:
|
|
112
|
+
start_index = 0
|
|
113
|
+
for index in range(start_index, L - 1):
|
|
114
|
+
if index < L - 2:
|
|
115
|
+
if description[index] == '.' and description[
|
|
116
|
+
index + 1] == ' ' and 'A' <= description[index + 2] <= 'Z':
|
|
117
|
+
break
|
|
118
|
+
elif index == L - 2:
|
|
119
|
+
break
|
|
120
|
+
|
|
121
|
+
first_sentence = description[:index + 1]
|
|
122
|
+
return first_sentence
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def extract_name(
|
|
126
|
+
name_raw: str,
|
|
127
|
+
description: str,
|
|
128
|
+
) -> Tuple[Optional[str], str, str]:
|
|
129
|
+
first_sentence = clean_up_description(description)
|
|
130
|
+
|
|
131
|
+
splitter = ' -- -- '
|
|
132
|
+
if ' are ' in first_sentence or ' were ' in first_sentence:
|
|
133
|
+
replaced_words = 'These molecules'
|
|
134
|
+
else:
|
|
135
|
+
replaced_words = 'This molecule'
|
|
136
|
+
|
|
137
|
+
first_sentence = first_sentence.replace(' is ', splitter)
|
|
138
|
+
first_sentence = first_sentence.replace(' are ', splitter)
|
|
139
|
+
first_sentence = first_sentence.replace(' was ', splitter)
|
|
140
|
+
first_sentence = first_sentence.replace(' were ', splitter)
|
|
141
|
+
first_sentence = first_sentence.replace(' appears ', splitter)
|
|
142
|
+
first_sentence = first_sentence.replace(' occurs ', splitter)
|
|
143
|
+
first_sentence = first_sentence.replace(' stands for ', splitter)
|
|
144
|
+
first_sentence = first_sentence.replace(' belongs to ', splitter)
|
|
145
|
+
first_sentence = first_sentence.replace(' exists ',
|
|
146
|
+
splitter) # only for CID=11443
|
|
147
|
+
first_sentence = first_sentence.replace(' has been used in trials ',
|
|
148
|
+
splitter)
|
|
149
|
+
first_sentence = first_sentence.replace(' has been investigated ',
|
|
150
|
+
splitter)
|
|
151
|
+
first_sentence = first_sentence.replace(' has many uses ', splitter)
|
|
152
|
+
|
|
153
|
+
if splitter in first_sentence:
|
|
154
|
+
extracted_name = first_sentence.split(splitter, 1)[0]
|
|
155
|
+
elif first_sentence.startswith(name_raw):
|
|
156
|
+
extracted_name = name_raw
|
|
157
|
+
elif name_raw in first_sentence:
|
|
158
|
+
extracted_name = name_raw
|
|
159
|
+
extracted_name = None
|
|
160
|
+
print("=====", name_raw)
|
|
161
|
+
print("first sentence: ", first_sentence)
|
|
162
|
+
else:
|
|
163
|
+
extracted_name = None
|
|
164
|
+
|
|
165
|
+
if extracted_name is not None:
|
|
166
|
+
extracted_description = description.replace(extracted_name,
|
|
167
|
+
replaced_words)
|
|
168
|
+
else:
|
|
169
|
+
extracted_description = description
|
|
170
|
+
|
|
171
|
+
return extracted_name, extracted_description, first_sentence
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
class MoleculeGPTDataset(InMemoryDataset):
|
|
175
|
+
r"""The dataset from the `"MoleculeGPT: Instruction Following Large
|
|
176
|
+
Language Models for Molecular Property Prediction"
|
|
177
|
+
<https://ai4d3.github.io/2023/papers/34.pdf>`_ paper.
|
|
178
|
+
|
|
179
|
+
Args:
|
|
180
|
+
root (str): Root directory where the dataset should be saved.
|
|
181
|
+
transform (callable, optional): A function/transform that takes in an
|
|
182
|
+
:obj:`torch_geometric.data.Data` object and returns a transformed
|
|
183
|
+
version. The data object will be transformed before every access.
|
|
184
|
+
(default: :obj:`None`)
|
|
185
|
+
pre_transform (callable, optional): A function/transform that takes in
|
|
186
|
+
an :obj:`torch_geometric.data.Data` object and returns a
|
|
187
|
+
transformed version. The data object will be transformed before
|
|
188
|
+
being saved to disk. (default: :obj:`None`)
|
|
189
|
+
pre_filter (callable, optional): A function that takes in an
|
|
190
|
+
:obj:`torch_geometric.data.Data` object and returns a boolean
|
|
191
|
+
value, indicating whether the data object should be included in the
|
|
192
|
+
final dataset. (default: :obj:`None`)
|
|
193
|
+
force_reload (bool, optional): Whether to re-process the dataset.
|
|
194
|
+
(default: :obj:`False`)
|
|
195
|
+
total_page_num (int, optional): The number of pages from PubChem.
|
|
196
|
+
(default: :obj:`10`)
|
|
197
|
+
total_block_num (int, optional): The blocks of SDF files from PubChem.
|
|
198
|
+
(default: :obj:`1`)
|
|
199
|
+
num_units (int, optional): Number of units of the sample.
|
|
200
|
+
(default: :obj:`-1`, which means all units will be used)
|
|
201
|
+
"""
|
|
202
|
+
description_url = (
|
|
203
|
+
'https://pubchem.ncbi.nlm.nih.gov/rest/pug_view/annotations/'
|
|
204
|
+
'heading/json?heading_type=Compound&heading=Record+Description&page={}'
|
|
205
|
+
)
|
|
206
|
+
compound_url = ('https://ftp.ncbi.nlm.nih.gov/pubchem/Compound/'
|
|
207
|
+
'CURRENT-Full/SDF')
|
|
208
|
+
|
|
209
|
+
def __init__(
|
|
210
|
+
self,
|
|
211
|
+
root: str,
|
|
212
|
+
transform: Optional[Callable] = None,
|
|
213
|
+
pre_transform: Optional[Callable] = None,
|
|
214
|
+
pre_filter: Optional[Callable] = None,
|
|
215
|
+
force_reload: bool = False,
|
|
216
|
+
total_page_num: int = 10,
|
|
217
|
+
total_block_num: int = 1,
|
|
218
|
+
num_units: int = -1,
|
|
219
|
+
):
|
|
220
|
+
self.total_page_num = total_page_num
|
|
221
|
+
self.total_block_num = total_block_num
|
|
222
|
+
self.num_units = num_units
|
|
223
|
+
|
|
224
|
+
super().__init__(root, transform, pre_transform, pre_filter,
|
|
225
|
+
force_reload=force_reload)
|
|
226
|
+
self.load(self.processed_paths[0])
|
|
227
|
+
|
|
228
|
+
@property
|
|
229
|
+
def raw_file_names(self) -> List[str]:
|
|
230
|
+
return ['pubchem.csv']
|
|
231
|
+
|
|
232
|
+
@property
|
|
233
|
+
def processed_file_names(self) -> List[str]:
|
|
234
|
+
return ['data.pt']
|
|
235
|
+
|
|
236
|
+
def download(self) -> None:
|
|
237
|
+
# Step 01. Extract description
|
|
238
|
+
step1_folder = f"{self.raw_dir}/step_01_PubChemSTM_description"
|
|
239
|
+
if not os.path.exists(step1_folder):
|
|
240
|
+
os.makedirs(step1_folder)
|
|
241
|
+
valid_CID_set = set()
|
|
242
|
+
CID2name_raw, CID2name_extracted = defaultdict(list), defaultdict(
|
|
243
|
+
list)
|
|
244
|
+
CID2text_raw, CID2text_extracted = defaultdict(list), defaultdict(
|
|
245
|
+
list)
|
|
246
|
+
|
|
247
|
+
for page_index in tqdm(range(self.total_page_num)):
|
|
248
|
+
page_num = page_index + 1
|
|
249
|
+
f_out = open(
|
|
250
|
+
f"{step1_folder}/Compound_description_{page_num}.txt", "w")
|
|
251
|
+
|
|
252
|
+
description_data = requests.get(
|
|
253
|
+
self.description_url.format(page_num)).json()
|
|
254
|
+
|
|
255
|
+
description_data = description_data["Annotations"]
|
|
256
|
+
assert description_data["Page"] == page_num
|
|
257
|
+
|
|
258
|
+
record_list = description_data["Annotation"]
|
|
259
|
+
|
|
260
|
+
for record in record_list:
|
|
261
|
+
try:
|
|
262
|
+
CID = record["LinkedRecords"]["CID"][0]
|
|
263
|
+
if "Name" in record:
|
|
264
|
+
name_raw = record["Name"]
|
|
265
|
+
CID2name_raw[CID].append(name_raw)
|
|
266
|
+
else:
|
|
267
|
+
name_raw = None
|
|
268
|
+
|
|
269
|
+
data_list = record["Data"]
|
|
270
|
+
for data in data_list:
|
|
271
|
+
description = data["Value"]["StringWithMarkup"][0][
|
|
272
|
+
"String"].strip()
|
|
273
|
+
|
|
274
|
+
extracted_name, extracted_description, _ = extract_name( # noqa: E501
|
|
275
|
+
name_raw, description)
|
|
276
|
+
if extracted_name is not None:
|
|
277
|
+
CID2name_extracted[CID].append(extracted_name)
|
|
278
|
+
|
|
279
|
+
CID2text_raw[CID].append(description)
|
|
280
|
+
CID2text_extracted[CID].append(
|
|
281
|
+
extracted_description)
|
|
282
|
+
|
|
283
|
+
valid_CID_set.add(CID)
|
|
284
|
+
f_out.write(f"{CID}\n")
|
|
285
|
+
f_out.write(f"{extracted_description}\n\n")
|
|
286
|
+
except Exception:
|
|
287
|
+
continue
|
|
288
|
+
|
|
289
|
+
valid_CID_list = sorted(list(valid_CID_set))
|
|
290
|
+
print(f"Total CID (with raw name) {len(CID2name_raw)}")
|
|
291
|
+
print(f"Total CID (with extracted name) {len(CID2name_extracted)}")
|
|
292
|
+
print(f"Total CID {len(valid_CID_list)}")
|
|
293
|
+
|
|
294
|
+
with open(f"{self.raw_dir}/CID2name_raw.json", "w") as f:
|
|
295
|
+
json.dump(CID2name_raw, f)
|
|
296
|
+
|
|
297
|
+
with open(f"{self.raw_dir}/CID2name.json", "w") as f:
|
|
298
|
+
json.dump(CID2name_extracted, f)
|
|
299
|
+
|
|
300
|
+
with open(f"{self.raw_dir}/CID2text_raw.json", "w") as f:
|
|
301
|
+
json.dump(CID2text_raw, f)
|
|
302
|
+
|
|
303
|
+
with open(f"{self.raw_dir}/CID2text.json", "w") as f:
|
|
304
|
+
json.dump(CID2text_extracted, f)
|
|
305
|
+
|
|
306
|
+
# Step 02. Download SDF Files
|
|
307
|
+
step2_folder = f"{self.raw_dir}/step_02_PubChemSTM_SDF"
|
|
308
|
+
if not os.path.exists(step2_folder):
|
|
309
|
+
for block_id in tqdm(range(self.total_block_num)):
|
|
310
|
+
block_size = 500000
|
|
311
|
+
l_id = block_id * block_size + 1
|
|
312
|
+
r_id = (block_id + 1) * block_size
|
|
313
|
+
|
|
314
|
+
compound_file_name = f"Compound_{l_id:09d}_{r_id:09d}.sdf.gz"
|
|
315
|
+
download_url(f"{self.compound_url}/{compound_file_name}",
|
|
316
|
+
step2_folder)
|
|
317
|
+
|
|
318
|
+
def process(self, use_mp: bool = False) -> None:
|
|
319
|
+
try:
|
|
320
|
+
from rdkit import Chem
|
|
321
|
+
from rdkit.Chem.rdchem import BondType as BT
|
|
322
|
+
WITH_RDKIT = True
|
|
323
|
+
|
|
324
|
+
except ImportError:
|
|
325
|
+
WITH_RDKIT = False
|
|
326
|
+
|
|
327
|
+
if not WITH_RDKIT:
|
|
328
|
+
print(("Using a pre-processed version of the dataset. Please "
|
|
329
|
+
"install 'rdkit' to alternatively process the raw data."),
|
|
330
|
+
file=sys.stderr)
|
|
331
|
+
|
|
332
|
+
data_list = fs.torch_load(self.raw_paths[0])
|
|
333
|
+
data_list = [Data(**data_dict) for data_dict in data_list]
|
|
334
|
+
|
|
335
|
+
if self.pre_filter is not None:
|
|
336
|
+
data_list = [d for d in data_list if self.pre_filter(d)]
|
|
337
|
+
|
|
338
|
+
if self.pre_transform is not None:
|
|
339
|
+
data_list = [self.pre_transform(d) for d in data_list]
|
|
340
|
+
|
|
341
|
+
self.save(data_list, self.processed_paths[0])
|
|
342
|
+
return
|
|
343
|
+
|
|
344
|
+
# Step 03. Filter out SDF
|
|
345
|
+
step2_folder = f"{self.raw_dir}/step_02_PubChemSTM_SDF"
|
|
346
|
+
step3_folder = f"{self.raw_dir}/step_03_PubChemSTM_filtered"
|
|
347
|
+
if not os.path.exists(step3_folder):
|
|
348
|
+
os.makedirs(step3_folder)
|
|
349
|
+
with open(f"{self.raw_dir}/CID2text.json") as f:
|
|
350
|
+
CID2text = json.load(f)
|
|
351
|
+
target_CID_list = set(CID2text.keys())
|
|
352
|
+
|
|
353
|
+
block_size = 500000
|
|
354
|
+
|
|
355
|
+
def extract_one_SDF_file(block_id: int) -> None:
|
|
356
|
+
valid_mol_count = 0
|
|
357
|
+
|
|
358
|
+
writer = Chem.SDWriter(
|
|
359
|
+
f'{step3_folder}/filtered_{block_id}.sdf')
|
|
360
|
+
l_id = block_id * block_size + 1
|
|
361
|
+
r_id = (block_id + 1) * block_size
|
|
362
|
+
|
|
363
|
+
compound_file_name = f"Compound_{l_id:09d}_{r_id:09d}.sdf.gz"
|
|
364
|
+
gzip_loader = gzip.open(f"{step2_folder}/{compound_file_name}")
|
|
365
|
+
suppl = Chem.ForwardSDMolSupplier(gzip_loader)
|
|
366
|
+
|
|
367
|
+
for mol in tqdm(suppl):
|
|
368
|
+
if mol is None:
|
|
369
|
+
continue
|
|
370
|
+
cid = mol.GetProp("PUBCHEM_COMPOUND_CID")
|
|
371
|
+
|
|
372
|
+
if cid not in target_CID_list:
|
|
373
|
+
continue
|
|
374
|
+
|
|
375
|
+
writer.write(mol)
|
|
376
|
+
valid_mol_count += 1
|
|
377
|
+
|
|
378
|
+
writer.close()
|
|
379
|
+
print(f"block id: {block_id}\nfound {valid_mol_count}\n\n")
|
|
380
|
+
sys.stdout.flush()
|
|
381
|
+
return
|
|
382
|
+
|
|
383
|
+
if use_mp:
|
|
384
|
+
num_process = multiprocessing.cpu_count()
|
|
385
|
+
print(f"{num_process} CPUs")
|
|
386
|
+
num_process = 8
|
|
387
|
+
p = Pool(num_process)
|
|
388
|
+
|
|
389
|
+
block_id_list = np.arange(self.total_block_num)
|
|
390
|
+
with p:
|
|
391
|
+
p.map(extract_one_SDF_file, block_id_list)
|
|
392
|
+
else:
|
|
393
|
+
for block_id in range(self.total_block_num):
|
|
394
|
+
extract_one_SDF_file(block_id)
|
|
395
|
+
|
|
396
|
+
# Step 04. Merge SDF
|
|
397
|
+
with open(f"{self.raw_dir}/CID2text.json") as f:
|
|
398
|
+
CID2text = json.load(f)
|
|
399
|
+
target_CID_list = set(CID2text.keys())
|
|
400
|
+
print(f'The length of target_CID_list: {len(target_CID_list)}')
|
|
401
|
+
|
|
402
|
+
writer = Chem.SDWriter(f'{self.raw_dir}/molecules.sdf')
|
|
403
|
+
|
|
404
|
+
found_CID_set = set()
|
|
405
|
+
for block_id in range(self.total_block_num + 1):
|
|
406
|
+
compound_file_path = f"{step3_folder}/filtered_{block_id}.sdf"
|
|
407
|
+
try:
|
|
408
|
+
suppl = Chem.SDMolSupplier(compound_file_path)
|
|
409
|
+
|
|
410
|
+
for mol in tqdm(suppl):
|
|
411
|
+
writer.write(mol)
|
|
412
|
+
cid = mol.GetProp("PUBCHEM_COMPOUND_CID")
|
|
413
|
+
found_CID_set.add(cid)
|
|
414
|
+
except Exception:
|
|
415
|
+
print(f"block id: {block_id} with 0 valid SDF file")
|
|
416
|
+
continue
|
|
417
|
+
|
|
418
|
+
writer.close()
|
|
419
|
+
print(f"In total: {len(found_CID_set)} molecules")
|
|
420
|
+
|
|
421
|
+
# Step 05. Convert to PyG data format
|
|
422
|
+
types = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4, 'Unknow': 5}
|
|
423
|
+
bonds = {BT.SINGLE: 0, BT.DOUBLE: 1, BT.TRIPLE: 2, BT.AROMATIC: 3}
|
|
424
|
+
|
|
425
|
+
data_list = []
|
|
426
|
+
# Real data
|
|
427
|
+
CID2text_file = f'{self.raw_dir}/CID2text.json'
|
|
428
|
+
|
|
429
|
+
with open(CID2text_file) as f:
|
|
430
|
+
CID2text_data = json.load(f)
|
|
431
|
+
|
|
432
|
+
suppl = Chem.SDMolSupplier(f'{self.raw_dir}/molecules.sdf')
|
|
433
|
+
|
|
434
|
+
llm = LLM(
|
|
435
|
+
# model_name='lmsys/vicuna-7b-v1.5',
|
|
436
|
+
model_name='TinyLlama/TinyLlama-1.1B-Chat-v0.1',
|
|
437
|
+
num_params=1,
|
|
438
|
+
dtype=torch.bfloat16,
|
|
439
|
+
)
|
|
440
|
+
prompt = ("Propose a question regarding the molecule '∼' "
|
|
441
|
+
"whose answer is: {}:")
|
|
442
|
+
for mol in tqdm(suppl):
|
|
443
|
+
if mol.HasProp('PUBCHEM_COMPOUND_CID'):
|
|
444
|
+
CID = mol.GetProp("PUBCHEM_COMPOUND_CID")
|
|
445
|
+
CAN_SMILES = mol.GetProp("PUBCHEM_SMILES")
|
|
446
|
+
|
|
447
|
+
m: Chem.Mol = Chem.MolFromSmiles(CAN_SMILES)
|
|
448
|
+
if m is None:
|
|
449
|
+
continue
|
|
450
|
+
RDKit_CAN_SMILES = Chem.MolToSmiles(m)
|
|
451
|
+
|
|
452
|
+
ground_truth = CID2text_data[CID][0]
|
|
453
|
+
|
|
454
|
+
instruction = llm.inference([prompt.format(ground_truth)])[0]
|
|
455
|
+
|
|
456
|
+
x: torch.Tensor = torch.tensor([
|
|
457
|
+
types[atom.GetSymbol()] if atom.GetSymbol() in types else 5
|
|
458
|
+
for atom in m.GetAtoms()
|
|
459
|
+
])
|
|
460
|
+
x = one_hot(x, num_classes=len(types), dtype=torch.float)
|
|
461
|
+
|
|
462
|
+
rows, cols, edge_types = [], [], []
|
|
463
|
+
for bond in m.GetBonds():
|
|
464
|
+
i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
|
|
465
|
+
edge_types += [bonds[bond.GetBondType()]] * 2
|
|
466
|
+
rows += [i, j]
|
|
467
|
+
cols += [j, i]
|
|
468
|
+
|
|
469
|
+
edge_index = torch.tensor([rows, cols], dtype=torch.long)
|
|
470
|
+
edge_type = torch.tensor(edge_types, dtype=torch.long)
|
|
471
|
+
edge_attr = one_hot(edge_type, num_classes=len(bonds))
|
|
472
|
+
|
|
473
|
+
data = Data(
|
|
474
|
+
x=x,
|
|
475
|
+
edge_index=edge_index,
|
|
476
|
+
edge_attr=edge_attr,
|
|
477
|
+
smiles=RDKit_CAN_SMILES,
|
|
478
|
+
instruction=instruction,
|
|
479
|
+
y=ground_truth,
|
|
480
|
+
)
|
|
481
|
+
|
|
482
|
+
if self.pre_filter is not None and not self.pre_filter(data):
|
|
483
|
+
continue
|
|
484
|
+
if self.pre_transform is not None:
|
|
485
|
+
data = self.pre_transform(data)
|
|
486
|
+
|
|
487
|
+
data_list.append(data)
|
|
488
|
+
|
|
489
|
+
if self.num_units > 0 and len(data_list) >= self.num_units:
|
|
490
|
+
break
|
|
491
|
+
|
|
492
|
+
self.save(data_list, self.processed_paths[0])
|
|
@@ -210,8 +210,9 @@ class MoleculeNet(InMemoryDataset):
|
|
|
210
210
|
data.y = y
|
|
211
211
|
|
|
212
212
|
if data.num_nodes == 0:
|
|
213
|
-
warnings.warn(
|
|
214
|
-
|
|
213
|
+
warnings.warn(
|
|
214
|
+
f"Skipping molecule '{smiles}' since it "
|
|
215
|
+
f"resulted in zero atoms", stacklevel=2)
|
|
215
216
|
continue
|
|
216
217
|
|
|
217
218
|
if self.pre_filter is not None and not self.pre_filter(data):
|
|
@@ -2,8 +2,6 @@ import os
|
|
|
2
2
|
import os.path as osp
|
|
3
3
|
from typing import Callable, List, Optional
|
|
4
4
|
|
|
5
|
-
import torch
|
|
6
|
-
|
|
7
5
|
from torch_geometric.data import (
|
|
8
6
|
Data,
|
|
9
7
|
InMemoryDataset,
|
|
@@ -110,7 +108,7 @@ class NeuroGraphDataset(InMemoryDataset):
|
|
|
110
108
|
fs.rm(osp.join(self.raw_dir, self.name))
|
|
111
109
|
|
|
112
110
|
def process(self) -> None:
|
|
113
|
-
data, slices =
|
|
111
|
+
data, slices = fs.torch_load(self.raw_paths[0])
|
|
114
112
|
|
|
115
113
|
num_samples = slices['x'].size(0) - 1
|
|
116
114
|
data_list: List[Data] = []
|
|
@@ -147,7 +147,7 @@ class OGB_MAG(InMemoryDataset):
|
|
|
147
147
|
for node_type in ['author', 'institution', 'field_of_study']:
|
|
148
148
|
data[node_type].num_nodes = num_nodes_df[node_type].tolist()[0]
|
|
149
149
|
else:
|
|
150
|
-
emb_dict =
|
|
150
|
+
emb_dict = fs.torch_load(self.raw_paths[-1])
|
|
151
151
|
for key, value in emb_dict.items():
|
|
152
152
|
if key != 'paper':
|
|
153
153
|
data[key].x = value
|
torch_geometric/datasets/opf.py
CHANGED
|
@@ -41,6 +41,12 @@ class OPFDataset(InMemoryDataset):
|
|
|
41
41
|
If :obj:`"test"`, loads the test dataset. (default: :obj:`"train"`)
|
|
42
42
|
case_name (str, optional): The name of the original pglib-opf case.
|
|
43
43
|
(default: :obj:`"pglib_opf_case14_ieee"`)
|
|
44
|
+
num_groups (int, optional): The dataset is divided into 20 groups with
|
|
45
|
+
each group containing 15,000 samples.
|
|
46
|
+
For large networks, this amount of data can be overwhelming.
|
|
47
|
+
The :obj:`num_groups` parameters controls the amount of data being
|
|
48
|
+
downloaded. Allowed values are :obj:`[1, 20]`.
|
|
49
|
+
(default: :obj:`20`)
|
|
44
50
|
topological_perturbations (bool, optional): Whether to use the dataset
|
|
45
51
|
with added topological perturbations. (default: :obj:`False`)
|
|
46
52
|
transform (callable, optional): A function/transform that takes in
|
|
@@ -76,6 +82,7 @@ class OPFDataset(InMemoryDataset):
|
|
|
76
82
|
'pglib_opf_case10000_goc',
|
|
77
83
|
'pglib_opf_case13659_pegase',
|
|
78
84
|
] = 'pglib_opf_case14_ieee',
|
|
85
|
+
num_groups: int = 20,
|
|
79
86
|
topological_perturbations: bool = False,
|
|
80
87
|
transform: Optional[Callable] = None,
|
|
81
88
|
pre_transform: Optional[Callable] = None,
|
|
@@ -85,6 +92,7 @@ class OPFDataset(InMemoryDataset):
|
|
|
85
92
|
|
|
86
93
|
self.split = split
|
|
87
94
|
self.case_name = case_name
|
|
95
|
+
self.num_groups = num_groups
|
|
88
96
|
self.topological_perturbations = topological_perturbations
|
|
89
97
|
|
|
90
98
|
self._release = 'dataset_release_1'
|
|
@@ -103,11 +111,12 @@ class OPFDataset(InMemoryDataset):
|
|
|
103
111
|
|
|
104
112
|
@property
|
|
105
113
|
def processed_dir(self) -> str:
|
|
106
|
-
return osp.join(self.root, self._release, self.case_name,
|
|
114
|
+
return osp.join(self.root, self._release, self.case_name,
|
|
115
|
+
f'processed_{self.num_groups}')
|
|
107
116
|
|
|
108
117
|
@property
|
|
109
118
|
def raw_file_names(self) -> List[str]:
|
|
110
|
-
return [f'{self.case_name}_{i}.tar.gz' for i in range(
|
|
119
|
+
return [f'{self.case_name}_{i}.tar.gz' for i in range(self.num_groups)]
|
|
111
120
|
|
|
112
121
|
@property
|
|
113
122
|
def processed_file_names(self) -> List[str]:
|
|
@@ -124,7 +133,7 @@ class OPFDataset(InMemoryDataset):
|
|
|
124
133
|
val_data_list = []
|
|
125
134
|
test_data_list = []
|
|
126
135
|
|
|
127
|
-
for group in tqdm.tqdm(range(
|
|
136
|
+
for group in tqdm.tqdm(range(self.num_groups)):
|
|
128
137
|
tmp_dir = osp.join(
|
|
129
138
|
self.raw_dir,
|
|
130
139
|
'gridopt-dataset-tmp',
|
|
@@ -139,11 +148,14 @@ class OPFDataset(InMemoryDataset):
|
|
|
139
148
|
|
|
140
149
|
grid = obj['grid']
|
|
141
150
|
solution = obj['solution']
|
|
151
|
+
metadata = obj['metadata']
|
|
142
152
|
|
|
143
153
|
# Graph-level properties:
|
|
144
154
|
data = HeteroData()
|
|
145
155
|
data.x = torch.tensor(grid['context']).view(-1)
|
|
146
156
|
|
|
157
|
+
data.objective = torch.tensor(metadata['objective'])
|
|
158
|
+
|
|
147
159
|
# Nodes (only some have a target):
|
|
148
160
|
data['bus'].x = torch.tensor(grid['nodes']['bus'])
|
|
149
161
|
data['bus'].y = torch.tensor(solution['nodes']['bus'])
|
|
@@ -193,9 +205,11 @@ class OPFDataset(InMemoryDataset):
|
|
|
193
205
|
data = self.pre_transform(data)
|
|
194
206
|
|
|
195
207
|
i = int(name.split('.')[0].split('_')[1])
|
|
196
|
-
|
|
208
|
+
train_limit = int(15_000 * self.num_groups * 0.9)
|
|
209
|
+
val_limit = train_limit + int(15_000 * self.num_groups * 0.05)
|
|
210
|
+
if i < train_limit:
|
|
197
211
|
train_data_list.append(data)
|
|
198
|
-
elif i <
|
|
212
|
+
elif i < val_limit:
|
|
199
213
|
val_data_list.append(data)
|
|
200
214
|
else:
|
|
201
215
|
test_data_list.append(data)
|
|
@@ -66,7 +66,7 @@ class PascalPF(InMemoryDataset):
|
|
|
66
66
|
super().__init__(root, transform, pre_transform, pre_filter,
|
|
67
67
|
force_reload=force_reload)
|
|
68
68
|
self.load(self.processed_paths[0])
|
|
69
|
-
self.pairs =
|
|
69
|
+
self.pairs = fs.torch_load(self.processed_paths[1])
|
|
70
70
|
|
|
71
71
|
@property
|
|
72
72
|
def raw_file_names(self) -> List[str]:
|
|
@@ -7,6 +7,7 @@ from tqdm import tqdm
|
|
|
7
7
|
|
|
8
8
|
from torch_geometric.data import Data, OnDiskDataset, download_url, extract_zip
|
|
9
9
|
from torch_geometric.data.data import BaseData
|
|
10
|
+
from torch_geometric.io import fs
|
|
10
11
|
from torch_geometric.utils import from_smiles as _from_smiles
|
|
11
12
|
|
|
12
13
|
|
|
@@ -72,7 +73,7 @@ class PCQM4Mv2(OnDiskDataset):
|
|
|
72
73
|
self.from_smiles = from_smiles or _from_smiles
|
|
73
74
|
super().__init__(root, transform, backend=backend, schema=schema)
|
|
74
75
|
|
|
75
|
-
split_idx =
|
|
76
|
+
split_idx = fs.torch_load(self.raw_paths[1])
|
|
76
77
|
self._indices = split_idx[self.split_mapping[split]].tolist()
|
|
77
78
|
|
|
78
79
|
@property
|
torch_geometric/datasets/ppi.py
CHANGED
|
@@ -107,7 +107,8 @@ class PPI(InMemoryDataset):
|
|
|
107
107
|
for s, split in enumerate(['train', 'valid', 'test']):
|
|
108
108
|
path = osp.join(self.raw_dir, f'{split}_graph.json')
|
|
109
109
|
with open(path) as f:
|
|
110
|
-
G = nx.DiGraph(
|
|
110
|
+
G = nx.DiGraph(
|
|
111
|
+
json_graph.node_link_graph(json.load(f), edges="links"))
|
|
111
112
|
|
|
112
113
|
x = np.load(osp.join(self.raw_dir, f'{split}_feats.npy'))
|
|
113
114
|
x = torch.from_numpy(x).to(torch.float)
|