pyg-nightly 2.6.0.dev20240704__py3-none-any.whl → 2.8.0.dev20251207__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pyg-nightly might be problematic. Click here for more details.

Files changed (268) hide show
  1. {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/METADATA +81 -58
  2. {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/RECORD +265 -221
  3. {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/WHEEL +1 -1
  4. pyg_nightly-2.8.0.dev20251207.dist-info/licenses/LICENSE +19 -0
  5. torch_geometric/__init__.py +34 -1
  6. torch_geometric/_compile.py +11 -3
  7. torch_geometric/_onnx.py +228 -0
  8. torch_geometric/config_mixin.py +8 -3
  9. torch_geometric/config_store.py +1 -1
  10. torch_geometric/contrib/__init__.py +1 -1
  11. torch_geometric/contrib/explain/pgm_explainer.py +1 -1
  12. torch_geometric/data/__init__.py +19 -1
  13. torch_geometric/data/batch.py +2 -2
  14. torch_geometric/data/collate.py +1 -3
  15. torch_geometric/data/data.py +110 -6
  16. torch_geometric/data/database.py +19 -5
  17. torch_geometric/data/dataset.py +14 -9
  18. torch_geometric/data/extract.py +1 -1
  19. torch_geometric/data/feature_store.py +17 -22
  20. torch_geometric/data/graph_store.py +3 -2
  21. torch_geometric/data/hetero_data.py +139 -7
  22. torch_geometric/data/hypergraph_data.py +2 -2
  23. torch_geometric/data/in_memory_dataset.py +2 -2
  24. torch_geometric/data/lightning/datamodule.py +42 -28
  25. torch_geometric/data/storage.py +9 -1
  26. torch_geometric/datasets/__init__.py +20 -1
  27. torch_geometric/datasets/actor.py +7 -9
  28. torch_geometric/datasets/airfrans.py +17 -20
  29. torch_geometric/datasets/airports.py +8 -10
  30. torch_geometric/datasets/amazon.py +8 -11
  31. torch_geometric/datasets/amazon_book.py +8 -9
  32. torch_geometric/datasets/amazon_products.py +7 -9
  33. torch_geometric/datasets/aminer.py +8 -9
  34. torch_geometric/datasets/aqsol.py +10 -13
  35. torch_geometric/datasets/attributed_graph_dataset.py +8 -10
  36. torch_geometric/datasets/ba_multi_shapes.py +10 -12
  37. torch_geometric/datasets/ba_shapes.py +5 -6
  38. torch_geometric/datasets/brca_tgca.py +1 -1
  39. torch_geometric/datasets/city.py +157 -0
  40. torch_geometric/datasets/dbp15k.py +1 -1
  41. torch_geometric/datasets/gdelt_lite.py +3 -2
  42. torch_geometric/datasets/ged_dataset.py +3 -2
  43. torch_geometric/datasets/git_mol_dataset.py +263 -0
  44. torch_geometric/datasets/gnn_benchmark_dataset.py +3 -2
  45. torch_geometric/datasets/hgb_dataset.py +2 -2
  46. torch_geometric/datasets/hm.py +1 -1
  47. torch_geometric/datasets/instruct_mol_dataset.py +134 -0
  48. torch_geometric/datasets/linkx_dataset.py +4 -3
  49. torch_geometric/datasets/lrgb.py +3 -5
  50. torch_geometric/datasets/malnet_tiny.py +2 -1
  51. torch_geometric/datasets/md17.py +3 -3
  52. torch_geometric/datasets/medshapenet.py +145 -0
  53. torch_geometric/datasets/mnist_superpixels.py +2 -3
  54. torch_geometric/datasets/modelnet.py +1 -1
  55. torch_geometric/datasets/molecule_gpt_dataset.py +492 -0
  56. torch_geometric/datasets/molecule_net.py +3 -2
  57. torch_geometric/datasets/neurograph.py +1 -3
  58. torch_geometric/datasets/ogb_mag.py +1 -1
  59. torch_geometric/datasets/opf.py +19 -5
  60. torch_geometric/datasets/pascal_pf.py +1 -1
  61. torch_geometric/datasets/pcqm4m.py +2 -1
  62. torch_geometric/datasets/ppi.py +2 -1
  63. torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
  64. torch_geometric/datasets/qm7.py +1 -1
  65. torch_geometric/datasets/qm9.py +3 -2
  66. torch_geometric/datasets/shrec2016.py +2 -2
  67. torch_geometric/datasets/snap_dataset.py +8 -4
  68. torch_geometric/datasets/tag_dataset.py +462 -0
  69. torch_geometric/datasets/teeth3ds.py +269 -0
  70. torch_geometric/datasets/web_qsp_dataset.py +342 -0
  71. torch_geometric/datasets/wikics.py +2 -1
  72. torch_geometric/datasets/wikidata.py +2 -1
  73. torch_geometric/deprecation.py +1 -1
  74. torch_geometric/distributed/__init__.py +13 -0
  75. torch_geometric/distributed/dist_loader.py +2 -2
  76. torch_geometric/distributed/local_feature_store.py +3 -2
  77. torch_geometric/distributed/local_graph_store.py +2 -1
  78. torch_geometric/distributed/partition.py +9 -8
  79. torch_geometric/distributed/rpc.py +3 -3
  80. torch_geometric/edge_index.py +35 -22
  81. torch_geometric/explain/algorithm/attention_explainer.py +219 -29
  82. torch_geometric/explain/algorithm/base.py +2 -2
  83. torch_geometric/explain/algorithm/captum.py +1 -1
  84. torch_geometric/explain/algorithm/captum_explainer.py +2 -1
  85. torch_geometric/explain/algorithm/gnn_explainer.py +406 -69
  86. torch_geometric/explain/algorithm/graphmask_explainer.py +8 -8
  87. torch_geometric/explain/algorithm/pg_explainer.py +305 -47
  88. torch_geometric/explain/explainer.py +2 -2
  89. torch_geometric/explain/explanation.py +89 -5
  90. torch_geometric/explain/metric/faithfulness.py +1 -1
  91. torch_geometric/graphgym/checkpoint.py +2 -1
  92. torch_geometric/graphgym/config.py +3 -2
  93. torch_geometric/graphgym/imports.py +15 -4
  94. torch_geometric/graphgym/logger.py +1 -1
  95. torch_geometric/graphgym/loss.py +1 -1
  96. torch_geometric/graphgym/models/encoder.py +2 -2
  97. torch_geometric/graphgym/models/layer.py +1 -1
  98. torch_geometric/graphgym/utils/comp_budget.py +4 -3
  99. torch_geometric/hash_tensor.py +798 -0
  100. torch_geometric/index.py +16 -7
  101. torch_geometric/inspector.py +6 -2
  102. torch_geometric/io/fs.py +27 -0
  103. torch_geometric/io/tu.py +2 -3
  104. torch_geometric/llm/__init__.py +9 -0
  105. torch_geometric/llm/large_graph_indexer.py +741 -0
  106. torch_geometric/llm/models/__init__.py +23 -0
  107. torch_geometric/llm/models/g_retriever.py +251 -0
  108. torch_geometric/llm/models/git_mol.py +336 -0
  109. torch_geometric/llm/models/glem.py +397 -0
  110. torch_geometric/llm/models/llm.py +470 -0
  111. torch_geometric/llm/models/llm_judge.py +158 -0
  112. torch_geometric/llm/models/molecule_gpt.py +222 -0
  113. torch_geometric/llm/models/protein_mpnn.py +333 -0
  114. torch_geometric/llm/models/sentence_transformer.py +188 -0
  115. torch_geometric/llm/models/txt2kg.py +353 -0
  116. torch_geometric/llm/models/vision_transformer.py +38 -0
  117. torch_geometric/llm/rag_loader.py +154 -0
  118. torch_geometric/llm/utils/__init__.py +10 -0
  119. torch_geometric/llm/utils/backend_utils.py +443 -0
  120. torch_geometric/llm/utils/feature_store.py +169 -0
  121. torch_geometric/llm/utils/graph_store.py +199 -0
  122. torch_geometric/llm/utils/vectorrag.py +125 -0
  123. torch_geometric/loader/cluster.py +6 -5
  124. torch_geometric/loader/graph_saint.py +2 -1
  125. torch_geometric/loader/ibmb_loader.py +4 -4
  126. torch_geometric/loader/link_loader.py +1 -1
  127. torch_geometric/loader/link_neighbor_loader.py +2 -1
  128. torch_geometric/loader/mixin.py +6 -5
  129. torch_geometric/loader/neighbor_loader.py +1 -1
  130. torch_geometric/loader/neighbor_sampler.py +2 -2
  131. torch_geometric/loader/prefetch.py +4 -3
  132. torch_geometric/loader/temporal_dataloader.py +2 -2
  133. torch_geometric/loader/utils.py +10 -10
  134. torch_geometric/metrics/__init__.py +23 -2
  135. torch_geometric/metrics/link_pred.py +755 -85
  136. torch_geometric/nn/__init__.py +1 -0
  137. torch_geometric/nn/aggr/__init__.py +2 -0
  138. torch_geometric/nn/aggr/base.py +1 -1
  139. torch_geometric/nn/aggr/equilibrium.py +1 -1
  140. torch_geometric/nn/aggr/fused.py +1 -1
  141. torch_geometric/nn/aggr/patch_transformer.py +149 -0
  142. torch_geometric/nn/aggr/set_transformer.py +1 -1
  143. torch_geometric/nn/aggr/utils.py +9 -4
  144. torch_geometric/nn/attention/__init__.py +9 -1
  145. torch_geometric/nn/attention/polynormer.py +107 -0
  146. torch_geometric/nn/attention/qformer.py +71 -0
  147. torch_geometric/nn/attention/sgformer.py +99 -0
  148. torch_geometric/nn/conv/__init__.py +2 -0
  149. torch_geometric/nn/conv/appnp.py +1 -1
  150. torch_geometric/nn/conv/collect.jinja +6 -3
  151. torch_geometric/nn/conv/cugraph/gat_conv.py +8 -2
  152. torch_geometric/nn/conv/cugraph/rgcn_conv.py +3 -0
  153. torch_geometric/nn/conv/cugraph/sage_conv.py +3 -0
  154. torch_geometric/nn/conv/dna_conv.py +1 -1
  155. torch_geometric/nn/conv/eg_conv.py +7 -7
  156. torch_geometric/nn/conv/gat_conv.py +33 -4
  157. torch_geometric/nn/conv/gatv2_conv.py +35 -4
  158. torch_geometric/nn/conv/gen_conv.py +1 -1
  159. torch_geometric/nn/conv/general_conv.py +1 -1
  160. torch_geometric/nn/conv/gravnet_conv.py +2 -1
  161. torch_geometric/nn/conv/hetero_conv.py +3 -2
  162. torch_geometric/nn/conv/meshcnn_conv.py +487 -0
  163. torch_geometric/nn/conv/message_passing.py +6 -5
  164. torch_geometric/nn/conv/mixhop_conv.py +1 -1
  165. torch_geometric/nn/conv/rgcn_conv.py +2 -1
  166. torch_geometric/nn/conv/sg_conv.py +1 -1
  167. torch_geometric/nn/conv/spline_conv.py +2 -1
  168. torch_geometric/nn/conv/ssg_conv.py +1 -1
  169. torch_geometric/nn/conv/transformer_conv.py +5 -3
  170. torch_geometric/nn/data_parallel.py +5 -4
  171. torch_geometric/nn/dense/linear.py +5 -24
  172. torch_geometric/nn/encoding.py +17 -3
  173. torch_geometric/nn/fx.py +17 -15
  174. torch_geometric/nn/model_hub.py +5 -16
  175. torch_geometric/nn/models/__init__.py +11 -0
  176. torch_geometric/nn/models/attentive_fp.py +1 -1
  177. torch_geometric/nn/models/attract_repel.py +148 -0
  178. torch_geometric/nn/models/basic_gnn.py +2 -1
  179. torch_geometric/nn/models/captum.py +1 -1
  180. torch_geometric/nn/models/deep_graph_infomax.py +1 -1
  181. torch_geometric/nn/models/dimenet.py +2 -2
  182. torch_geometric/nn/models/dimenet_utils.py +4 -2
  183. torch_geometric/nn/models/gpse.py +1083 -0
  184. torch_geometric/nn/models/graph_unet.py +13 -4
  185. torch_geometric/nn/models/lpformer.py +783 -0
  186. torch_geometric/nn/models/metapath2vec.py +1 -1
  187. torch_geometric/nn/models/mlp.py +4 -2
  188. torch_geometric/nn/models/node2vec.py +1 -1
  189. torch_geometric/nn/models/polynormer.py +206 -0
  190. torch_geometric/nn/models/rev_gnn.py +3 -3
  191. torch_geometric/nn/models/schnet.py +2 -1
  192. torch_geometric/nn/models/sgformer.py +219 -0
  193. torch_geometric/nn/models/signed_gcn.py +1 -1
  194. torch_geometric/nn/models/visnet.py +2 -2
  195. torch_geometric/nn/norm/batch_norm.py +17 -7
  196. torch_geometric/nn/norm/diff_group_norm.py +7 -2
  197. torch_geometric/nn/norm/graph_norm.py +9 -4
  198. torch_geometric/nn/norm/instance_norm.py +5 -1
  199. torch_geometric/nn/norm/layer_norm.py +15 -7
  200. torch_geometric/nn/norm/msg_norm.py +8 -2
  201. torch_geometric/nn/pool/__init__.py +15 -9
  202. torch_geometric/nn/pool/cluster_pool.py +144 -0
  203. torch_geometric/nn/pool/connect/base.py +1 -3
  204. torch_geometric/nn/pool/edge_pool.py +1 -1
  205. torch_geometric/nn/pool/knn.py +13 -10
  206. torch_geometric/nn/pool/select/base.py +1 -4
  207. torch_geometric/nn/summary.py +1 -1
  208. torch_geometric/nn/to_hetero_module.py +4 -3
  209. torch_geometric/nn/to_hetero_transformer.py +3 -3
  210. torch_geometric/nn/to_hetero_with_bases_transformer.py +5 -5
  211. torch_geometric/profile/__init__.py +2 -0
  212. torch_geometric/profile/nvtx.py +66 -0
  213. torch_geometric/profile/profiler.py +18 -9
  214. torch_geometric/profile/utils.py +20 -5
  215. torch_geometric/sampler/__init__.py +2 -1
  216. torch_geometric/sampler/base.py +337 -8
  217. torch_geometric/sampler/hgt_sampler.py +11 -1
  218. torch_geometric/sampler/neighbor_sampler.py +298 -25
  219. torch_geometric/sampler/utils.py +93 -5
  220. torch_geometric/testing/__init__.py +4 -0
  221. torch_geometric/testing/decorators.py +35 -5
  222. torch_geometric/testing/distributed.py +1 -1
  223. torch_geometric/transforms/__init__.py +4 -0
  224. torch_geometric/transforms/add_gpse.py +49 -0
  225. torch_geometric/transforms/add_metapaths.py +10 -8
  226. torch_geometric/transforms/add_positional_encoding.py +2 -2
  227. torch_geometric/transforms/base_transform.py +2 -1
  228. torch_geometric/transforms/delaunay.py +65 -15
  229. torch_geometric/transforms/face_to_edge.py +32 -3
  230. torch_geometric/transforms/gdc.py +8 -9
  231. torch_geometric/transforms/largest_connected_components.py +1 -1
  232. torch_geometric/transforms/mask.py +5 -1
  233. torch_geometric/transforms/node_property_split.py +1 -1
  234. torch_geometric/transforms/normalize_features.py +3 -3
  235. torch_geometric/transforms/pad.py +1 -1
  236. torch_geometric/transforms/random_link_split.py +1 -1
  237. torch_geometric/transforms/remove_duplicated_edges.py +4 -2
  238. torch_geometric/transforms/remove_self_loops.py +36 -0
  239. torch_geometric/transforms/rooted_subgraph.py +1 -1
  240. torch_geometric/transforms/svd_feature_reduction.py +1 -1
  241. torch_geometric/transforms/virtual_node.py +2 -1
  242. torch_geometric/typing.py +82 -17
  243. torch_geometric/utils/__init__.py +6 -1
  244. torch_geometric/utils/_lexsort.py +0 -9
  245. torch_geometric/utils/_negative_sampling.py +28 -13
  246. torch_geometric/utils/_normalize_edge_index.py +46 -0
  247. torch_geometric/utils/_scatter.py +126 -164
  248. torch_geometric/utils/_sort_edge_index.py +0 -2
  249. torch_geometric/utils/_spmm.py +16 -14
  250. torch_geometric/utils/_subgraph.py +4 -0
  251. torch_geometric/utils/_tree_decomposition.py +1 -1
  252. torch_geometric/utils/_trim_to_layer.py +2 -2
  253. torch_geometric/utils/augmentation.py +1 -1
  254. torch_geometric/utils/convert.py +17 -10
  255. torch_geometric/utils/cross_entropy.py +34 -13
  256. torch_geometric/utils/embedding.py +91 -2
  257. torch_geometric/utils/geodesic.py +28 -25
  258. torch_geometric/utils/influence.py +279 -0
  259. torch_geometric/utils/map.py +14 -10
  260. torch_geometric/utils/nested.py +1 -1
  261. torch_geometric/utils/smiles.py +3 -3
  262. torch_geometric/utils/sparse.py +32 -24
  263. torch_geometric/visualization/__init__.py +2 -1
  264. torch_geometric/visualization/graph.py +250 -5
  265. torch_geometric/warnings.py +11 -2
  266. torch_geometric/nn/nlp/__init__.py +0 -7
  267. torch_geometric/nn/nlp/llm.py +0 -283
  268. torch_geometric/nn/nlp/sentence_transformer.py +0 -94
@@ -1,14 +1,13 @@
1
1
  import os
2
2
  from typing import Callable, List, Optional
3
3
 
4
- import torch
5
-
6
4
  from torch_geometric.data import (
7
5
  Data,
8
6
  InMemoryDataset,
9
7
  download_url,
10
8
  extract_zip,
11
9
  )
10
+ from torch_geometric.io import fs
12
11
 
13
12
 
14
13
  class MNISTSuperpixels(InMemoryDataset):
@@ -85,7 +84,7 @@ class MNISTSuperpixels(InMemoryDataset):
85
84
  os.unlink(path)
86
85
 
87
86
  def process(self) -> None:
88
- inputs = torch.load(self.raw_paths[0])
87
+ inputs = fs.torch_load(self.raw_paths[0])
89
88
  for i in range(len(inputs)):
90
89
  data_list = [Data(**data_dict) for data_dict in inputs[i]]
91
90
 
@@ -79,7 +79,7 @@ class ModelNet(InMemoryDataset):
79
79
 
80
80
  urls = {
81
81
  '10':
82
- 'http://vision.princeton.edu/projects/2014/3DShapeNets/ModelNet10.zip',
82
+ 'http://3dvision.princeton.edu/projects/2014/3DShapeNets/ModelNet10.zip', # noqa
83
83
  '40': 'http://modelnet.cs.princeton.edu/ModelNet40.zip'
84
84
  }
85
85
 
@@ -0,0 +1,492 @@
1
+ import gzip
2
+ import json
3
+ import multiprocessing
4
+ import os
5
+ import sys
6
+ from collections import defaultdict
7
+ from multiprocessing import Pool
8
+ from typing import Callable, List, Optional, Tuple
9
+
10
+ import numpy as np
11
+ import requests
12
+ import torch
13
+ from tqdm import tqdm
14
+
15
+ from torch_geometric.data import Data, InMemoryDataset, download_url
16
+ from torch_geometric.io import fs
17
+ from torch_geometric.llm.models import LLM
18
+ from torch_geometric.utils import one_hot
19
+
20
+
21
+ def clean_up_description(description: str) -> str:
22
+ description = description + " "
23
+
24
+ # extra adj Pure
25
+ if description.startswith("Pure "):
26
+ description = description.replace("Pure ", "")
27
+ # fix typo
28
+ if description.startswith("Mercurycombines"):
29
+ description = description.replace("Mercurycombines",
30
+ "Mercury combines")
31
+
32
+ # a special case
33
+ description = description.replace(
34
+ "17-Hydroxy-6-methylpregna-3,6-diene-3,20-dione. ",
35
+ "17-Hydroxy-6-methylpregna-3,6-diene-3,20-dione is ")
36
+
37
+ # a special case
38
+ description = description.replace("5-Thymidylic acid. ",
39
+ "5-Thymidylic acid. is ")
40
+
41
+ # a special case
42
+ description = description.replace(
43
+ "5'-S-(3-Amino-3-carboxypropyl)-5'-thioadenosine. ",
44
+ "5'-S-(3-Amino-3-carboxypropyl)-5'-thioadenosine. is ")
45
+
46
+ # a special case
47
+ description = description.replace(
48
+ ("Guanosine 5'-(trihydrogen diphosphate), monoanhydride"
49
+ " with phosphorothioic acid. "),
50
+ ("Guanosine 5'-(trihydrogen diphosphate), monoanhydride"
51
+ " with phosphorothioic acid is "))
52
+
53
+ # a special case
54
+ description = description.replace("5'-Uridylic acid. ",
55
+ "5'-Uridylic acid is ")
56
+
57
+ # a special case
58
+ description = description.replace("5'-Adenylic acid, ",
59
+ "5'-Adenylic acid is ")
60
+
61
+ # a special case
62
+ description = description.replace(
63
+ "Uridine 5'-(tetrahydrogen triphosphate). ",
64
+ "Uridine 5'-(tetrahydrogen triphosphate). is ")
65
+
66
+ # a special case
67
+ description = description.replace("Inosine 5'-Monophosphate. ",
68
+ "Inosine 5'-Monophosphate. is ")
69
+
70
+ # a special case
71
+ description = description.replace("Pivaloyloxymethyl butyrate (AN-9), ",
72
+ "Pivaloyloxymethyl butyrate (AN-9) is ")
73
+
74
+ # a special case
75
+ description = description.replace(
76
+ "4-Amino-5-cyano-7-(D-ribofuranosyl)-7H- pyrrolo(2,3-d)pyrimidine. ",
77
+ "4-Amino-5-cyano-7-(D-ribofuranosyl)-7H- pyrrolo(2,3-d)pyrimidine is ")
78
+
79
+ # a special case
80
+ description = description.replace(
81
+ "Cardamonin (also known as Dihydroxymethoxychalcone), ",
82
+ "Cardamonin (also known as Dihydroxymethoxychalcone) is ")
83
+
84
+ # a special case
85
+ description = description.replace("Lithium has been used to treat ",
86
+ "Lithium is ")
87
+
88
+ # a special case
89
+ description = description.replace("4,4'-Methylenebis ",
90
+ "4,4'-Methylenebis is ")
91
+
92
+ # a special case
93
+ description = description.replace(
94
+ "2,3,7,8-Tetrachlorodibenzo-p-dioxin",
95
+ "2,3,7,8-Tetrachlorodibenzo-p-dioxin is ")
96
+
97
+ # a special case
98
+ description = description.replace("Exposure to 2,4,5-trichlorophenol ",
99
+ "2,4,5-Trichlorophenol exposure ")
100
+
101
+ index = 0
102
+ L = len(description)
103
+ if description.startswith('C.I. '):
104
+ start_index = len('C.I. ')
105
+ elif description.startswith('Nectriapyrone. D '):
106
+ start_index = len('Nectriapyrone. D ')
107
+ elif description.startswith(
108
+ 'Salmonella enterica sv. Minnesota LPS core oligosaccharide'):
109
+ start_index = len(
110
+ 'Salmonella enterica sv. Minnesota LPS core oligosaccharide')
111
+ else:
112
+ start_index = 0
113
+ for index in range(start_index, L - 1):
114
+ if index < L - 2:
115
+ if description[index] == '.' and description[
116
+ index + 1] == ' ' and 'A' <= description[index + 2] <= 'Z':
117
+ break
118
+ elif index == L - 2:
119
+ break
120
+
121
+ first_sentence = description[:index + 1]
122
+ return first_sentence
123
+
124
+
125
+ def extract_name(
126
+ name_raw: str,
127
+ description: str,
128
+ ) -> Tuple[Optional[str], str, str]:
129
+ first_sentence = clean_up_description(description)
130
+
131
+ splitter = ' -- -- '
132
+ if ' are ' in first_sentence or ' were ' in first_sentence:
133
+ replaced_words = 'These molecules'
134
+ else:
135
+ replaced_words = 'This molecule'
136
+
137
+ first_sentence = first_sentence.replace(' is ', splitter)
138
+ first_sentence = first_sentence.replace(' are ', splitter)
139
+ first_sentence = first_sentence.replace(' was ', splitter)
140
+ first_sentence = first_sentence.replace(' were ', splitter)
141
+ first_sentence = first_sentence.replace(' appears ', splitter)
142
+ first_sentence = first_sentence.replace(' occurs ', splitter)
143
+ first_sentence = first_sentence.replace(' stands for ', splitter)
144
+ first_sentence = first_sentence.replace(' belongs to ', splitter)
145
+ first_sentence = first_sentence.replace(' exists ',
146
+ splitter) # only for CID=11443
147
+ first_sentence = first_sentence.replace(' has been used in trials ',
148
+ splitter)
149
+ first_sentence = first_sentence.replace(' has been investigated ',
150
+ splitter)
151
+ first_sentence = first_sentence.replace(' has many uses ', splitter)
152
+
153
+ if splitter in first_sentence:
154
+ extracted_name = first_sentence.split(splitter, 1)[0]
155
+ elif first_sentence.startswith(name_raw):
156
+ extracted_name = name_raw
157
+ elif name_raw in first_sentence:
158
+ extracted_name = name_raw
159
+ extracted_name = None
160
+ print("=====", name_raw)
161
+ print("first sentence: ", first_sentence)
162
+ else:
163
+ extracted_name = None
164
+
165
+ if extracted_name is not None:
166
+ extracted_description = description.replace(extracted_name,
167
+ replaced_words)
168
+ else:
169
+ extracted_description = description
170
+
171
+ return extracted_name, extracted_description, first_sentence
172
+
173
+
174
+ class MoleculeGPTDataset(InMemoryDataset):
175
+ r"""The dataset from the `"MoleculeGPT: Instruction Following Large
176
+ Language Models for Molecular Property Prediction"
177
+ <https://ai4d3.github.io/2023/papers/34.pdf>`_ paper.
178
+
179
+ Args:
180
+ root (str): Root directory where the dataset should be saved.
181
+ transform (callable, optional): A function/transform that takes in an
182
+ :obj:`torch_geometric.data.Data` object and returns a transformed
183
+ version. The data object will be transformed before every access.
184
+ (default: :obj:`None`)
185
+ pre_transform (callable, optional): A function/transform that takes in
186
+ an :obj:`torch_geometric.data.Data` object and returns a
187
+ transformed version. The data object will be transformed before
188
+ being saved to disk. (default: :obj:`None`)
189
+ pre_filter (callable, optional): A function that takes in an
190
+ :obj:`torch_geometric.data.Data` object and returns a boolean
191
+ value, indicating whether the data object should be included in the
192
+ final dataset. (default: :obj:`None`)
193
+ force_reload (bool, optional): Whether to re-process the dataset.
194
+ (default: :obj:`False`)
195
+ total_page_num (int, optional): The number of pages from PubChem.
196
+ (default: :obj:`10`)
197
+ total_block_num (int, optional): The blocks of SDF files from PubChem.
198
+ (default: :obj:`1`)
199
+ num_units (int, optional): Number of units of the sample.
200
+ (default: :obj:`-1`, which means all units will be used)
201
+ """
202
+ description_url = (
203
+ 'https://pubchem.ncbi.nlm.nih.gov/rest/pug_view/annotations/'
204
+ 'heading/json?heading_type=Compound&heading=Record+Description&page={}'
205
+ )
206
+ compound_url = ('https://ftp.ncbi.nlm.nih.gov/pubchem/Compound/'
207
+ 'CURRENT-Full/SDF')
208
+
209
+ def __init__(
210
+ self,
211
+ root: str,
212
+ transform: Optional[Callable] = None,
213
+ pre_transform: Optional[Callable] = None,
214
+ pre_filter: Optional[Callable] = None,
215
+ force_reload: bool = False,
216
+ total_page_num: int = 10,
217
+ total_block_num: int = 1,
218
+ num_units: int = -1,
219
+ ):
220
+ self.total_page_num = total_page_num
221
+ self.total_block_num = total_block_num
222
+ self.num_units = num_units
223
+
224
+ super().__init__(root, transform, pre_transform, pre_filter,
225
+ force_reload=force_reload)
226
+ self.load(self.processed_paths[0])
227
+
228
+ @property
229
+ def raw_file_names(self) -> List[str]:
230
+ return ['pubchem.csv']
231
+
232
+ @property
233
+ def processed_file_names(self) -> List[str]:
234
+ return ['data.pt']
235
+
236
+ def download(self) -> None:
237
+ # Step 01. Extract description
238
+ step1_folder = f"{self.raw_dir}/step_01_PubChemSTM_description"
239
+ if not os.path.exists(step1_folder):
240
+ os.makedirs(step1_folder)
241
+ valid_CID_set = set()
242
+ CID2name_raw, CID2name_extracted = defaultdict(list), defaultdict(
243
+ list)
244
+ CID2text_raw, CID2text_extracted = defaultdict(list), defaultdict(
245
+ list)
246
+
247
+ for page_index in tqdm(range(self.total_page_num)):
248
+ page_num = page_index + 1
249
+ f_out = open(
250
+ f"{step1_folder}/Compound_description_{page_num}.txt", "w")
251
+
252
+ description_data = requests.get(
253
+ self.description_url.format(page_num)).json()
254
+
255
+ description_data = description_data["Annotations"]
256
+ assert description_data["Page"] == page_num
257
+
258
+ record_list = description_data["Annotation"]
259
+
260
+ for record in record_list:
261
+ try:
262
+ CID = record["LinkedRecords"]["CID"][0]
263
+ if "Name" in record:
264
+ name_raw = record["Name"]
265
+ CID2name_raw[CID].append(name_raw)
266
+ else:
267
+ name_raw = None
268
+
269
+ data_list = record["Data"]
270
+ for data in data_list:
271
+ description = data["Value"]["StringWithMarkup"][0][
272
+ "String"].strip()
273
+
274
+ extracted_name, extracted_description, _ = extract_name( # noqa: E501
275
+ name_raw, description)
276
+ if extracted_name is not None:
277
+ CID2name_extracted[CID].append(extracted_name)
278
+
279
+ CID2text_raw[CID].append(description)
280
+ CID2text_extracted[CID].append(
281
+ extracted_description)
282
+
283
+ valid_CID_set.add(CID)
284
+ f_out.write(f"{CID}\n")
285
+ f_out.write(f"{extracted_description}\n\n")
286
+ except Exception:
287
+ continue
288
+
289
+ valid_CID_list = sorted(list(valid_CID_set))
290
+ print(f"Total CID (with raw name) {len(CID2name_raw)}")
291
+ print(f"Total CID (with extracted name) {len(CID2name_extracted)}")
292
+ print(f"Total CID {len(valid_CID_list)}")
293
+
294
+ with open(f"{self.raw_dir}/CID2name_raw.json", "w") as f:
295
+ json.dump(CID2name_raw, f)
296
+
297
+ with open(f"{self.raw_dir}/CID2name.json", "w") as f:
298
+ json.dump(CID2name_extracted, f)
299
+
300
+ with open(f"{self.raw_dir}/CID2text_raw.json", "w") as f:
301
+ json.dump(CID2text_raw, f)
302
+
303
+ with open(f"{self.raw_dir}/CID2text.json", "w") as f:
304
+ json.dump(CID2text_extracted, f)
305
+
306
+ # Step 02. Download SDF Files
307
+ step2_folder = f"{self.raw_dir}/step_02_PubChemSTM_SDF"
308
+ if not os.path.exists(step2_folder):
309
+ for block_id in tqdm(range(self.total_block_num)):
310
+ block_size = 500000
311
+ l_id = block_id * block_size + 1
312
+ r_id = (block_id + 1) * block_size
313
+
314
+ compound_file_name = f"Compound_{l_id:09d}_{r_id:09d}.sdf.gz"
315
+ download_url(f"{self.compound_url}/{compound_file_name}",
316
+ step2_folder)
317
+
318
+ def process(self, use_mp: bool = False) -> None:
319
+ try:
320
+ from rdkit import Chem
321
+ from rdkit.Chem.rdchem import BondType as BT
322
+ WITH_RDKIT = True
323
+
324
+ except ImportError:
325
+ WITH_RDKIT = False
326
+
327
+ if not WITH_RDKIT:
328
+ print(("Using a pre-processed version of the dataset. Please "
329
+ "install 'rdkit' to alternatively process the raw data."),
330
+ file=sys.stderr)
331
+
332
+ data_list = fs.torch_load(self.raw_paths[0])
333
+ data_list = [Data(**data_dict) for data_dict in data_list]
334
+
335
+ if self.pre_filter is not None:
336
+ data_list = [d for d in data_list if self.pre_filter(d)]
337
+
338
+ if self.pre_transform is not None:
339
+ data_list = [self.pre_transform(d) for d in data_list]
340
+
341
+ self.save(data_list, self.processed_paths[0])
342
+ return
343
+
344
+ # Step 03. Filter out SDF
345
+ step2_folder = f"{self.raw_dir}/step_02_PubChemSTM_SDF"
346
+ step3_folder = f"{self.raw_dir}/step_03_PubChemSTM_filtered"
347
+ if not os.path.exists(step3_folder):
348
+ os.makedirs(step3_folder)
349
+ with open(f"{self.raw_dir}/CID2text.json") as f:
350
+ CID2text = json.load(f)
351
+ target_CID_list = set(CID2text.keys())
352
+
353
+ block_size = 500000
354
+
355
+ def extract_one_SDF_file(block_id: int) -> None:
356
+ valid_mol_count = 0
357
+
358
+ writer = Chem.SDWriter(
359
+ f'{step3_folder}/filtered_{block_id}.sdf')
360
+ l_id = block_id * block_size + 1
361
+ r_id = (block_id + 1) * block_size
362
+
363
+ compound_file_name = f"Compound_{l_id:09d}_{r_id:09d}.sdf.gz"
364
+ gzip_loader = gzip.open(f"{step2_folder}/{compound_file_name}")
365
+ suppl = Chem.ForwardSDMolSupplier(gzip_loader)
366
+
367
+ for mol in tqdm(suppl):
368
+ if mol is None:
369
+ continue
370
+ cid = mol.GetProp("PUBCHEM_COMPOUND_CID")
371
+
372
+ if cid not in target_CID_list:
373
+ continue
374
+
375
+ writer.write(mol)
376
+ valid_mol_count += 1
377
+
378
+ writer.close()
379
+ print(f"block id: {block_id}\nfound {valid_mol_count}\n\n")
380
+ sys.stdout.flush()
381
+ return
382
+
383
+ if use_mp:
384
+ num_process = multiprocessing.cpu_count()
385
+ print(f"{num_process} CPUs")
386
+ num_process = 8
387
+ p = Pool(num_process)
388
+
389
+ block_id_list = np.arange(self.total_block_num)
390
+ with p:
391
+ p.map(extract_one_SDF_file, block_id_list)
392
+ else:
393
+ for block_id in range(self.total_block_num):
394
+ extract_one_SDF_file(block_id)
395
+
396
+ # Step 04. Merge SDF
397
+ with open(f"{self.raw_dir}/CID2text.json") as f:
398
+ CID2text = json.load(f)
399
+ target_CID_list = set(CID2text.keys())
400
+ print(f'The length of target_CID_list: {len(target_CID_list)}')
401
+
402
+ writer = Chem.SDWriter(f'{self.raw_dir}/molecules.sdf')
403
+
404
+ found_CID_set = set()
405
+ for block_id in range(self.total_block_num + 1):
406
+ compound_file_path = f"{step3_folder}/filtered_{block_id}.sdf"
407
+ try:
408
+ suppl = Chem.SDMolSupplier(compound_file_path)
409
+
410
+ for mol in tqdm(suppl):
411
+ writer.write(mol)
412
+ cid = mol.GetProp("PUBCHEM_COMPOUND_CID")
413
+ found_CID_set.add(cid)
414
+ except Exception:
415
+ print(f"block id: {block_id} with 0 valid SDF file")
416
+ continue
417
+
418
+ writer.close()
419
+ print(f"In total: {len(found_CID_set)} molecules")
420
+
421
+ # Step 05. Convert to PyG data format
422
+ types = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4, 'Unknow': 5}
423
+ bonds = {BT.SINGLE: 0, BT.DOUBLE: 1, BT.TRIPLE: 2, BT.AROMATIC: 3}
424
+
425
+ data_list = []
426
+ # Real data
427
+ CID2text_file = f'{self.raw_dir}/CID2text.json'
428
+
429
+ with open(CID2text_file) as f:
430
+ CID2text_data = json.load(f)
431
+
432
+ suppl = Chem.SDMolSupplier(f'{self.raw_dir}/molecules.sdf')
433
+
434
+ llm = LLM(
435
+ # model_name='lmsys/vicuna-7b-v1.5',
436
+ model_name='TinyLlama/TinyLlama-1.1B-Chat-v0.1',
437
+ num_params=1,
438
+ dtype=torch.bfloat16,
439
+ )
440
+ prompt = ("Propose a question regarding the molecule '∼' "
441
+ "whose answer is: {}:")
442
+ for mol in tqdm(suppl):
443
+ if mol.HasProp('PUBCHEM_COMPOUND_CID'):
444
+ CID = mol.GetProp("PUBCHEM_COMPOUND_CID")
445
+ CAN_SMILES = mol.GetProp("PUBCHEM_SMILES")
446
+
447
+ m: Chem.Mol = Chem.MolFromSmiles(CAN_SMILES)
448
+ if m is None:
449
+ continue
450
+ RDKit_CAN_SMILES = Chem.MolToSmiles(m)
451
+
452
+ ground_truth = CID2text_data[CID][0]
453
+
454
+ instruction = llm.inference([prompt.format(ground_truth)])[0]
455
+
456
+ x: torch.Tensor = torch.tensor([
457
+ types[atom.GetSymbol()] if atom.GetSymbol() in types else 5
458
+ for atom in m.GetAtoms()
459
+ ])
460
+ x = one_hot(x, num_classes=len(types), dtype=torch.float)
461
+
462
+ rows, cols, edge_types = [], [], []
463
+ for bond in m.GetBonds():
464
+ i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
465
+ edge_types += [bonds[bond.GetBondType()]] * 2
466
+ rows += [i, j]
467
+ cols += [j, i]
468
+
469
+ edge_index = torch.tensor([rows, cols], dtype=torch.long)
470
+ edge_type = torch.tensor(edge_types, dtype=torch.long)
471
+ edge_attr = one_hot(edge_type, num_classes=len(bonds))
472
+
473
+ data = Data(
474
+ x=x,
475
+ edge_index=edge_index,
476
+ edge_attr=edge_attr,
477
+ smiles=RDKit_CAN_SMILES,
478
+ instruction=instruction,
479
+ y=ground_truth,
480
+ )
481
+
482
+ if self.pre_filter is not None and not self.pre_filter(data):
483
+ continue
484
+ if self.pre_transform is not None:
485
+ data = self.pre_transform(data)
486
+
487
+ data_list.append(data)
488
+
489
+ if self.num_units > 0 and len(data_list) >= self.num_units:
490
+ break
491
+
492
+ self.save(data_list, self.processed_paths[0])
@@ -210,8 +210,9 @@ class MoleculeNet(InMemoryDataset):
210
210
  data.y = y
211
211
 
212
212
  if data.num_nodes == 0:
213
- warnings.warn(f"Skipping molecule '{smiles}' since it "
214
- f"resulted in zero atoms")
213
+ warnings.warn(
214
+ f"Skipping molecule '{smiles}' since it "
215
+ f"resulted in zero atoms", stacklevel=2)
215
216
  continue
216
217
 
217
218
  if self.pre_filter is not None and not self.pre_filter(data):
@@ -2,8 +2,6 @@ import os
2
2
  import os.path as osp
3
3
  from typing import Callable, List, Optional
4
4
 
5
- import torch
6
-
7
5
  from torch_geometric.data import (
8
6
  Data,
9
7
  InMemoryDataset,
@@ -110,7 +108,7 @@ class NeuroGraphDataset(InMemoryDataset):
110
108
  fs.rm(osp.join(self.raw_dir, self.name))
111
109
 
112
110
  def process(self) -> None:
113
- data, slices = torch.load(self.raw_paths[0])
111
+ data, slices = fs.torch_load(self.raw_paths[0])
114
112
 
115
113
  num_samples = slices['x'].size(0) - 1
116
114
  data_list: List[Data] = []
@@ -147,7 +147,7 @@ class OGB_MAG(InMemoryDataset):
147
147
  for node_type in ['author', 'institution', 'field_of_study']:
148
148
  data[node_type].num_nodes = num_nodes_df[node_type].tolist()[0]
149
149
  else:
150
- emb_dict = torch.load(self.raw_paths[-1])
150
+ emb_dict = fs.torch_load(self.raw_paths[-1])
151
151
  for key, value in emb_dict.items():
152
152
  if key != 'paper':
153
153
  data[key].x = value
@@ -41,6 +41,12 @@ class OPFDataset(InMemoryDataset):
41
41
  If :obj:`"test"`, loads the test dataset. (default: :obj:`"train"`)
42
42
  case_name (str, optional): The name of the original pglib-opf case.
43
43
  (default: :obj:`"pglib_opf_case14_ieee"`)
44
+ num_groups (int, optional): The dataset is divided into 20 groups with
45
+ each group containing 15,000 samples.
46
+ For large networks, this amount of data can be overwhelming.
47
+ The :obj:`num_groups` parameters controls the amount of data being
48
+ downloaded. Allowed values are :obj:`[1, 20]`.
49
+ (default: :obj:`20`)
44
50
  topological_perturbations (bool, optional): Whether to use the dataset
45
51
  with added topological perturbations. (default: :obj:`False`)
46
52
  transform (callable, optional): A function/transform that takes in
@@ -76,6 +82,7 @@ class OPFDataset(InMemoryDataset):
76
82
  'pglib_opf_case10000_goc',
77
83
  'pglib_opf_case13659_pegase',
78
84
  ] = 'pglib_opf_case14_ieee',
85
+ num_groups: int = 20,
79
86
  topological_perturbations: bool = False,
80
87
  transform: Optional[Callable] = None,
81
88
  pre_transform: Optional[Callable] = None,
@@ -85,6 +92,7 @@ class OPFDataset(InMemoryDataset):
85
92
 
86
93
  self.split = split
87
94
  self.case_name = case_name
95
+ self.num_groups = num_groups
88
96
  self.topological_perturbations = topological_perturbations
89
97
 
90
98
  self._release = 'dataset_release_1'
@@ -103,11 +111,12 @@ class OPFDataset(InMemoryDataset):
103
111
 
104
112
  @property
105
113
  def processed_dir(self) -> str:
106
- return osp.join(self.root, self._release, self.case_name, 'processed')
114
+ return osp.join(self.root, self._release, self.case_name,
115
+ f'processed_{self.num_groups}')
107
116
 
108
117
  @property
109
118
  def raw_file_names(self) -> List[str]:
110
- return [f'{self.case_name}_{i}.tar.gz' for i in range(20)]
119
+ return [f'{self.case_name}_{i}.tar.gz' for i in range(self.num_groups)]
111
120
 
112
121
  @property
113
122
  def processed_file_names(self) -> List[str]:
@@ -124,7 +133,7 @@ class OPFDataset(InMemoryDataset):
124
133
  val_data_list = []
125
134
  test_data_list = []
126
135
 
127
- for group in tqdm.tqdm(range(20)):
136
+ for group in tqdm.tqdm(range(self.num_groups)):
128
137
  tmp_dir = osp.join(
129
138
  self.raw_dir,
130
139
  'gridopt-dataset-tmp',
@@ -139,11 +148,14 @@ class OPFDataset(InMemoryDataset):
139
148
 
140
149
  grid = obj['grid']
141
150
  solution = obj['solution']
151
+ metadata = obj['metadata']
142
152
 
143
153
  # Graph-level properties:
144
154
  data = HeteroData()
145
155
  data.x = torch.tensor(grid['context']).view(-1)
146
156
 
157
+ data.objective = torch.tensor(metadata['objective'])
158
+
147
159
  # Nodes (only some have a target):
148
160
  data['bus'].x = torch.tensor(grid['nodes']['bus'])
149
161
  data['bus'].y = torch.tensor(solution['nodes']['bus'])
@@ -193,9 +205,11 @@ class OPFDataset(InMemoryDataset):
193
205
  data = self.pre_transform(data)
194
206
 
195
207
  i = int(name.split('.')[0].split('_')[1])
196
- if i < 270_000:
208
+ train_limit = int(15_000 * self.num_groups * 0.9)
209
+ val_limit = train_limit + int(15_000 * self.num_groups * 0.05)
210
+ if i < train_limit:
197
211
  train_data_list.append(data)
198
- elif i < 285_000:
212
+ elif i < val_limit:
199
213
  val_data_list.append(data)
200
214
  else:
201
215
  test_data_list.append(data)
@@ -66,7 +66,7 @@ class PascalPF(InMemoryDataset):
66
66
  super().__init__(root, transform, pre_transform, pre_filter,
67
67
  force_reload=force_reload)
68
68
  self.load(self.processed_paths[0])
69
- self.pairs = torch.load(self.processed_paths[1])
69
+ self.pairs = fs.torch_load(self.processed_paths[1])
70
70
 
71
71
  @property
72
72
  def raw_file_names(self) -> List[str]:
@@ -7,6 +7,7 @@ from tqdm import tqdm
7
7
 
8
8
  from torch_geometric.data import Data, OnDiskDataset, download_url, extract_zip
9
9
  from torch_geometric.data.data import BaseData
10
+ from torch_geometric.io import fs
10
11
  from torch_geometric.utils import from_smiles as _from_smiles
11
12
 
12
13
 
@@ -72,7 +73,7 @@ class PCQM4Mv2(OnDiskDataset):
72
73
  self.from_smiles = from_smiles or _from_smiles
73
74
  super().__init__(root, transform, backend=backend, schema=schema)
74
75
 
75
- split_idx = torch.load(self.raw_paths[1])
76
+ split_idx = fs.torch_load(self.raw_paths[1])
76
77
  self._indices = split_idx[self.split_mapping[split]].tolist()
77
78
 
78
79
  @property
@@ -107,7 +107,8 @@ class PPI(InMemoryDataset):
107
107
  for s, split in enumerate(['train', 'valid', 'test']):
108
108
  path = osp.join(self.raw_dir, f'{split}_graph.json')
109
109
  with open(path) as f:
110
- G = nx.DiGraph(json_graph.node_link_graph(json.load(f)))
110
+ G = nx.DiGraph(
111
+ json_graph.node_link_graph(json.load(f), edges="links"))
111
112
 
112
113
  x = np.load(osp.join(self.raw_dir, f'{split}_feats.npy'))
113
114
  x = torch.from_numpy(x).to(torch.float)