pyg-nightly 2.6.0.dev20240318__py3-none-any.whl → 2.7.0.dev20250115__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- {pyg_nightly-2.6.0.dev20240318.dist-info → pyg_nightly-2.7.0.dev20250115.dist-info}/METADATA +31 -47
- {pyg_nightly-2.6.0.dev20240318.dist-info → pyg_nightly-2.7.0.dev20250115.dist-info}/RECORD +226 -199
- {pyg_nightly-2.6.0.dev20240318.dist-info → pyg_nightly-2.7.0.dev20250115.dist-info}/WHEEL +1 -1
- torch_geometric/__init__.py +28 -1
- torch_geometric/_compile.py +8 -1
- torch_geometric/_onnx.py +14 -0
- torch_geometric/config_mixin.py +113 -0
- torch_geometric/config_store.py +28 -19
- torch_geometric/data/__init__.py +24 -1
- torch_geometric/data/batch.py +2 -2
- torch_geometric/data/collate.py +8 -2
- torch_geometric/data/data.py +16 -8
- torch_geometric/data/database.py +61 -15
- torch_geometric/data/dataset.py +14 -6
- torch_geometric/data/feature_store.py +25 -42
- torch_geometric/data/graph_store.py +1 -5
- torch_geometric/data/hetero_data.py +18 -9
- torch_geometric/data/in_memory_dataset.py +2 -4
- torch_geometric/data/large_graph_indexer.py +677 -0
- torch_geometric/data/lightning/datamodule.py +4 -4
- torch_geometric/data/separate.py +6 -1
- torch_geometric/data/storage.py +17 -7
- torch_geometric/data/summary.py +14 -4
- torch_geometric/data/temporal.py +1 -2
- torch_geometric/datasets/__init__.py +17 -2
- torch_geometric/datasets/actor.py +9 -11
- torch_geometric/datasets/airfrans.py +15 -18
- torch_geometric/datasets/airports.py +10 -12
- torch_geometric/datasets/amazon.py +8 -11
- torch_geometric/datasets/amazon_book.py +9 -10
- torch_geometric/datasets/amazon_products.py +9 -10
- torch_geometric/datasets/aminer.py +8 -9
- torch_geometric/datasets/aqsol.py +10 -13
- torch_geometric/datasets/attributed_graph_dataset.py +10 -12
- torch_geometric/datasets/ba_multi_shapes.py +10 -12
- torch_geometric/datasets/ba_shapes.py +5 -6
- torch_geometric/datasets/bitcoin_otc.py +1 -1
- torch_geometric/datasets/brca_tgca.py +1 -1
- torch_geometric/datasets/cornell.py +145 -0
- torch_geometric/datasets/dblp.py +2 -1
- torch_geometric/datasets/dbp15k.py +2 -2
- torch_geometric/datasets/fake.py +1 -3
- torch_geometric/datasets/flickr.py +2 -1
- torch_geometric/datasets/freebase.py +1 -1
- torch_geometric/datasets/gdelt_lite.py +3 -2
- torch_geometric/datasets/ged_dataset.py +3 -2
- torch_geometric/datasets/git_mol_dataset.py +263 -0
- torch_geometric/datasets/gnn_benchmark_dataset.py +11 -10
- torch_geometric/datasets/hgb_dataset.py +8 -8
- torch_geometric/datasets/imdb.py +2 -1
- torch_geometric/datasets/karate.py +3 -2
- torch_geometric/datasets/last_fm.py +2 -1
- torch_geometric/datasets/linkx_dataset.py +4 -3
- torch_geometric/datasets/lrgb.py +3 -5
- torch_geometric/datasets/malnet_tiny.py +4 -3
- torch_geometric/datasets/mnist_superpixels.py +2 -3
- torch_geometric/datasets/molecule_gpt_dataset.py +485 -0
- torch_geometric/datasets/molecule_net.py +15 -3
- torch_geometric/datasets/motif_generator/base.py +0 -1
- torch_geometric/datasets/neurograph.py +1 -3
- torch_geometric/datasets/ogb_mag.py +1 -1
- torch_geometric/datasets/opf.py +239 -0
- torch_geometric/datasets/ose_gvcs.py +1 -1
- torch_geometric/datasets/pascal.py +11 -9
- torch_geometric/datasets/pascal_pf.py +1 -1
- torch_geometric/datasets/pcpnet_dataset.py +1 -1
- torch_geometric/datasets/pcqm4m.py +10 -3
- torch_geometric/datasets/ppi.py +1 -1
- torch_geometric/datasets/qm9.py +8 -7
- torch_geometric/datasets/rcdd.py +4 -4
- torch_geometric/datasets/reddit.py +2 -1
- torch_geometric/datasets/reddit2.py +2 -1
- torch_geometric/datasets/rel_link_pred_dataset.py +3 -3
- torch_geometric/datasets/s3dis.py +5 -3
- torch_geometric/datasets/shapenet.py +3 -3
- torch_geometric/datasets/shrec2016.py +2 -2
- torch_geometric/datasets/snap_dataset.py +7 -1
- torch_geometric/datasets/tag_dataset.py +350 -0
- torch_geometric/datasets/upfd.py +2 -1
- torch_geometric/datasets/web_qsp_dataset.py +246 -0
- torch_geometric/datasets/webkb.py +2 -2
- torch_geometric/datasets/wikics.py +1 -1
- torch_geometric/datasets/wikidata.py +3 -2
- torch_geometric/datasets/wikipedia_network.py +2 -2
- torch_geometric/datasets/willow_object_class.py +1 -1
- torch_geometric/datasets/word_net.py +2 -2
- torch_geometric/datasets/yelp.py +2 -1
- torch_geometric/datasets/zinc.py +1 -1
- torch_geometric/device.py +42 -0
- torch_geometric/distributed/local_feature_store.py +3 -2
- torch_geometric/distributed/local_graph_store.py +2 -1
- torch_geometric/distributed/partition.py +9 -8
- torch_geometric/edge_index.py +616 -438
- torch_geometric/explain/algorithm/base.py +0 -1
- torch_geometric/explain/algorithm/graphmask_explainer.py +1 -2
- torch_geometric/explain/algorithm/pg_explainer.py +1 -1
- torch_geometric/explain/explanation.py +2 -2
- torch_geometric/graphgym/checkpoint.py +2 -1
- torch_geometric/graphgym/logger.py +4 -4
- torch_geometric/graphgym/loss.py +1 -1
- torch_geometric/graphgym/utils/agg_runs.py +6 -6
- torch_geometric/index.py +826 -0
- torch_geometric/inspector.py +13 -7
- torch_geometric/io/fs.py +28 -2
- torch_geometric/io/npz.py +2 -1
- torch_geometric/io/off.py +2 -2
- torch_geometric/io/sdf.py +2 -2
- torch_geometric/io/tu.py +4 -5
- torch_geometric/loader/__init__.py +4 -0
- torch_geometric/loader/cluster.py +10 -4
- torch_geometric/loader/graph_saint.py +2 -1
- torch_geometric/loader/ibmb_loader.py +12 -4
- torch_geometric/loader/mixin.py +1 -1
- torch_geometric/loader/neighbor_loader.py +1 -1
- torch_geometric/loader/neighbor_sampler.py +2 -2
- torch_geometric/loader/prefetch.py +1 -1
- torch_geometric/loader/rag_loader.py +107 -0
- torch_geometric/loader/utils.py +8 -7
- torch_geometric/loader/zip_loader.py +10 -0
- torch_geometric/metrics/__init__.py +11 -2
- torch_geometric/metrics/link_pred.py +317 -65
- torch_geometric/nn/aggr/__init__.py +4 -0
- torch_geometric/nn/aggr/attention.py +0 -2
- torch_geometric/nn/aggr/base.py +3 -5
- torch_geometric/nn/aggr/patch_transformer.py +143 -0
- torch_geometric/nn/aggr/set_transformer.py +1 -1
- torch_geometric/nn/aggr/variance_preserving.py +33 -0
- torch_geometric/nn/attention/__init__.py +5 -1
- torch_geometric/nn/attention/qformer.py +71 -0
- torch_geometric/nn/conv/collect.jinja +7 -4
- torch_geometric/nn/conv/cugraph/base.py +8 -12
- torch_geometric/nn/conv/edge_conv.py +3 -2
- torch_geometric/nn/conv/fused_gat_conv.py +1 -1
- torch_geometric/nn/conv/gat_conv.py +35 -7
- torch_geometric/nn/conv/gatv2_conv.py +36 -6
- torch_geometric/nn/conv/general_conv.py +1 -1
- torch_geometric/nn/conv/graph_conv.py +21 -3
- torch_geometric/nn/conv/gravnet_conv.py +3 -2
- torch_geometric/nn/conv/hetero_conv.py +3 -3
- torch_geometric/nn/conv/hgt_conv.py +1 -1
- torch_geometric/nn/conv/message_passing.py +138 -87
- torch_geometric/nn/conv/mixhop_conv.py +1 -1
- torch_geometric/nn/conv/propagate.jinja +9 -1
- torch_geometric/nn/conv/rgcn_conv.py +5 -5
- torch_geometric/nn/conv/spline_conv.py +4 -4
- torch_geometric/nn/conv/x_conv.py +3 -2
- torch_geometric/nn/dense/linear.py +11 -6
- torch_geometric/nn/fx.py +3 -3
- torch_geometric/nn/model_hub.py +3 -1
- torch_geometric/nn/models/__init__.py +10 -2
- torch_geometric/nn/models/deep_graph_infomax.py +1 -2
- torch_geometric/nn/models/dimenet_utils.py +5 -7
- torch_geometric/nn/models/g_retriever.py +230 -0
- torch_geometric/nn/models/git_mol.py +336 -0
- torch_geometric/nn/models/glem.py +385 -0
- torch_geometric/nn/models/gnnff.py +0 -1
- torch_geometric/nn/models/graph_unet.py +12 -3
- torch_geometric/nn/models/jumping_knowledge.py +63 -4
- torch_geometric/nn/models/lightgcn.py +1 -1
- torch_geometric/nn/models/metapath2vec.py +5 -5
- torch_geometric/nn/models/molecule_gpt.py +222 -0
- torch_geometric/nn/models/node2vec.py +2 -3
- torch_geometric/nn/models/schnet.py +2 -1
- torch_geometric/nn/models/signed_gcn.py +3 -3
- torch_geometric/nn/module_dict.py +2 -2
- torch_geometric/nn/nlp/__init__.py +9 -0
- torch_geometric/nn/nlp/llm.py +329 -0
- torch_geometric/nn/nlp/sentence_transformer.py +134 -0
- torch_geometric/nn/nlp/vision_transformer.py +33 -0
- torch_geometric/nn/norm/batch_norm.py +1 -1
- torch_geometric/nn/parameter_dict.py +2 -2
- torch_geometric/nn/pool/__init__.py +21 -5
- torch_geometric/nn/pool/cluster_pool.py +145 -0
- torch_geometric/nn/pool/connect/base.py +0 -1
- torch_geometric/nn/pool/edge_pool.py +1 -1
- torch_geometric/nn/pool/graclus.py +4 -2
- torch_geometric/nn/pool/pool.py +8 -2
- torch_geometric/nn/pool/select/base.py +0 -1
- torch_geometric/nn/pool/voxel_grid.py +3 -2
- torch_geometric/nn/resolver.py +1 -1
- torch_geometric/nn/sequential.jinja +10 -23
- torch_geometric/nn/sequential.py +204 -78
- torch_geometric/nn/summary.py +1 -1
- torch_geometric/nn/to_hetero_with_bases_transformer.py +19 -19
- torch_geometric/profile/__init__.py +2 -0
- torch_geometric/profile/nvtx.py +66 -0
- torch_geometric/profile/profiler.py +30 -19
- torch_geometric/resolver.py +1 -1
- torch_geometric/sampler/base.py +34 -13
- torch_geometric/sampler/neighbor_sampler.py +11 -10
- torch_geometric/sampler/utils.py +1 -1
- torch_geometric/template.py +1 -0
- torch_geometric/testing/__init__.py +6 -2
- torch_geometric/testing/decorators.py +56 -22
- torch_geometric/testing/feature_store.py +1 -1
- torch_geometric/transforms/__init__.py +2 -0
- torch_geometric/transforms/add_metapaths.py +5 -5
- torch_geometric/transforms/add_positional_encoding.py +1 -1
- torch_geometric/transforms/delaunay.py +65 -14
- torch_geometric/transforms/face_to_edge.py +32 -3
- torch_geometric/transforms/gdc.py +7 -6
- torch_geometric/transforms/laplacian_lambda_max.py +3 -3
- torch_geometric/transforms/mask.py +5 -1
- torch_geometric/transforms/node_property_split.py +1 -2
- torch_geometric/transforms/pad.py +7 -6
- torch_geometric/transforms/random_link_split.py +1 -1
- torch_geometric/transforms/remove_self_loops.py +36 -0
- torch_geometric/transforms/svd_feature_reduction.py +1 -1
- torch_geometric/transforms/to_sparse_tensor.py +1 -1
- torch_geometric/transforms/two_hop.py +1 -1
- torch_geometric/transforms/virtual_node.py +2 -1
- torch_geometric/typing.py +43 -6
- torch_geometric/utils/__init__.py +5 -1
- torch_geometric/utils/_negative_sampling.py +1 -1
- torch_geometric/utils/_normalize_edge_index.py +46 -0
- torch_geometric/utils/_scatter.py +38 -12
- torch_geometric/utils/_subgraph.py +4 -0
- torch_geometric/utils/_tree_decomposition.py +2 -2
- torch_geometric/utils/augmentation.py +1 -1
- torch_geometric/utils/convert.py +12 -8
- torch_geometric/utils/geodesic.py +24 -22
- torch_geometric/utils/hetero.py +1 -1
- torch_geometric/utils/map.py +8 -2
- torch_geometric/utils/smiles.py +65 -27
- torch_geometric/utils/sparse.py +39 -25
- torch_geometric/visualization/graph.py +3 -4
@@ -0,0 +1,485 @@
|
|
1
|
+
import gzip
|
2
|
+
import json
|
3
|
+
import multiprocessing
|
4
|
+
import os
|
5
|
+
import sys
|
6
|
+
from collections import defaultdict
|
7
|
+
from multiprocessing import Pool
|
8
|
+
from typing import Callable, List, Optional, Tuple
|
9
|
+
|
10
|
+
import numpy as np
|
11
|
+
import requests
|
12
|
+
import torch
|
13
|
+
from tqdm import tqdm
|
14
|
+
|
15
|
+
from torch_geometric.data import Data, InMemoryDataset, download_url
|
16
|
+
from torch_geometric.io import fs
|
17
|
+
from torch_geometric.nn.nlp import LLM
|
18
|
+
from torch_geometric.utils import one_hot
|
19
|
+
|
20
|
+
|
21
|
+
def clean_up_description(description: str) -> str:
|
22
|
+
description = description + " "
|
23
|
+
|
24
|
+
# extra adj Pure
|
25
|
+
if description.startswith("Pure "):
|
26
|
+
description = description.replace("Pure ", "")
|
27
|
+
# fix typo
|
28
|
+
if description.startswith("Mercurycombines"):
|
29
|
+
description = description.replace("Mercurycombines",
|
30
|
+
"Mercury combines")
|
31
|
+
|
32
|
+
# a special case
|
33
|
+
description = description.replace(
|
34
|
+
"17-Hydroxy-6-methylpregna-3,6-diene-3,20-dione. ",
|
35
|
+
"17-Hydroxy-6-methylpregna-3,6-diene-3,20-dione is ")
|
36
|
+
|
37
|
+
# a special case
|
38
|
+
description = description.replace("5-Thymidylic acid. ",
|
39
|
+
"5-Thymidylic acid. is ")
|
40
|
+
|
41
|
+
# a special case
|
42
|
+
description = description.replace(
|
43
|
+
"5'-S-(3-Amino-3-carboxypropyl)-5'-thioadenosine. ",
|
44
|
+
"5'-S-(3-Amino-3-carboxypropyl)-5'-thioadenosine. is ")
|
45
|
+
|
46
|
+
# a special case
|
47
|
+
description = description.replace(
|
48
|
+
("Guanosine 5'-(trihydrogen diphosphate), monoanhydride"
|
49
|
+
" with phosphorothioic acid. "),
|
50
|
+
("Guanosine 5'-(trihydrogen diphosphate), monoanhydride"
|
51
|
+
" with phosphorothioic acid is "))
|
52
|
+
|
53
|
+
# a special case
|
54
|
+
description = description.replace("5'-Uridylic acid. ",
|
55
|
+
"5'-Uridylic acid is ")
|
56
|
+
|
57
|
+
# a special case
|
58
|
+
description = description.replace("5'-Adenylic acid, ",
|
59
|
+
"5'-Adenylic acid is ")
|
60
|
+
|
61
|
+
# a special case
|
62
|
+
description = description.replace(
|
63
|
+
"Uridine 5'-(tetrahydrogen triphosphate). ",
|
64
|
+
"Uridine 5'-(tetrahydrogen triphosphate). is ")
|
65
|
+
|
66
|
+
# a special case
|
67
|
+
description = description.replace("Inosine 5'-Monophosphate. ",
|
68
|
+
"Inosine 5'-Monophosphate. is ")
|
69
|
+
|
70
|
+
# a special case
|
71
|
+
description = description.replace("Pivaloyloxymethyl butyrate (AN-9), ",
|
72
|
+
"Pivaloyloxymethyl butyrate (AN-9) is ")
|
73
|
+
|
74
|
+
# a special case
|
75
|
+
description = description.replace(
|
76
|
+
"4-Amino-5-cyano-7-(D-ribofuranosyl)-7H- pyrrolo(2,3-d)pyrimidine. ",
|
77
|
+
"4-Amino-5-cyano-7-(D-ribofuranosyl)-7H- pyrrolo(2,3-d)pyrimidine is ")
|
78
|
+
|
79
|
+
# a special case
|
80
|
+
description = description.replace(
|
81
|
+
"Cardamonin (also known as Dihydroxymethoxychalcone), ",
|
82
|
+
"Cardamonin (also known as Dihydroxymethoxychalcone) is ")
|
83
|
+
|
84
|
+
# a special case
|
85
|
+
description = description.replace("Lithium has been used to treat ",
|
86
|
+
"Lithium is ")
|
87
|
+
|
88
|
+
# a special case
|
89
|
+
description = description.replace("4,4'-Methylenebis ",
|
90
|
+
"4,4'-Methylenebis is ")
|
91
|
+
|
92
|
+
# a special case
|
93
|
+
description = description.replace(
|
94
|
+
"2,3,7,8-Tetrachlorodibenzo-p-dioxin",
|
95
|
+
"2,3,7,8-Tetrachlorodibenzo-p-dioxin is ")
|
96
|
+
|
97
|
+
# a special case
|
98
|
+
description = description.replace("Exposure to 2,4,5-trichlorophenol ",
|
99
|
+
"2,4,5-Trichlorophenol exposure ")
|
100
|
+
|
101
|
+
index = 0
|
102
|
+
L = len(description)
|
103
|
+
if description.startswith('C.I. '):
|
104
|
+
start_index = len('C.I. ')
|
105
|
+
elif description.startswith('Nectriapyrone. D '):
|
106
|
+
start_index = len('Nectriapyrone. D ')
|
107
|
+
elif description.startswith(
|
108
|
+
'Salmonella enterica sv. Minnesota LPS core oligosaccharide'):
|
109
|
+
start_index = len(
|
110
|
+
'Salmonella enterica sv. Minnesota LPS core oligosaccharide')
|
111
|
+
else:
|
112
|
+
start_index = 0
|
113
|
+
for index in range(start_index, L - 1):
|
114
|
+
if index < L - 2:
|
115
|
+
if description[index] == '.' and description[
|
116
|
+
index + 1] == ' ' and 'A' <= description[index + 2] <= 'Z':
|
117
|
+
break
|
118
|
+
elif index == L - 2:
|
119
|
+
break
|
120
|
+
|
121
|
+
first_sentence = description[:index + 1]
|
122
|
+
return first_sentence
|
123
|
+
|
124
|
+
|
125
|
+
def extract_name(
|
126
|
+
name_raw: str,
|
127
|
+
description: str,
|
128
|
+
) -> Tuple[Optional[str], str, str]:
|
129
|
+
first_sentence = clean_up_description(description)
|
130
|
+
|
131
|
+
splitter = ' -- -- '
|
132
|
+
if ' are ' in first_sentence or ' were ' in first_sentence:
|
133
|
+
replaced_words = 'These molecules'
|
134
|
+
else:
|
135
|
+
replaced_words = 'This molecule'
|
136
|
+
|
137
|
+
first_sentence = first_sentence.replace(' is ', splitter)
|
138
|
+
first_sentence = first_sentence.replace(' are ', splitter)
|
139
|
+
first_sentence = first_sentence.replace(' was ', splitter)
|
140
|
+
first_sentence = first_sentence.replace(' were ', splitter)
|
141
|
+
first_sentence = first_sentence.replace(' appears ', splitter)
|
142
|
+
first_sentence = first_sentence.replace(' occurs ', splitter)
|
143
|
+
first_sentence = first_sentence.replace(' stands for ', splitter)
|
144
|
+
first_sentence = first_sentence.replace(' belongs to ', splitter)
|
145
|
+
first_sentence = first_sentence.replace(' exists ',
|
146
|
+
splitter) # only for CID=11443
|
147
|
+
first_sentence = first_sentence.replace(' has been used in trials ',
|
148
|
+
splitter)
|
149
|
+
first_sentence = first_sentence.replace(' has been investigated ',
|
150
|
+
splitter)
|
151
|
+
first_sentence = first_sentence.replace(' has many uses ', splitter)
|
152
|
+
|
153
|
+
if splitter in first_sentence:
|
154
|
+
extracted_name = first_sentence.split(splitter, 1)[0]
|
155
|
+
elif first_sentence.startswith(name_raw):
|
156
|
+
extracted_name = name_raw
|
157
|
+
elif name_raw in first_sentence:
|
158
|
+
extracted_name = name_raw
|
159
|
+
extracted_name = None
|
160
|
+
print("=====", name_raw)
|
161
|
+
print("first sentence: ", first_sentence)
|
162
|
+
else:
|
163
|
+
extracted_name = None
|
164
|
+
|
165
|
+
if extracted_name is not None:
|
166
|
+
extracted_description = description.replace(extracted_name,
|
167
|
+
replaced_words)
|
168
|
+
else:
|
169
|
+
extracted_description = description
|
170
|
+
|
171
|
+
return extracted_name, extracted_description, first_sentence
|
172
|
+
|
173
|
+
|
174
|
+
class MoleculeGPTDataset(InMemoryDataset):
|
175
|
+
r"""The dataset from the `"MoleculeGPT: Instruction Following Large
|
176
|
+
Language Models for Molecular Property Prediction"
|
177
|
+
<https://ai4d3.github.io/papers/34.pdf>`_ paper.
|
178
|
+
|
179
|
+
Args:
|
180
|
+
root (str): Root directory where the dataset should be saved.
|
181
|
+
transform (callable, optional): A function/transform that takes in an
|
182
|
+
:obj:`torch_geometric.data.Data` object and returns a transformed
|
183
|
+
version. The data object will be transformed before every access.
|
184
|
+
(default: :obj:`None`)
|
185
|
+
pre_transform (callable, optional): A function/transform that takes in
|
186
|
+
an :obj:`torch_geometric.data.Data` object and returns a
|
187
|
+
transformed version. The data object will be transformed before
|
188
|
+
being saved to disk. (default: :obj:`None`)
|
189
|
+
pre_filter (callable, optional): A function that takes in an
|
190
|
+
:obj:`torch_geometric.data.Data` object and returns a boolean
|
191
|
+
value, indicating whether the data object should be included in the
|
192
|
+
final dataset. (default: :obj:`None`)
|
193
|
+
force_reload (bool, optional): Whether to re-process the dataset.
|
194
|
+
(default: :obj:`False`)
|
195
|
+
total_page_num (int, optional): The number of pages from PubChem.
|
196
|
+
(default: :obj:`10`)
|
197
|
+
total_block_num (int, optional): The blocks of SDF files from PubChem.
|
198
|
+
(default: :obj:`1`)
|
199
|
+
"""
|
200
|
+
description_url = (
|
201
|
+
'https://pubchem.ncbi.nlm.nih.gov/rest/pug_view/annotations/'
|
202
|
+
'heading/json?heading_type=Compound&heading=Record+Description&page={}'
|
203
|
+
)
|
204
|
+
compound_url = ('https://ftp.ncbi.nlm.nih.gov/pubchem/Compound/'
|
205
|
+
'CURRENT-Full/SDF')
|
206
|
+
|
207
|
+
def __init__(
|
208
|
+
self,
|
209
|
+
root: str,
|
210
|
+
transform: Optional[Callable] = None,
|
211
|
+
pre_transform: Optional[Callable] = None,
|
212
|
+
pre_filter: Optional[Callable] = None,
|
213
|
+
force_reload: bool = False,
|
214
|
+
total_page_num: int = 10,
|
215
|
+
total_block_num: int = 1,
|
216
|
+
):
|
217
|
+
self.total_page_num = total_page_num
|
218
|
+
self.total_block_num = total_block_num
|
219
|
+
|
220
|
+
super().__init__(root, transform, pre_transform, pre_filter,
|
221
|
+
force_reload=force_reload)
|
222
|
+
self.load(self.processed_paths[0])
|
223
|
+
|
224
|
+
@property
|
225
|
+
def raw_file_names(self) -> List[str]:
|
226
|
+
return ['pubchem.csv']
|
227
|
+
|
228
|
+
@property
|
229
|
+
def processed_file_names(self) -> List[str]:
|
230
|
+
return ['data.pt']
|
231
|
+
|
232
|
+
def download(self) -> None:
|
233
|
+
# Step 01. Extract description
|
234
|
+
step1_folder = f"{self.raw_dir}/step_01_PubChemSTM_description"
|
235
|
+
if not os.path.exists(step1_folder):
|
236
|
+
os.makedirs(step1_folder)
|
237
|
+
valid_CID_set = set()
|
238
|
+
CID2name_raw, CID2name_extracted = defaultdict(list), defaultdict(
|
239
|
+
list)
|
240
|
+
CID2text_raw, CID2text_extracted = defaultdict(list), defaultdict(
|
241
|
+
list)
|
242
|
+
|
243
|
+
for page_index in tqdm(range(self.total_page_num)):
|
244
|
+
page_num = page_index + 1
|
245
|
+
f_out = open(
|
246
|
+
f"{step1_folder}/Compound_description_{page_num}.txt", "w")
|
247
|
+
|
248
|
+
description_data = requests.get(
|
249
|
+
self.description_url.format(page_num)).json()
|
250
|
+
|
251
|
+
description_data = description_data["Annotations"]
|
252
|
+
assert description_data["Page"] == page_num
|
253
|
+
|
254
|
+
record_list = description_data["Annotation"]
|
255
|
+
|
256
|
+
for record in record_list:
|
257
|
+
try:
|
258
|
+
CID = record["LinkedRecords"]["CID"][0]
|
259
|
+
if "Name" in record:
|
260
|
+
name_raw = record["Name"]
|
261
|
+
CID2name_raw[CID].append(name_raw)
|
262
|
+
else:
|
263
|
+
name_raw = None
|
264
|
+
|
265
|
+
data_list = record["Data"]
|
266
|
+
for data in data_list:
|
267
|
+
description = data["Value"]["StringWithMarkup"][0][
|
268
|
+
"String"].strip()
|
269
|
+
|
270
|
+
extracted_name, extracted_description, _ = extract_name( # noqa: E501
|
271
|
+
name_raw, description)
|
272
|
+
if extracted_name is not None:
|
273
|
+
CID2name_extracted[CID].append(extracted_name)
|
274
|
+
|
275
|
+
CID2text_raw[CID].append(description)
|
276
|
+
CID2text_extracted[CID].append(
|
277
|
+
extracted_description)
|
278
|
+
|
279
|
+
valid_CID_set.add(CID)
|
280
|
+
f_out.write(f"{CID}\n")
|
281
|
+
f_out.write(f"{extracted_description}\n\n")
|
282
|
+
except Exception:
|
283
|
+
continue
|
284
|
+
|
285
|
+
valid_CID_list = sorted(list(valid_CID_set))
|
286
|
+
print(f"Total CID (with raw name) {len(CID2name_raw)}")
|
287
|
+
print(f"Total CID (with extracted name) {len(CID2name_extracted)}")
|
288
|
+
print(f"Total CID {len(valid_CID_list)}")
|
289
|
+
|
290
|
+
with open(f"{self.raw_dir}/CID2name_raw.json", "w") as f:
|
291
|
+
json.dump(CID2name_raw, f)
|
292
|
+
|
293
|
+
with open(f"{self.raw_dir}/CID2name.json", "w") as f:
|
294
|
+
json.dump(CID2name_extracted, f)
|
295
|
+
|
296
|
+
with open(f"{self.raw_dir}/CID2text_raw.json", "w") as f:
|
297
|
+
json.dump(CID2text_raw, f)
|
298
|
+
|
299
|
+
with open(f"{self.raw_dir}/CID2text.json", "w") as f:
|
300
|
+
json.dump(CID2text_extracted, f)
|
301
|
+
|
302
|
+
# Step 02. Download SDF Files
|
303
|
+
step2_folder = f"{self.raw_dir}/step_02_PubChemSTM_SDF"
|
304
|
+
if not os.path.exists(step2_folder):
|
305
|
+
for block_id in tqdm(range(self.total_block_num)):
|
306
|
+
block_size = 500000
|
307
|
+
l_id = block_id * block_size + 1
|
308
|
+
r_id = (block_id + 1) * block_size
|
309
|
+
|
310
|
+
compound_file_name = f"Compound_{l_id:09d}_{r_id:09d}.sdf.gz"
|
311
|
+
download_url(f"{self.compound_url}/{compound_file_name}",
|
312
|
+
step2_folder)
|
313
|
+
|
314
|
+
def process(self, use_mp: bool = False) -> None:
|
315
|
+
try:
|
316
|
+
from rdkit import Chem
|
317
|
+
from rdkit.Chem.rdchem import BondType as BT
|
318
|
+
WITH_RDKIT = True
|
319
|
+
|
320
|
+
except ImportError:
|
321
|
+
WITH_RDKIT = False
|
322
|
+
|
323
|
+
if not WITH_RDKIT:
|
324
|
+
print(("Using a pre-processed version of the dataset. Please "
|
325
|
+
"install 'rdkit' to alternatively process the raw data."),
|
326
|
+
file=sys.stderr)
|
327
|
+
|
328
|
+
data_list = fs.torch_load(self.raw_paths[0])
|
329
|
+
data_list = [Data(**data_dict) for data_dict in data_list]
|
330
|
+
|
331
|
+
if self.pre_filter is not None:
|
332
|
+
data_list = [d for d in data_list if self.pre_filter(d)]
|
333
|
+
|
334
|
+
if self.pre_transform is not None:
|
335
|
+
data_list = [self.pre_transform(d) for d in data_list]
|
336
|
+
|
337
|
+
self.save(data_list, self.processed_paths[0])
|
338
|
+
return
|
339
|
+
|
340
|
+
# Step 03. Filter out SDF
|
341
|
+
step2_folder = f"{self.raw_dir}/step_02_PubChemSTM_SDF"
|
342
|
+
step3_folder = f"{self.raw_dir}/step_03_PubChemSTM_filtered"
|
343
|
+
if not os.path.exists(step3_folder):
|
344
|
+
os.makedirs(step3_folder)
|
345
|
+
with open(f"{self.raw_dir}/CID2text.json") as f:
|
346
|
+
CID2text = json.load(f)
|
347
|
+
target_CID_list = set(CID2text.keys())
|
348
|
+
|
349
|
+
block_size = 500000
|
350
|
+
|
351
|
+
def extract_one_SDF_file(block_id: int) -> None:
|
352
|
+
valid_mol_count = 0
|
353
|
+
|
354
|
+
writer = Chem.SDWriter(
|
355
|
+
f'{step3_folder}/filtered_{block_id}.sdf')
|
356
|
+
l_id = block_id * block_size + 1
|
357
|
+
r_id = (block_id + 1) * block_size
|
358
|
+
|
359
|
+
compound_file_name = f"Compound_{l_id:09d}_{r_id:09d}.sdf.gz"
|
360
|
+
gzip_loader = gzip.open(f"{step2_folder}/{compound_file_name}")
|
361
|
+
suppl = Chem.ForwardSDMolSupplier(gzip_loader)
|
362
|
+
|
363
|
+
for mol in tqdm(suppl):
|
364
|
+
if mol is None:
|
365
|
+
continue
|
366
|
+
cid = mol.GetProp("PUBCHEM_COMPOUND_CID")
|
367
|
+
|
368
|
+
if cid not in target_CID_list:
|
369
|
+
continue
|
370
|
+
|
371
|
+
writer.write(mol)
|
372
|
+
valid_mol_count += 1
|
373
|
+
|
374
|
+
writer.close()
|
375
|
+
print(f"block id: {block_id}\nfound {valid_mol_count}\n\n")
|
376
|
+
sys.stdout.flush()
|
377
|
+
return
|
378
|
+
|
379
|
+
if use_mp:
|
380
|
+
num_process = multiprocessing.cpu_count()
|
381
|
+
print(f"{num_process} CPUs")
|
382
|
+
num_process = 8
|
383
|
+
p = Pool(num_process)
|
384
|
+
|
385
|
+
block_id_list = np.arange(self.total_block_num)
|
386
|
+
with p:
|
387
|
+
p.map(extract_one_SDF_file, block_id_list)
|
388
|
+
else:
|
389
|
+
for block_id in range(self.total_block_num):
|
390
|
+
extract_one_SDF_file(block_id)
|
391
|
+
|
392
|
+
# Step 04. Merge SDF
|
393
|
+
with open(f"{self.raw_dir}/CID2text.json") as f:
|
394
|
+
CID2text = json.load(f)
|
395
|
+
target_CID_list = set(CID2text.keys())
|
396
|
+
print(f'The length of target_CID_list: {len(target_CID_list)}')
|
397
|
+
|
398
|
+
writer = Chem.SDWriter(f'{self.raw_dir}/molecules.sdf')
|
399
|
+
|
400
|
+
found_CID_set = set()
|
401
|
+
for block_id in range(self.total_block_num + 1):
|
402
|
+
compound_file_path = f"{step3_folder}/filtered_{block_id}.sdf"
|
403
|
+
try:
|
404
|
+
suppl = Chem.SDMolSupplier(compound_file_path)
|
405
|
+
|
406
|
+
for mol in tqdm(suppl):
|
407
|
+
writer.write(mol)
|
408
|
+
cid = mol.GetProp("PUBCHEM_COMPOUND_CID")
|
409
|
+
found_CID_set.add(cid)
|
410
|
+
except Exception:
|
411
|
+
print(f"block id: {block_id} with 0 valid SDF file")
|
412
|
+
continue
|
413
|
+
|
414
|
+
writer.close()
|
415
|
+
print(f"In total: {len(found_CID_set)} molecules")
|
416
|
+
|
417
|
+
# Step 05. Convert to PyG data format
|
418
|
+
types = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4, 'Unknow': 5}
|
419
|
+
bonds = {BT.SINGLE: 0, BT.DOUBLE: 1, BT.TRIPLE: 2, BT.AROMATIC: 3}
|
420
|
+
|
421
|
+
data_list = []
|
422
|
+
# Real data
|
423
|
+
CID2text_file = f'{self.raw_dir}/CID2text.json'
|
424
|
+
|
425
|
+
with open(CID2text_file) as f:
|
426
|
+
CID2text_data = json.load(f)
|
427
|
+
|
428
|
+
suppl = Chem.SDMolSupplier(f'{self.raw_dir}/molecules.sdf')
|
429
|
+
|
430
|
+
llm = LLM(
|
431
|
+
# model_name='lmsys/vicuna-7b-v1.5',
|
432
|
+
model_name='TinyLlama/TinyLlama-1.1B-Chat-v0.1',
|
433
|
+
num_params=1,
|
434
|
+
dtype=torch.bfloat16,
|
435
|
+
)
|
436
|
+
prompt = ("Propose a question regarding the molecule '∼' "
|
437
|
+
"whose answer is: {}:")
|
438
|
+
for mol in tqdm(suppl):
|
439
|
+
if mol.HasProp('PUBCHEM_COMPOUND_CID'):
|
440
|
+
CID = mol.GetProp("PUBCHEM_COMPOUND_CID")
|
441
|
+
CAN_SMILES = mol.GetProp("PUBCHEM_OPENEYE_CAN_SMILES")
|
442
|
+
|
443
|
+
m: Chem.Mol = Chem.MolFromSmiles(CAN_SMILES)
|
444
|
+
if m is None:
|
445
|
+
continue
|
446
|
+
RDKit_CAN_SMILES = Chem.MolToSmiles(m)
|
447
|
+
|
448
|
+
ground_truth = CID2text_data[CID][0]
|
449
|
+
|
450
|
+
instruction = llm.inference([prompt.format(ground_truth)])[0]
|
451
|
+
|
452
|
+
x: torch.Tensor = torch.tensor([
|
453
|
+
types[atom.GetSymbol()] if atom.GetSymbol() in types else 5
|
454
|
+
for atom in m.GetAtoms()
|
455
|
+
])
|
456
|
+
x = one_hot(x, num_classes=len(types), dtype=torch.float)
|
457
|
+
|
458
|
+
rows, cols, edge_types = [], [], []
|
459
|
+
for bond in m.GetBonds():
|
460
|
+
i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
|
461
|
+
edge_types += [bonds[bond.GetBondType()]] * 2
|
462
|
+
rows += [i, j]
|
463
|
+
cols += [j, i]
|
464
|
+
|
465
|
+
edge_index = torch.tensor([rows, cols], dtype=torch.long)
|
466
|
+
edge_type = torch.tensor(edge_types, dtype=torch.long)
|
467
|
+
edge_attr = one_hot(edge_type, num_classes=len(bonds))
|
468
|
+
|
469
|
+
data = Data(
|
470
|
+
x=x,
|
471
|
+
edge_index=edge_index,
|
472
|
+
edge_attr=edge_attr,
|
473
|
+
smiles=RDKit_CAN_SMILES,
|
474
|
+
instruction=instruction,
|
475
|
+
y=ground_truth,
|
476
|
+
)
|
477
|
+
|
478
|
+
if self.pre_filter is not None and not self.pre_filter(data):
|
479
|
+
continue
|
480
|
+
if self.pre_transform is not None:
|
481
|
+
data = self.pre_transform(data)
|
482
|
+
|
483
|
+
data_list.append(data)
|
484
|
+
|
485
|
+
self.save(data_list, self.processed_paths[0])
|
@@ -1,12 +1,13 @@
|
|
1
1
|
import os
|
2
2
|
import os.path as osp
|
3
3
|
import re
|
4
|
+
import warnings
|
4
5
|
from typing import Callable, Dict, Optional, Tuple, Union
|
5
6
|
|
6
7
|
import torch
|
7
8
|
|
8
9
|
from torch_geometric.data import InMemoryDataset, download_url, extract_gz
|
9
|
-
from torch_geometric.utils import from_smiles
|
10
|
+
from torch_geometric.utils import from_smiles as _from_smiles
|
10
11
|
|
11
12
|
|
12
13
|
class MoleculeNet(InMemoryDataset):
|
@@ -38,6 +39,10 @@ class MoleculeNet(InMemoryDataset):
|
|
38
39
|
final dataset. (default: :obj:`None`)
|
39
40
|
force_reload (bool, optional): Whether to re-process the dataset.
|
40
41
|
(default: :obj:`False`)
|
42
|
+
from_smiles (callable, optional): A custom function that takes a SMILES
|
43
|
+
string and outputs a :obj:`~torch_geometric.data.Data` object.
|
44
|
+
If not set, defaults to :meth:`~torch_geometric.utils.from_smiles`.
|
45
|
+
(default: :obj:`None`)
|
41
46
|
|
42
47
|
**STATS:**
|
43
48
|
|
@@ -152,9 +157,11 @@ class MoleculeNet(InMemoryDataset):
|
|
152
157
|
pre_transform: Optional[Callable] = None,
|
153
158
|
pre_filter: Optional[Callable] = None,
|
154
159
|
force_reload: bool = False,
|
160
|
+
from_smiles: Optional[Callable] = None,
|
155
161
|
) -> None:
|
156
162
|
self.name = name.lower()
|
157
163
|
assert self.name in self.names.keys()
|
164
|
+
self.from_smiles = from_smiles or _from_smiles
|
158
165
|
super().__init__(root, transform, pre_transform, pre_filter,
|
159
166
|
force_reload=force_reload)
|
160
167
|
self.load(self.processed_paths[0])
|
@@ -183,7 +190,7 @@ class MoleculeNet(InMemoryDataset):
|
|
183
190
|
os.unlink(path)
|
184
191
|
|
185
192
|
def process(self) -> None:
|
186
|
-
with open(self.raw_paths[0]
|
193
|
+
with open(self.raw_paths[0]) as f:
|
187
194
|
dataset = f.read().split('\n')[1:-1]
|
188
195
|
dataset = [x for x in dataset if len(x) > 0] # Filter empty lines.
|
189
196
|
|
@@ -199,9 +206,14 @@ class MoleculeNet(InMemoryDataset):
|
|
199
206
|
ys = [float(y) if len(y) > 0 else float('NaN') for y in labels]
|
200
207
|
y = torch.tensor(ys, dtype=torch.float).view(1, -1)
|
201
208
|
|
202
|
-
data = from_smiles(smiles)
|
209
|
+
data = self.from_smiles(smiles)
|
203
210
|
data.y = y
|
204
211
|
|
212
|
+
if data.num_nodes == 0:
|
213
|
+
warnings.warn(f"Skipping molecule '{smiles}' since it "
|
214
|
+
f"resulted in zero atoms")
|
215
|
+
continue
|
216
|
+
|
205
217
|
if self.pre_filter is not None and not self.pre_filter(data):
|
206
218
|
continue
|
207
219
|
|
@@ -2,8 +2,6 @@ import os
|
|
2
2
|
import os.path as osp
|
3
3
|
from typing import Callable, List, Optional
|
4
4
|
|
5
|
-
import torch
|
6
|
-
|
7
5
|
from torch_geometric.data import (
|
8
6
|
Data,
|
9
7
|
InMemoryDataset,
|
@@ -110,7 +108,7 @@ class NeuroGraphDataset(InMemoryDataset):
|
|
110
108
|
fs.rm(osp.join(self.raw_dir, self.name))
|
111
109
|
|
112
110
|
def process(self) -> None:
|
113
|
-
data, slices =
|
111
|
+
data, slices = fs.torch_load(self.raw_paths[0])
|
114
112
|
|
115
113
|
num_samples = slices['x'].size(0) - 1
|
116
114
|
data_list: List[Data] = []
|
@@ -147,7 +147,7 @@ class OGB_MAG(InMemoryDataset):
|
|
147
147
|
for node_type in ['author', 'institution', 'field_of_study']:
|
148
148
|
data[node_type].num_nodes = num_nodes_df[node_type].tolist()[0]
|
149
149
|
else:
|
150
|
-
emb_dict =
|
150
|
+
emb_dict = fs.torch_load(self.raw_paths[-1])
|
151
151
|
for key, value in emb_dict.items():
|
152
152
|
if key != 'paper':
|
153
153
|
data[key].x = value
|