pyg-nightly 2.6.0.dev20240511__py3-none-any.whl → 2.7.0.dev20250114__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (205) hide show
  1. {pyg_nightly-2.6.0.dev20240511.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/METADATA +30 -31
  2. {pyg_nightly-2.6.0.dev20240511.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/RECORD +205 -181
  3. {pyg_nightly-2.6.0.dev20240511.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/WHEEL +1 -1
  4. torch_geometric/__init__.py +26 -1
  5. torch_geometric/_compile.py +8 -1
  6. torch_geometric/_onnx.py +14 -0
  7. torch_geometric/config_mixin.py +113 -0
  8. torch_geometric/config_store.py +16 -14
  9. torch_geometric/data/__init__.py +24 -1
  10. torch_geometric/data/batch.py +2 -2
  11. torch_geometric/data/data.py +13 -8
  12. torch_geometric/data/database.py +15 -7
  13. torch_geometric/data/dataset.py +14 -6
  14. torch_geometric/data/feature_store.py +13 -22
  15. torch_geometric/data/graph_store.py +0 -4
  16. torch_geometric/data/hetero_data.py +4 -4
  17. torch_geometric/data/in_memory_dataset.py +2 -4
  18. torch_geometric/data/large_graph_indexer.py +677 -0
  19. torch_geometric/data/lightning/datamodule.py +4 -4
  20. torch_geometric/data/storage.py +15 -5
  21. torch_geometric/data/summary.py +14 -4
  22. torch_geometric/data/temporal.py +1 -2
  23. torch_geometric/datasets/__init__.py +11 -1
  24. torch_geometric/datasets/actor.py +9 -11
  25. torch_geometric/datasets/airfrans.py +15 -18
  26. torch_geometric/datasets/airports.py +10 -12
  27. torch_geometric/datasets/amazon.py +8 -11
  28. torch_geometric/datasets/amazon_book.py +9 -10
  29. torch_geometric/datasets/amazon_products.py +9 -10
  30. torch_geometric/datasets/aminer.py +8 -9
  31. torch_geometric/datasets/aqsol.py +10 -13
  32. torch_geometric/datasets/attributed_graph_dataset.py +10 -12
  33. torch_geometric/datasets/ba_multi_shapes.py +10 -12
  34. torch_geometric/datasets/ba_shapes.py +5 -6
  35. torch_geometric/datasets/bitcoin_otc.py +1 -1
  36. torch_geometric/datasets/brca_tgca.py +1 -1
  37. torch_geometric/datasets/dblp.py +2 -1
  38. torch_geometric/datasets/dbp15k.py +2 -2
  39. torch_geometric/datasets/fake.py +1 -3
  40. torch_geometric/datasets/flickr.py +2 -1
  41. torch_geometric/datasets/freebase.py +1 -1
  42. torch_geometric/datasets/gdelt_lite.py +3 -2
  43. torch_geometric/datasets/ged_dataset.py +3 -2
  44. torch_geometric/datasets/git_mol_dataset.py +263 -0
  45. torch_geometric/datasets/gnn_benchmark_dataset.py +6 -5
  46. torch_geometric/datasets/hgb_dataset.py +8 -8
  47. torch_geometric/datasets/imdb.py +2 -1
  48. torch_geometric/datasets/last_fm.py +2 -1
  49. torch_geometric/datasets/linkx_dataset.py +4 -3
  50. torch_geometric/datasets/lrgb.py +3 -5
  51. torch_geometric/datasets/malnet_tiny.py +4 -3
  52. torch_geometric/datasets/mnist_superpixels.py +2 -3
  53. torch_geometric/datasets/molecule_gpt_dataset.py +485 -0
  54. torch_geometric/datasets/molecule_net.py +7 -1
  55. torch_geometric/datasets/motif_generator/base.py +0 -1
  56. torch_geometric/datasets/neurograph.py +1 -3
  57. torch_geometric/datasets/ogb_mag.py +1 -1
  58. torch_geometric/datasets/opf.py +239 -0
  59. torch_geometric/datasets/ose_gvcs.py +1 -1
  60. torch_geometric/datasets/pascal_pf.py +1 -1
  61. torch_geometric/datasets/pcpnet_dataset.py +1 -1
  62. torch_geometric/datasets/pcqm4m.py +2 -1
  63. torch_geometric/datasets/ppi.py +1 -1
  64. torch_geometric/datasets/qm9.py +4 -3
  65. torch_geometric/datasets/reddit.py +2 -1
  66. torch_geometric/datasets/reddit2.py +2 -1
  67. torch_geometric/datasets/rel_link_pred_dataset.py +3 -3
  68. torch_geometric/datasets/s3dis.py +2 -2
  69. torch_geometric/datasets/shapenet.py +3 -3
  70. torch_geometric/datasets/shrec2016.py +2 -2
  71. torch_geometric/datasets/tag_dataset.py +350 -0
  72. torch_geometric/datasets/upfd.py +2 -1
  73. torch_geometric/datasets/web_qsp_dataset.py +246 -0
  74. torch_geometric/datasets/webkb.py +2 -2
  75. torch_geometric/datasets/wikics.py +1 -1
  76. torch_geometric/datasets/wikidata.py +3 -2
  77. torch_geometric/datasets/wikipedia_network.py +2 -2
  78. torch_geometric/datasets/word_net.py +2 -2
  79. torch_geometric/datasets/yelp.py +2 -1
  80. torch_geometric/datasets/zinc.py +1 -1
  81. torch_geometric/device.py +42 -0
  82. torch_geometric/distributed/local_feature_store.py +3 -2
  83. torch_geometric/distributed/local_graph_store.py +2 -1
  84. torch_geometric/distributed/partition.py +9 -8
  85. torch_geometric/edge_index.py +17 -8
  86. torch_geometric/explain/algorithm/base.py +0 -1
  87. torch_geometric/explain/algorithm/pg_explainer.py +1 -1
  88. torch_geometric/explain/explanation.py +2 -2
  89. torch_geometric/graphgym/checkpoint.py +2 -1
  90. torch_geometric/graphgym/logger.py +4 -4
  91. torch_geometric/graphgym/loss.py +1 -1
  92. torch_geometric/graphgym/utils/agg_runs.py +6 -6
  93. torch_geometric/index.py +20 -7
  94. torch_geometric/inspector.py +6 -2
  95. torch_geometric/io/fs.py +28 -2
  96. torch_geometric/io/npz.py +2 -1
  97. torch_geometric/io/off.py +2 -2
  98. torch_geometric/io/sdf.py +2 -2
  99. torch_geometric/io/tu.py +2 -3
  100. torch_geometric/loader/__init__.py +4 -0
  101. torch_geometric/loader/cluster.py +9 -3
  102. torch_geometric/loader/graph_saint.py +2 -1
  103. torch_geometric/loader/ibmb_loader.py +12 -4
  104. torch_geometric/loader/mixin.py +1 -1
  105. torch_geometric/loader/neighbor_loader.py +1 -1
  106. torch_geometric/loader/neighbor_sampler.py +2 -2
  107. torch_geometric/loader/prefetch.py +1 -1
  108. torch_geometric/loader/rag_loader.py +107 -0
  109. torch_geometric/loader/zip_loader.py +10 -0
  110. torch_geometric/metrics/__init__.py +11 -2
  111. torch_geometric/metrics/link_pred.py +159 -34
  112. torch_geometric/nn/aggr/__init__.py +2 -0
  113. torch_geometric/nn/aggr/attention.py +0 -2
  114. torch_geometric/nn/aggr/base.py +2 -4
  115. torch_geometric/nn/aggr/patch_transformer.py +143 -0
  116. torch_geometric/nn/aggr/set_transformer.py +1 -1
  117. torch_geometric/nn/attention/__init__.py +5 -1
  118. torch_geometric/nn/attention/qformer.py +71 -0
  119. torch_geometric/nn/conv/collect.jinja +6 -3
  120. torch_geometric/nn/conv/cugraph/base.py +0 -1
  121. torch_geometric/nn/conv/edge_conv.py +3 -2
  122. torch_geometric/nn/conv/gat_conv.py +35 -7
  123. torch_geometric/nn/conv/gatv2_conv.py +36 -6
  124. torch_geometric/nn/conv/general_conv.py +1 -1
  125. torch_geometric/nn/conv/gravnet_conv.py +3 -2
  126. torch_geometric/nn/conv/hetero_conv.py +3 -3
  127. torch_geometric/nn/conv/hgt_conv.py +1 -1
  128. torch_geometric/nn/conv/message_passing.py +100 -82
  129. torch_geometric/nn/conv/mixhop_conv.py +1 -1
  130. torch_geometric/nn/conv/rgcn_conv.py +2 -1
  131. torch_geometric/nn/conv/spline_conv.py +4 -4
  132. torch_geometric/nn/conv/x_conv.py +3 -2
  133. torch_geometric/nn/dense/linear.py +5 -4
  134. torch_geometric/nn/fx.py +3 -3
  135. torch_geometric/nn/model_hub.py +3 -1
  136. torch_geometric/nn/models/__init__.py +10 -2
  137. torch_geometric/nn/models/deep_graph_infomax.py +1 -2
  138. torch_geometric/nn/models/dimenet_utils.py +5 -7
  139. torch_geometric/nn/models/g_retriever.py +230 -0
  140. torch_geometric/nn/models/git_mol.py +336 -0
  141. torch_geometric/nn/models/glem.py +385 -0
  142. torch_geometric/nn/models/gnnff.py +0 -1
  143. torch_geometric/nn/models/graph_unet.py +12 -3
  144. torch_geometric/nn/models/jumping_knowledge.py +63 -4
  145. torch_geometric/nn/models/lightgcn.py +1 -1
  146. torch_geometric/nn/models/metapath2vec.py +3 -4
  147. torch_geometric/nn/models/molecule_gpt.py +222 -0
  148. torch_geometric/nn/models/node2vec.py +1 -2
  149. torch_geometric/nn/models/schnet.py +2 -1
  150. torch_geometric/nn/models/signed_gcn.py +3 -3
  151. torch_geometric/nn/module_dict.py +2 -2
  152. torch_geometric/nn/nlp/__init__.py +9 -0
  153. torch_geometric/nn/nlp/llm.py +322 -0
  154. torch_geometric/nn/nlp/sentence_transformer.py +134 -0
  155. torch_geometric/nn/nlp/vision_transformer.py +33 -0
  156. torch_geometric/nn/norm/batch_norm.py +1 -1
  157. torch_geometric/nn/parameter_dict.py +2 -2
  158. torch_geometric/nn/pool/__init__.py +7 -5
  159. torch_geometric/nn/pool/cluster_pool.py +145 -0
  160. torch_geometric/nn/pool/connect/base.py +0 -1
  161. torch_geometric/nn/pool/edge_pool.py +1 -1
  162. torch_geometric/nn/pool/graclus.py +4 -2
  163. torch_geometric/nn/pool/select/base.py +0 -1
  164. torch_geometric/nn/pool/voxel_grid.py +3 -2
  165. torch_geometric/nn/resolver.py +1 -1
  166. torch_geometric/nn/sequential.jinja +10 -23
  167. torch_geometric/nn/sequential.py +203 -77
  168. torch_geometric/nn/summary.py +1 -1
  169. torch_geometric/nn/to_hetero_with_bases_transformer.py +19 -19
  170. torch_geometric/profile/__init__.py +2 -0
  171. torch_geometric/profile/nvtx.py +66 -0
  172. torch_geometric/profile/profiler.py +24 -15
  173. torch_geometric/resolver.py +1 -1
  174. torch_geometric/sampler/base.py +34 -13
  175. torch_geometric/sampler/neighbor_sampler.py +11 -10
  176. torch_geometric/testing/decorators.py +17 -22
  177. torch_geometric/transforms/__init__.py +2 -0
  178. torch_geometric/transforms/add_metapaths.py +4 -4
  179. torch_geometric/transforms/add_positional_encoding.py +1 -1
  180. torch_geometric/transforms/delaunay.py +65 -14
  181. torch_geometric/transforms/face_to_edge.py +32 -3
  182. torch_geometric/transforms/gdc.py +7 -6
  183. torch_geometric/transforms/laplacian_lambda_max.py +2 -2
  184. torch_geometric/transforms/mask.py +5 -1
  185. torch_geometric/transforms/node_property_split.py +1 -2
  186. torch_geometric/transforms/pad.py +7 -6
  187. torch_geometric/transforms/random_link_split.py +1 -1
  188. torch_geometric/transforms/remove_self_loops.py +36 -0
  189. torch_geometric/transforms/svd_feature_reduction.py +1 -1
  190. torch_geometric/transforms/virtual_node.py +2 -1
  191. torch_geometric/typing.py +31 -5
  192. torch_geometric/utils/__init__.py +5 -1
  193. torch_geometric/utils/_negative_sampling.py +1 -1
  194. torch_geometric/utils/_normalize_edge_index.py +46 -0
  195. torch_geometric/utils/_scatter.py +37 -12
  196. torch_geometric/utils/_subgraph.py +4 -0
  197. torch_geometric/utils/_tree_decomposition.py +2 -2
  198. torch_geometric/utils/augmentation.py +1 -1
  199. torch_geometric/utils/convert.py +5 -5
  200. torch_geometric/utils/geodesic.py +24 -22
  201. torch_geometric/utils/hetero.py +1 -1
  202. torch_geometric/utils/map.py +1 -1
  203. torch_geometric/utils/smiles.py +66 -28
  204. torch_geometric/utils/sparse.py +25 -10
  205. torch_geometric/visualization/graph.py +3 -4
@@ -0,0 +1,485 @@
1
+ import gzip
2
+ import json
3
+ import multiprocessing
4
+ import os
5
+ import sys
6
+ from collections import defaultdict
7
+ from multiprocessing import Pool
8
+ from typing import Callable, List, Optional, Tuple
9
+
10
+ import numpy as np
11
+ import requests
12
+ import torch
13
+ from tqdm import tqdm
14
+
15
+ from torch_geometric.data import Data, InMemoryDataset, download_url
16
+ from torch_geometric.io import fs
17
+ from torch_geometric.nn.nlp import LLM
18
+ from torch_geometric.utils import one_hot
19
+
20
+
21
+ def clean_up_description(description: str) -> str:
22
+ description = description + " "
23
+
24
+ # extra adj Pure
25
+ if description.startswith("Pure "):
26
+ description = description.replace("Pure ", "")
27
+ # fix typo
28
+ if description.startswith("Mercurycombines"):
29
+ description = description.replace("Mercurycombines",
30
+ "Mercury combines")
31
+
32
+ # a special case
33
+ description = description.replace(
34
+ "17-Hydroxy-6-methylpregna-3,6-diene-3,20-dione. ",
35
+ "17-Hydroxy-6-methylpregna-3,6-diene-3,20-dione is ")
36
+
37
+ # a special case
38
+ description = description.replace("5-Thymidylic acid. ",
39
+ "5-Thymidylic acid. is ")
40
+
41
+ # a special case
42
+ description = description.replace(
43
+ "5'-S-(3-Amino-3-carboxypropyl)-5'-thioadenosine. ",
44
+ "5'-S-(3-Amino-3-carboxypropyl)-5'-thioadenosine. is ")
45
+
46
+ # a special case
47
+ description = description.replace(
48
+ ("Guanosine 5'-(trihydrogen diphosphate), monoanhydride"
49
+ " with phosphorothioic acid. "),
50
+ ("Guanosine 5'-(trihydrogen diphosphate), monoanhydride"
51
+ " with phosphorothioic acid is "))
52
+
53
+ # a special case
54
+ description = description.replace("5'-Uridylic acid. ",
55
+ "5'-Uridylic acid is ")
56
+
57
+ # a special case
58
+ description = description.replace("5'-Adenylic acid, ",
59
+ "5'-Adenylic acid is ")
60
+
61
+ # a special case
62
+ description = description.replace(
63
+ "Uridine 5'-(tetrahydrogen triphosphate). ",
64
+ "Uridine 5'-(tetrahydrogen triphosphate). is ")
65
+
66
+ # a special case
67
+ description = description.replace("Inosine 5'-Monophosphate. ",
68
+ "Inosine 5'-Monophosphate. is ")
69
+
70
+ # a special case
71
+ description = description.replace("Pivaloyloxymethyl butyrate (AN-9), ",
72
+ "Pivaloyloxymethyl butyrate (AN-9) is ")
73
+
74
+ # a special case
75
+ description = description.replace(
76
+ "4-Amino-5-cyano-7-(D-ribofuranosyl)-7H- pyrrolo(2,3-d)pyrimidine. ",
77
+ "4-Amino-5-cyano-7-(D-ribofuranosyl)-7H- pyrrolo(2,3-d)pyrimidine is ")
78
+
79
+ # a special case
80
+ description = description.replace(
81
+ "Cardamonin (also known as Dihydroxymethoxychalcone), ",
82
+ "Cardamonin (also known as Dihydroxymethoxychalcone) is ")
83
+
84
+ # a special case
85
+ description = description.replace("Lithium has been used to treat ",
86
+ "Lithium is ")
87
+
88
+ # a special case
89
+ description = description.replace("4,4'-Methylenebis ",
90
+ "4,4'-Methylenebis is ")
91
+
92
+ # a special case
93
+ description = description.replace(
94
+ "2,3,7,8-Tetrachlorodibenzo-p-dioxin",
95
+ "2,3,7,8-Tetrachlorodibenzo-p-dioxin is ")
96
+
97
+ # a special case
98
+ description = description.replace("Exposure to 2,4,5-trichlorophenol ",
99
+ "2,4,5-Trichlorophenol exposure ")
100
+
101
+ index = 0
102
+ L = len(description)
103
+ if description.startswith('C.I. '):
104
+ start_index = len('C.I. ')
105
+ elif description.startswith('Nectriapyrone. D '):
106
+ start_index = len('Nectriapyrone. D ')
107
+ elif description.startswith(
108
+ 'Salmonella enterica sv. Minnesota LPS core oligosaccharide'):
109
+ start_index = len(
110
+ 'Salmonella enterica sv. Minnesota LPS core oligosaccharide')
111
+ else:
112
+ start_index = 0
113
+ for index in range(start_index, L - 1):
114
+ if index < L - 2:
115
+ if description[index] == '.' and description[
116
+ index + 1] == ' ' and 'A' <= description[index + 2] <= 'Z':
117
+ break
118
+ elif index == L - 2:
119
+ break
120
+
121
+ first_sentence = description[:index + 1]
122
+ return first_sentence
123
+
124
+
125
+ def extract_name(
126
+ name_raw: str,
127
+ description: str,
128
+ ) -> Tuple[Optional[str], str, str]:
129
+ first_sentence = clean_up_description(description)
130
+
131
+ splitter = ' -- -- '
132
+ if ' are ' in first_sentence or ' were ' in first_sentence:
133
+ replaced_words = 'These molecules'
134
+ else:
135
+ replaced_words = 'This molecule'
136
+
137
+ first_sentence = first_sentence.replace(' is ', splitter)
138
+ first_sentence = first_sentence.replace(' are ', splitter)
139
+ first_sentence = first_sentence.replace(' was ', splitter)
140
+ first_sentence = first_sentence.replace(' were ', splitter)
141
+ first_sentence = first_sentence.replace(' appears ', splitter)
142
+ first_sentence = first_sentence.replace(' occurs ', splitter)
143
+ first_sentence = first_sentence.replace(' stands for ', splitter)
144
+ first_sentence = first_sentence.replace(' belongs to ', splitter)
145
+ first_sentence = first_sentence.replace(' exists ',
146
+ splitter) # only for CID=11443
147
+ first_sentence = first_sentence.replace(' has been used in trials ',
148
+ splitter)
149
+ first_sentence = first_sentence.replace(' has been investigated ',
150
+ splitter)
151
+ first_sentence = first_sentence.replace(' has many uses ', splitter)
152
+
153
+ if splitter in first_sentence:
154
+ extracted_name = first_sentence.split(splitter, 1)[0]
155
+ elif first_sentence.startswith(name_raw):
156
+ extracted_name = name_raw
157
+ elif name_raw in first_sentence:
158
+ extracted_name = name_raw
159
+ extracted_name = None
160
+ print("=====", name_raw)
161
+ print("first sentence: ", first_sentence)
162
+ else:
163
+ extracted_name = None
164
+
165
+ if extracted_name is not None:
166
+ extracted_description = description.replace(extracted_name,
167
+ replaced_words)
168
+ else:
169
+ extracted_description = description
170
+
171
+ return extracted_name, extracted_description, first_sentence
172
+
173
+
174
+ class MoleculeGPTDataset(InMemoryDataset):
175
+ r"""The dataset from the `"MoleculeGPT: Instruction Following Large
176
+ Language Models for Molecular Property Prediction"
177
+ <https://ai4d3.github.io/papers/34.pdf>`_ paper.
178
+
179
+ Args:
180
+ root (str): Root directory where the dataset should be saved.
181
+ transform (callable, optional): A function/transform that takes in an
182
+ :obj:`torch_geometric.data.Data` object and returns a transformed
183
+ version. The data object will be transformed before every access.
184
+ (default: :obj:`None`)
185
+ pre_transform (callable, optional): A function/transform that takes in
186
+ an :obj:`torch_geometric.data.Data` object and returns a
187
+ transformed version. The data object will be transformed before
188
+ being saved to disk. (default: :obj:`None`)
189
+ pre_filter (callable, optional): A function that takes in an
190
+ :obj:`torch_geometric.data.Data` object and returns a boolean
191
+ value, indicating whether the data object should be included in the
192
+ final dataset. (default: :obj:`None`)
193
+ force_reload (bool, optional): Whether to re-process the dataset.
194
+ (default: :obj:`False`)
195
+ total_page_num (int, optional): The number of pages from PubChem.
196
+ (default: :obj:`10`)
197
+ total_block_num (int, optional): The blocks of SDF files from PubChem.
198
+ (default: :obj:`1`)
199
+ """
200
+ description_url = (
201
+ 'https://pubchem.ncbi.nlm.nih.gov/rest/pug_view/annotations/'
202
+ 'heading/json?heading_type=Compound&heading=Record+Description&page={}'
203
+ )
204
+ compound_url = ('https://ftp.ncbi.nlm.nih.gov/pubchem/Compound/'
205
+ 'CURRENT-Full/SDF')
206
+
207
+ def __init__(
208
+ self,
209
+ root: str,
210
+ transform: Optional[Callable] = None,
211
+ pre_transform: Optional[Callable] = None,
212
+ pre_filter: Optional[Callable] = None,
213
+ force_reload: bool = False,
214
+ total_page_num: int = 10,
215
+ total_block_num: int = 1,
216
+ ):
217
+ self.total_page_num = total_page_num
218
+ self.total_block_num = total_block_num
219
+
220
+ super().__init__(root, transform, pre_transform, pre_filter,
221
+ force_reload=force_reload)
222
+ self.load(self.processed_paths[0])
223
+
224
+ @property
225
+ def raw_file_names(self) -> List[str]:
226
+ return ['pubchem.csv']
227
+
228
+ @property
229
+ def processed_file_names(self) -> List[str]:
230
+ return ['data.pt']
231
+
232
+ def download(self) -> None:
233
+ # Step 01. Extract description
234
+ step1_folder = f"{self.raw_dir}/step_01_PubChemSTM_description"
235
+ if not os.path.exists(step1_folder):
236
+ os.makedirs(step1_folder)
237
+ valid_CID_set = set()
238
+ CID2name_raw, CID2name_extracted = defaultdict(list), defaultdict(
239
+ list)
240
+ CID2text_raw, CID2text_extracted = defaultdict(list), defaultdict(
241
+ list)
242
+
243
+ for page_index in tqdm(range(self.total_page_num)):
244
+ page_num = page_index + 1
245
+ f_out = open(
246
+ f"{step1_folder}/Compound_description_{page_num}.txt", "w")
247
+
248
+ description_data = requests.get(
249
+ self.description_url.format(page_num)).json()
250
+
251
+ description_data = description_data["Annotations"]
252
+ assert description_data["Page"] == page_num
253
+
254
+ record_list = description_data["Annotation"]
255
+
256
+ for record in record_list:
257
+ try:
258
+ CID = record["LinkedRecords"]["CID"][0]
259
+ if "Name" in record:
260
+ name_raw = record["Name"]
261
+ CID2name_raw[CID].append(name_raw)
262
+ else:
263
+ name_raw = None
264
+
265
+ data_list = record["Data"]
266
+ for data in data_list:
267
+ description = data["Value"]["StringWithMarkup"][0][
268
+ "String"].strip()
269
+
270
+ extracted_name, extracted_description, _ = extract_name( # noqa: E501
271
+ name_raw, description)
272
+ if extracted_name is not None:
273
+ CID2name_extracted[CID].append(extracted_name)
274
+
275
+ CID2text_raw[CID].append(description)
276
+ CID2text_extracted[CID].append(
277
+ extracted_description)
278
+
279
+ valid_CID_set.add(CID)
280
+ f_out.write(f"{CID}\n")
281
+ f_out.write(f"{extracted_description}\n\n")
282
+ except Exception:
283
+ continue
284
+
285
+ valid_CID_list = sorted(list(valid_CID_set))
286
+ print(f"Total CID (with raw name) {len(CID2name_raw)}")
287
+ print(f"Total CID (with extracted name) {len(CID2name_extracted)}")
288
+ print(f"Total CID {len(valid_CID_list)}")
289
+
290
+ with open(f"{self.raw_dir}/CID2name_raw.json", "w") as f:
291
+ json.dump(CID2name_raw, f)
292
+
293
+ with open(f"{self.raw_dir}/CID2name.json", "w") as f:
294
+ json.dump(CID2name_extracted, f)
295
+
296
+ with open(f"{self.raw_dir}/CID2text_raw.json", "w") as f:
297
+ json.dump(CID2text_raw, f)
298
+
299
+ with open(f"{self.raw_dir}/CID2text.json", "w") as f:
300
+ json.dump(CID2text_extracted, f)
301
+
302
+ # Step 02. Download SDF Files
303
+ step2_folder = f"{self.raw_dir}/step_02_PubChemSTM_SDF"
304
+ if not os.path.exists(step2_folder):
305
+ for block_id in tqdm(range(self.total_block_num)):
306
+ block_size = 500000
307
+ l_id = block_id * block_size + 1
308
+ r_id = (block_id + 1) * block_size
309
+
310
+ compound_file_name = f"Compound_{l_id:09d}_{r_id:09d}.sdf.gz"
311
+ download_url(f"{self.compound_url}/{compound_file_name}",
312
+ step2_folder)
313
+
314
+ def process(self, use_mp: bool = False) -> None:
315
+ try:
316
+ from rdkit import Chem
317
+ from rdkit.Chem.rdchem import BondType as BT
318
+ WITH_RDKIT = True
319
+
320
+ except ImportError:
321
+ WITH_RDKIT = False
322
+
323
+ if not WITH_RDKIT:
324
+ print(("Using a pre-processed version of the dataset. Please "
325
+ "install 'rdkit' to alternatively process the raw data."),
326
+ file=sys.stderr)
327
+
328
+ data_list = fs.torch_load(self.raw_paths[0])
329
+ data_list = [Data(**data_dict) for data_dict in data_list]
330
+
331
+ if self.pre_filter is not None:
332
+ data_list = [d for d in data_list if self.pre_filter(d)]
333
+
334
+ if self.pre_transform is not None:
335
+ data_list = [self.pre_transform(d) for d in data_list]
336
+
337
+ self.save(data_list, self.processed_paths[0])
338
+ return
339
+
340
+ # Step 03. Filter out SDF
341
+ step2_folder = f"{self.raw_dir}/step_02_PubChemSTM_SDF"
342
+ step3_folder = f"{self.raw_dir}/step_03_PubChemSTM_filtered"
343
+ if not os.path.exists(step3_folder):
344
+ os.makedirs(step3_folder)
345
+ with open(f"{self.raw_dir}/CID2text.json") as f:
346
+ CID2text = json.load(f)
347
+ target_CID_list = set(CID2text.keys())
348
+
349
+ block_size = 500000
350
+
351
+ def extract_one_SDF_file(block_id: int) -> None:
352
+ valid_mol_count = 0
353
+
354
+ writer = Chem.SDWriter(
355
+ f'{step3_folder}/filtered_{block_id}.sdf')
356
+ l_id = block_id * block_size + 1
357
+ r_id = (block_id + 1) * block_size
358
+
359
+ compound_file_name = f"Compound_{l_id:09d}_{r_id:09d}.sdf.gz"
360
+ gzip_loader = gzip.open(f"{step2_folder}/{compound_file_name}")
361
+ suppl = Chem.ForwardSDMolSupplier(gzip_loader)
362
+
363
+ for mol in tqdm(suppl):
364
+ if mol is None:
365
+ continue
366
+ cid = mol.GetProp("PUBCHEM_COMPOUND_CID")
367
+
368
+ if cid not in target_CID_list:
369
+ continue
370
+
371
+ writer.write(mol)
372
+ valid_mol_count += 1
373
+
374
+ writer.close()
375
+ print(f"block id: {block_id}\nfound {valid_mol_count}\n\n")
376
+ sys.stdout.flush()
377
+ return
378
+
379
+ if use_mp:
380
+ num_process = multiprocessing.cpu_count()
381
+ print(f"{num_process} CPUs")
382
+ num_process = 8
383
+ p = Pool(num_process)
384
+
385
+ block_id_list = np.arange(self.total_block_num)
386
+ with p:
387
+ p.map(extract_one_SDF_file, block_id_list)
388
+ else:
389
+ for block_id in range(self.total_block_num):
390
+ extract_one_SDF_file(block_id)
391
+
392
+ # Step 04. Merge SDF
393
+ with open(f"{self.raw_dir}/CID2text.json") as f:
394
+ CID2text = json.load(f)
395
+ target_CID_list = set(CID2text.keys())
396
+ print(f'The length of target_CID_list: {len(target_CID_list)}')
397
+
398
+ writer = Chem.SDWriter(f'{self.raw_dir}/molecules.sdf')
399
+
400
+ found_CID_set = set()
401
+ for block_id in range(self.total_block_num + 1):
402
+ compound_file_path = f"{step3_folder}/filtered_{block_id}.sdf"
403
+ try:
404
+ suppl = Chem.SDMolSupplier(compound_file_path)
405
+
406
+ for mol in tqdm(suppl):
407
+ writer.write(mol)
408
+ cid = mol.GetProp("PUBCHEM_COMPOUND_CID")
409
+ found_CID_set.add(cid)
410
+ except Exception:
411
+ print(f"block id: {block_id} with 0 valid SDF file")
412
+ continue
413
+
414
+ writer.close()
415
+ print(f"In total: {len(found_CID_set)} molecules")
416
+
417
+ # Step 05. Convert to PyG data format
418
+ types = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4, 'Unknow': 5}
419
+ bonds = {BT.SINGLE: 0, BT.DOUBLE: 1, BT.TRIPLE: 2, BT.AROMATIC: 3}
420
+
421
+ data_list = []
422
+ # Real data
423
+ CID2text_file = f'{self.raw_dir}/CID2text.json'
424
+
425
+ with open(CID2text_file) as f:
426
+ CID2text_data = json.load(f)
427
+
428
+ suppl = Chem.SDMolSupplier(f'{self.raw_dir}/molecules.sdf')
429
+
430
+ llm = LLM(
431
+ # model_name='lmsys/vicuna-7b-v1.5',
432
+ model_name='TinyLlama/TinyLlama-1.1B-Chat-v0.1',
433
+ num_params=1,
434
+ dtype=torch.bfloat16,
435
+ )
436
+ prompt = ("Propose a question regarding the molecule '∼' "
437
+ "whose answer is: {}:")
438
+ for mol in tqdm(suppl):
439
+ if mol.HasProp('PUBCHEM_COMPOUND_CID'):
440
+ CID = mol.GetProp("PUBCHEM_COMPOUND_CID")
441
+ CAN_SMILES = mol.GetProp("PUBCHEM_OPENEYE_CAN_SMILES")
442
+
443
+ m: Chem.Mol = Chem.MolFromSmiles(CAN_SMILES)
444
+ if m is None:
445
+ continue
446
+ RDKit_CAN_SMILES = Chem.MolToSmiles(m)
447
+
448
+ ground_truth = CID2text_data[CID][0]
449
+
450
+ instruction = llm.inference([prompt.format(ground_truth)])[0]
451
+
452
+ x: torch.Tensor = torch.tensor([
453
+ types[atom.GetSymbol()] if atom.GetSymbol() in types else 5
454
+ for atom in m.GetAtoms()
455
+ ])
456
+ x = one_hot(x, num_classes=len(types), dtype=torch.float)
457
+
458
+ rows, cols, edge_types = [], [], []
459
+ for bond in m.GetBonds():
460
+ i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
461
+ edge_types += [bonds[bond.GetBondType()]] * 2
462
+ rows += [i, j]
463
+ cols += [j, i]
464
+
465
+ edge_index = torch.tensor([rows, cols], dtype=torch.long)
466
+ edge_type = torch.tensor(edge_types, dtype=torch.long)
467
+ edge_attr = one_hot(edge_type, num_classes=len(bonds))
468
+
469
+ data = Data(
470
+ x=x,
471
+ edge_index=edge_index,
472
+ edge_attr=edge_attr,
473
+ smiles=RDKit_CAN_SMILES,
474
+ instruction=instruction,
475
+ y=ground_truth,
476
+ )
477
+
478
+ if self.pre_filter is not None and not self.pre_filter(data):
479
+ continue
480
+ if self.pre_transform is not None:
481
+ data = self.pre_transform(data)
482
+
483
+ data_list.append(data)
484
+
485
+ self.save(data_list, self.processed_paths[0])
@@ -1,6 +1,7 @@
1
1
  import os
2
2
  import os.path as osp
3
3
  import re
4
+ import warnings
4
5
  from typing import Callable, Dict, Optional, Tuple, Union
5
6
 
6
7
  import torch
@@ -189,7 +190,7 @@ class MoleculeNet(InMemoryDataset):
189
190
  os.unlink(path)
190
191
 
191
192
  def process(self) -> None:
192
- with open(self.raw_paths[0], 'r') as f:
193
+ with open(self.raw_paths[0]) as f:
193
194
  dataset = f.read().split('\n')[1:-1]
194
195
  dataset = [x for x in dataset if len(x) > 0] # Filter empty lines.
195
196
 
@@ -208,6 +209,11 @@ class MoleculeNet(InMemoryDataset):
208
209
  data = self.from_smiles(smiles)
209
210
  data.y = y
210
211
 
212
+ if data.num_nodes == 0:
213
+ warnings.warn(f"Skipping molecule '{smiles}' since it "
214
+ f"resulted in zero atoms")
215
+ continue
216
+
211
217
  if self.pre_filter is not None and not self.pre_filter(data):
212
218
  continue
213
219
 
@@ -10,7 +10,6 @@ class MotifGenerator(ABC):
10
10
  @abstractmethod
11
11
  def __call__(self) -> Data:
12
12
  r"""To be implemented by :class:`Motif` subclasses."""
13
- pass
14
13
 
15
14
  @staticmethod
16
15
  def resolve(query: Any, *args: Any, **kwargs: Any) -> 'MotifGenerator':
@@ -2,8 +2,6 @@ import os
2
2
  import os.path as osp
3
3
  from typing import Callable, List, Optional
4
4
 
5
- import torch
6
-
7
5
  from torch_geometric.data import (
8
6
  Data,
9
7
  InMemoryDataset,
@@ -110,7 +108,7 @@ class NeuroGraphDataset(InMemoryDataset):
110
108
  fs.rm(osp.join(self.raw_dir, self.name))
111
109
 
112
110
  def process(self) -> None:
113
- data, slices = torch.load(self.raw_paths[0])
111
+ data, slices = fs.torch_load(self.raw_paths[0])
114
112
 
115
113
  num_samples = slices['x'].size(0) - 1
116
114
  data_list: List[Data] = []
@@ -147,7 +147,7 @@ class OGB_MAG(InMemoryDataset):
147
147
  for node_type in ['author', 'institution', 'field_of_study']:
148
148
  data[node_type].num_nodes = num_nodes_df[node_type].tolist()[0]
149
149
  else:
150
- emb_dict = torch.load(self.raw_paths[-1])
150
+ emb_dict = fs.torch_load(self.raw_paths[-1])
151
151
  for key, value in emb_dict.items():
152
152
  if key != 'paper':
153
153
  data[key].x = value