pyg-nightly 2.6.0.dev20240704__py3-none-any.whl → 2.8.0.dev20251207__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pyg-nightly might be problematic. Click here for more details.

Files changed (268) hide show
  1. {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/METADATA +81 -58
  2. {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/RECORD +265 -221
  3. {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/WHEEL +1 -1
  4. pyg_nightly-2.8.0.dev20251207.dist-info/licenses/LICENSE +19 -0
  5. torch_geometric/__init__.py +34 -1
  6. torch_geometric/_compile.py +11 -3
  7. torch_geometric/_onnx.py +228 -0
  8. torch_geometric/config_mixin.py +8 -3
  9. torch_geometric/config_store.py +1 -1
  10. torch_geometric/contrib/__init__.py +1 -1
  11. torch_geometric/contrib/explain/pgm_explainer.py +1 -1
  12. torch_geometric/data/__init__.py +19 -1
  13. torch_geometric/data/batch.py +2 -2
  14. torch_geometric/data/collate.py +1 -3
  15. torch_geometric/data/data.py +110 -6
  16. torch_geometric/data/database.py +19 -5
  17. torch_geometric/data/dataset.py +14 -9
  18. torch_geometric/data/extract.py +1 -1
  19. torch_geometric/data/feature_store.py +17 -22
  20. torch_geometric/data/graph_store.py +3 -2
  21. torch_geometric/data/hetero_data.py +139 -7
  22. torch_geometric/data/hypergraph_data.py +2 -2
  23. torch_geometric/data/in_memory_dataset.py +2 -2
  24. torch_geometric/data/lightning/datamodule.py +42 -28
  25. torch_geometric/data/storage.py +9 -1
  26. torch_geometric/datasets/__init__.py +20 -1
  27. torch_geometric/datasets/actor.py +7 -9
  28. torch_geometric/datasets/airfrans.py +17 -20
  29. torch_geometric/datasets/airports.py +8 -10
  30. torch_geometric/datasets/amazon.py +8 -11
  31. torch_geometric/datasets/amazon_book.py +8 -9
  32. torch_geometric/datasets/amazon_products.py +7 -9
  33. torch_geometric/datasets/aminer.py +8 -9
  34. torch_geometric/datasets/aqsol.py +10 -13
  35. torch_geometric/datasets/attributed_graph_dataset.py +8 -10
  36. torch_geometric/datasets/ba_multi_shapes.py +10 -12
  37. torch_geometric/datasets/ba_shapes.py +5 -6
  38. torch_geometric/datasets/brca_tgca.py +1 -1
  39. torch_geometric/datasets/city.py +157 -0
  40. torch_geometric/datasets/dbp15k.py +1 -1
  41. torch_geometric/datasets/gdelt_lite.py +3 -2
  42. torch_geometric/datasets/ged_dataset.py +3 -2
  43. torch_geometric/datasets/git_mol_dataset.py +263 -0
  44. torch_geometric/datasets/gnn_benchmark_dataset.py +3 -2
  45. torch_geometric/datasets/hgb_dataset.py +2 -2
  46. torch_geometric/datasets/hm.py +1 -1
  47. torch_geometric/datasets/instruct_mol_dataset.py +134 -0
  48. torch_geometric/datasets/linkx_dataset.py +4 -3
  49. torch_geometric/datasets/lrgb.py +3 -5
  50. torch_geometric/datasets/malnet_tiny.py +2 -1
  51. torch_geometric/datasets/md17.py +3 -3
  52. torch_geometric/datasets/medshapenet.py +145 -0
  53. torch_geometric/datasets/mnist_superpixels.py +2 -3
  54. torch_geometric/datasets/modelnet.py +1 -1
  55. torch_geometric/datasets/molecule_gpt_dataset.py +492 -0
  56. torch_geometric/datasets/molecule_net.py +3 -2
  57. torch_geometric/datasets/neurograph.py +1 -3
  58. torch_geometric/datasets/ogb_mag.py +1 -1
  59. torch_geometric/datasets/opf.py +19 -5
  60. torch_geometric/datasets/pascal_pf.py +1 -1
  61. torch_geometric/datasets/pcqm4m.py +2 -1
  62. torch_geometric/datasets/ppi.py +2 -1
  63. torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
  64. torch_geometric/datasets/qm7.py +1 -1
  65. torch_geometric/datasets/qm9.py +3 -2
  66. torch_geometric/datasets/shrec2016.py +2 -2
  67. torch_geometric/datasets/snap_dataset.py +8 -4
  68. torch_geometric/datasets/tag_dataset.py +462 -0
  69. torch_geometric/datasets/teeth3ds.py +269 -0
  70. torch_geometric/datasets/web_qsp_dataset.py +342 -0
  71. torch_geometric/datasets/wikics.py +2 -1
  72. torch_geometric/datasets/wikidata.py +2 -1
  73. torch_geometric/deprecation.py +1 -1
  74. torch_geometric/distributed/__init__.py +13 -0
  75. torch_geometric/distributed/dist_loader.py +2 -2
  76. torch_geometric/distributed/local_feature_store.py +3 -2
  77. torch_geometric/distributed/local_graph_store.py +2 -1
  78. torch_geometric/distributed/partition.py +9 -8
  79. torch_geometric/distributed/rpc.py +3 -3
  80. torch_geometric/edge_index.py +35 -22
  81. torch_geometric/explain/algorithm/attention_explainer.py +219 -29
  82. torch_geometric/explain/algorithm/base.py +2 -2
  83. torch_geometric/explain/algorithm/captum.py +1 -1
  84. torch_geometric/explain/algorithm/captum_explainer.py +2 -1
  85. torch_geometric/explain/algorithm/gnn_explainer.py +406 -69
  86. torch_geometric/explain/algorithm/graphmask_explainer.py +8 -8
  87. torch_geometric/explain/algorithm/pg_explainer.py +305 -47
  88. torch_geometric/explain/explainer.py +2 -2
  89. torch_geometric/explain/explanation.py +89 -5
  90. torch_geometric/explain/metric/faithfulness.py +1 -1
  91. torch_geometric/graphgym/checkpoint.py +2 -1
  92. torch_geometric/graphgym/config.py +3 -2
  93. torch_geometric/graphgym/imports.py +15 -4
  94. torch_geometric/graphgym/logger.py +1 -1
  95. torch_geometric/graphgym/loss.py +1 -1
  96. torch_geometric/graphgym/models/encoder.py +2 -2
  97. torch_geometric/graphgym/models/layer.py +1 -1
  98. torch_geometric/graphgym/utils/comp_budget.py +4 -3
  99. torch_geometric/hash_tensor.py +798 -0
  100. torch_geometric/index.py +16 -7
  101. torch_geometric/inspector.py +6 -2
  102. torch_geometric/io/fs.py +27 -0
  103. torch_geometric/io/tu.py +2 -3
  104. torch_geometric/llm/__init__.py +9 -0
  105. torch_geometric/llm/large_graph_indexer.py +741 -0
  106. torch_geometric/llm/models/__init__.py +23 -0
  107. torch_geometric/llm/models/g_retriever.py +251 -0
  108. torch_geometric/llm/models/git_mol.py +336 -0
  109. torch_geometric/llm/models/glem.py +397 -0
  110. torch_geometric/llm/models/llm.py +470 -0
  111. torch_geometric/llm/models/llm_judge.py +158 -0
  112. torch_geometric/llm/models/molecule_gpt.py +222 -0
  113. torch_geometric/llm/models/protein_mpnn.py +333 -0
  114. torch_geometric/llm/models/sentence_transformer.py +188 -0
  115. torch_geometric/llm/models/txt2kg.py +353 -0
  116. torch_geometric/llm/models/vision_transformer.py +38 -0
  117. torch_geometric/llm/rag_loader.py +154 -0
  118. torch_geometric/llm/utils/__init__.py +10 -0
  119. torch_geometric/llm/utils/backend_utils.py +443 -0
  120. torch_geometric/llm/utils/feature_store.py +169 -0
  121. torch_geometric/llm/utils/graph_store.py +199 -0
  122. torch_geometric/llm/utils/vectorrag.py +125 -0
  123. torch_geometric/loader/cluster.py +6 -5
  124. torch_geometric/loader/graph_saint.py +2 -1
  125. torch_geometric/loader/ibmb_loader.py +4 -4
  126. torch_geometric/loader/link_loader.py +1 -1
  127. torch_geometric/loader/link_neighbor_loader.py +2 -1
  128. torch_geometric/loader/mixin.py +6 -5
  129. torch_geometric/loader/neighbor_loader.py +1 -1
  130. torch_geometric/loader/neighbor_sampler.py +2 -2
  131. torch_geometric/loader/prefetch.py +4 -3
  132. torch_geometric/loader/temporal_dataloader.py +2 -2
  133. torch_geometric/loader/utils.py +10 -10
  134. torch_geometric/metrics/__init__.py +23 -2
  135. torch_geometric/metrics/link_pred.py +755 -85
  136. torch_geometric/nn/__init__.py +1 -0
  137. torch_geometric/nn/aggr/__init__.py +2 -0
  138. torch_geometric/nn/aggr/base.py +1 -1
  139. torch_geometric/nn/aggr/equilibrium.py +1 -1
  140. torch_geometric/nn/aggr/fused.py +1 -1
  141. torch_geometric/nn/aggr/patch_transformer.py +149 -0
  142. torch_geometric/nn/aggr/set_transformer.py +1 -1
  143. torch_geometric/nn/aggr/utils.py +9 -4
  144. torch_geometric/nn/attention/__init__.py +9 -1
  145. torch_geometric/nn/attention/polynormer.py +107 -0
  146. torch_geometric/nn/attention/qformer.py +71 -0
  147. torch_geometric/nn/attention/sgformer.py +99 -0
  148. torch_geometric/nn/conv/__init__.py +2 -0
  149. torch_geometric/nn/conv/appnp.py +1 -1
  150. torch_geometric/nn/conv/collect.jinja +6 -3
  151. torch_geometric/nn/conv/cugraph/gat_conv.py +8 -2
  152. torch_geometric/nn/conv/cugraph/rgcn_conv.py +3 -0
  153. torch_geometric/nn/conv/cugraph/sage_conv.py +3 -0
  154. torch_geometric/nn/conv/dna_conv.py +1 -1
  155. torch_geometric/nn/conv/eg_conv.py +7 -7
  156. torch_geometric/nn/conv/gat_conv.py +33 -4
  157. torch_geometric/nn/conv/gatv2_conv.py +35 -4
  158. torch_geometric/nn/conv/gen_conv.py +1 -1
  159. torch_geometric/nn/conv/general_conv.py +1 -1
  160. torch_geometric/nn/conv/gravnet_conv.py +2 -1
  161. torch_geometric/nn/conv/hetero_conv.py +3 -2
  162. torch_geometric/nn/conv/meshcnn_conv.py +487 -0
  163. torch_geometric/nn/conv/message_passing.py +6 -5
  164. torch_geometric/nn/conv/mixhop_conv.py +1 -1
  165. torch_geometric/nn/conv/rgcn_conv.py +2 -1
  166. torch_geometric/nn/conv/sg_conv.py +1 -1
  167. torch_geometric/nn/conv/spline_conv.py +2 -1
  168. torch_geometric/nn/conv/ssg_conv.py +1 -1
  169. torch_geometric/nn/conv/transformer_conv.py +5 -3
  170. torch_geometric/nn/data_parallel.py +5 -4
  171. torch_geometric/nn/dense/linear.py +5 -24
  172. torch_geometric/nn/encoding.py +17 -3
  173. torch_geometric/nn/fx.py +17 -15
  174. torch_geometric/nn/model_hub.py +5 -16
  175. torch_geometric/nn/models/__init__.py +11 -0
  176. torch_geometric/nn/models/attentive_fp.py +1 -1
  177. torch_geometric/nn/models/attract_repel.py +148 -0
  178. torch_geometric/nn/models/basic_gnn.py +2 -1
  179. torch_geometric/nn/models/captum.py +1 -1
  180. torch_geometric/nn/models/deep_graph_infomax.py +1 -1
  181. torch_geometric/nn/models/dimenet.py +2 -2
  182. torch_geometric/nn/models/dimenet_utils.py +4 -2
  183. torch_geometric/nn/models/gpse.py +1083 -0
  184. torch_geometric/nn/models/graph_unet.py +13 -4
  185. torch_geometric/nn/models/lpformer.py +783 -0
  186. torch_geometric/nn/models/metapath2vec.py +1 -1
  187. torch_geometric/nn/models/mlp.py +4 -2
  188. torch_geometric/nn/models/node2vec.py +1 -1
  189. torch_geometric/nn/models/polynormer.py +206 -0
  190. torch_geometric/nn/models/rev_gnn.py +3 -3
  191. torch_geometric/nn/models/schnet.py +2 -1
  192. torch_geometric/nn/models/sgformer.py +219 -0
  193. torch_geometric/nn/models/signed_gcn.py +1 -1
  194. torch_geometric/nn/models/visnet.py +2 -2
  195. torch_geometric/nn/norm/batch_norm.py +17 -7
  196. torch_geometric/nn/norm/diff_group_norm.py +7 -2
  197. torch_geometric/nn/norm/graph_norm.py +9 -4
  198. torch_geometric/nn/norm/instance_norm.py +5 -1
  199. torch_geometric/nn/norm/layer_norm.py +15 -7
  200. torch_geometric/nn/norm/msg_norm.py +8 -2
  201. torch_geometric/nn/pool/__init__.py +15 -9
  202. torch_geometric/nn/pool/cluster_pool.py +144 -0
  203. torch_geometric/nn/pool/connect/base.py +1 -3
  204. torch_geometric/nn/pool/edge_pool.py +1 -1
  205. torch_geometric/nn/pool/knn.py +13 -10
  206. torch_geometric/nn/pool/select/base.py +1 -4
  207. torch_geometric/nn/summary.py +1 -1
  208. torch_geometric/nn/to_hetero_module.py +4 -3
  209. torch_geometric/nn/to_hetero_transformer.py +3 -3
  210. torch_geometric/nn/to_hetero_with_bases_transformer.py +5 -5
  211. torch_geometric/profile/__init__.py +2 -0
  212. torch_geometric/profile/nvtx.py +66 -0
  213. torch_geometric/profile/profiler.py +18 -9
  214. torch_geometric/profile/utils.py +20 -5
  215. torch_geometric/sampler/__init__.py +2 -1
  216. torch_geometric/sampler/base.py +337 -8
  217. torch_geometric/sampler/hgt_sampler.py +11 -1
  218. torch_geometric/sampler/neighbor_sampler.py +298 -25
  219. torch_geometric/sampler/utils.py +93 -5
  220. torch_geometric/testing/__init__.py +4 -0
  221. torch_geometric/testing/decorators.py +35 -5
  222. torch_geometric/testing/distributed.py +1 -1
  223. torch_geometric/transforms/__init__.py +4 -0
  224. torch_geometric/transforms/add_gpse.py +49 -0
  225. torch_geometric/transforms/add_metapaths.py +10 -8
  226. torch_geometric/transforms/add_positional_encoding.py +2 -2
  227. torch_geometric/transforms/base_transform.py +2 -1
  228. torch_geometric/transforms/delaunay.py +65 -15
  229. torch_geometric/transforms/face_to_edge.py +32 -3
  230. torch_geometric/transforms/gdc.py +8 -9
  231. torch_geometric/transforms/largest_connected_components.py +1 -1
  232. torch_geometric/transforms/mask.py +5 -1
  233. torch_geometric/transforms/node_property_split.py +1 -1
  234. torch_geometric/transforms/normalize_features.py +3 -3
  235. torch_geometric/transforms/pad.py +1 -1
  236. torch_geometric/transforms/random_link_split.py +1 -1
  237. torch_geometric/transforms/remove_duplicated_edges.py +4 -2
  238. torch_geometric/transforms/remove_self_loops.py +36 -0
  239. torch_geometric/transforms/rooted_subgraph.py +1 -1
  240. torch_geometric/transforms/svd_feature_reduction.py +1 -1
  241. torch_geometric/transforms/virtual_node.py +2 -1
  242. torch_geometric/typing.py +82 -17
  243. torch_geometric/utils/__init__.py +6 -1
  244. torch_geometric/utils/_lexsort.py +0 -9
  245. torch_geometric/utils/_negative_sampling.py +28 -13
  246. torch_geometric/utils/_normalize_edge_index.py +46 -0
  247. torch_geometric/utils/_scatter.py +126 -164
  248. torch_geometric/utils/_sort_edge_index.py +0 -2
  249. torch_geometric/utils/_spmm.py +16 -14
  250. torch_geometric/utils/_subgraph.py +4 -0
  251. torch_geometric/utils/_tree_decomposition.py +1 -1
  252. torch_geometric/utils/_trim_to_layer.py +2 -2
  253. torch_geometric/utils/augmentation.py +1 -1
  254. torch_geometric/utils/convert.py +17 -10
  255. torch_geometric/utils/cross_entropy.py +34 -13
  256. torch_geometric/utils/embedding.py +91 -2
  257. torch_geometric/utils/geodesic.py +28 -25
  258. torch_geometric/utils/influence.py +279 -0
  259. torch_geometric/utils/map.py +14 -10
  260. torch_geometric/utils/nested.py +1 -1
  261. torch_geometric/utils/smiles.py +3 -3
  262. torch_geometric/utils/sparse.py +32 -24
  263. torch_geometric/visualization/__init__.py +2 -1
  264. torch_geometric/visualization/graph.py +250 -5
  265. torch_geometric/warnings.py +11 -2
  266. torch_geometric/nn/nlp/__init__.py +0 -7
  267. torch_geometric/nn/nlp/llm.py +0 -283
  268. torch_geometric/nn/nlp/sentence_transformer.py +0 -94
@@ -0,0 +1,263 @@
1
+ import sys
2
+ from typing import Any, Callable, Dict, List, Optional
3
+
4
+ import numpy as np
5
+ import torch
6
+ from tqdm import tqdm
7
+
8
+ from torch_geometric.data import (
9
+ Data,
10
+ InMemoryDataset,
11
+ download_google_url,
12
+ extract_zip,
13
+ )
14
+ from torch_geometric.io import fs
15
+
16
+
17
+ def safe_index(lst: List[Any], e: int) -> int:
18
+ return lst.index(e) if e in lst else len(lst) - 1
19
+
20
+
21
+ class GitMolDataset(InMemoryDataset):
22
+ r"""The dataset from the `"GIT-Mol: A Multi-modal Large Language Model
23
+ for Molecular Science with Graph, Image, and Text"
24
+ <https://arxiv.org/pdf/2308.06911>`_ paper.
25
+
26
+ Args:
27
+ root (str): Root directory where the dataset should be saved.
28
+ transform (callable, optional): A function/transform that takes in an
29
+ :obj:`torch_geometric.data.Data` object and returns a transformed
30
+ version. The data object will be transformed before every access.
31
+ (default: :obj:`None`)
32
+ pre_transform (callable, optional): A function/transform that takes in
33
+ an :obj:`torch_geometric.data.Data` object and returns a
34
+ transformed version. The data object will be transformed before
35
+ being saved to disk. (default: :obj:`None`)
36
+ pre_filter (callable, optional): A function that takes in an
37
+ :obj:`torch_geometric.data.Data` object and returns a boolean
38
+ value, indicating whether the data object should be included in the
39
+ final dataset. (default: :obj:`None`)
40
+ force_reload (bool, optional): Whether to re-process the dataset.
41
+ (default: :obj:`False`)
42
+ split (int, optional): Datasets split, train/valid/test=0/1/2.
43
+ (default: :obj:`0`)
44
+ """
45
+
46
+ raw_url_id = '1loBXabD6ncAFY-vanRsVtRUSFkEtBweg'
47
+
48
+ def __init__(
49
+ self,
50
+ root: str,
51
+ transform: Optional[Callable] = None,
52
+ pre_transform: Optional[Callable] = None,
53
+ pre_filter: Optional[Callable] = None,
54
+ force_reload: bool = False,
55
+ split: int = 0,
56
+ ):
57
+ from torchvision import transforms
58
+
59
+ self.split = split
60
+
61
+ if self.split == 0:
62
+ self.img_transform = transforms.Compose([
63
+ transforms.Resize((224, 224)),
64
+ transforms.RandomRotation(15),
65
+ transforms.ColorJitter(brightness=0.5, contrast=0.5, hue=0.5),
66
+ transforms.ToTensor(),
67
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
68
+ std=[0.229, 0.224, 0.225])
69
+ ])
70
+ else:
71
+ self.img_transform = transforms.Compose([
72
+ transforms.Resize((224, 224)),
73
+ transforms.ToTensor(),
74
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
75
+ std=[0.229, 0.224, 0.225])
76
+ ])
77
+
78
+ super().__init__(root, transform, pre_transform, pre_filter,
79
+ force_reload=force_reload)
80
+
81
+ self.load(self.processed_paths[0])
82
+
83
+ @property
84
+ def raw_file_names(self) -> List[str]:
85
+ return ['train_3500.pkl', 'valid_450.pkl', 'test_450.pkl']
86
+
87
+ @property
88
+ def processed_file_names(self) -> str:
89
+ return ['train.pt', 'valid.pt', 'test.pt'][self.split]
90
+
91
+ def download(self) -> None:
92
+ file_path = download_google_url(
93
+ self.raw_url_id,
94
+ self.raw_dir,
95
+ 'gitmol.zip',
96
+ )
97
+ extract_zip(file_path, self.raw_dir)
98
+
99
+ def process(self) -> None:
100
+ import pandas as pd
101
+ from PIL import Image
102
+
103
+ try:
104
+ from rdkit import Chem, RDLogger
105
+ RDLogger.DisableLog('rdApp.*') # type: ignore[attr-defined]
106
+ WITH_RDKIT = True
107
+
108
+ except ImportError:
109
+ WITH_RDKIT = False
110
+
111
+ if not WITH_RDKIT:
112
+ print(("Using a pre-processed version of the dataset. Please "
113
+ "install 'rdkit' to alternatively process the raw data."),
114
+ file=sys.stderr)
115
+
116
+ data_list = fs.torch_load(self.raw_paths[0])
117
+ data_list = [Data(**data_dict) for data_dict in data_list]
118
+
119
+ if self.pre_filter is not None:
120
+ data_list = [d for d in data_list if self.pre_filter(d)]
121
+
122
+ if self.pre_transform is not None:
123
+ data_list = [self.pre_transform(d) for d in data_list]
124
+
125
+ self.save(data_list, self.processed_paths[0])
126
+ return
127
+
128
+ allowable_features: Dict[str, List[Any]] = {
129
+ 'possible_atomic_num_list':
130
+ list(range(1, 119)) + ['misc'],
131
+ 'possible_formal_charge_list':
132
+ [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 'misc'],
133
+ 'possible_chirality_list': [
134
+ Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
135
+ Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
136
+ Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
137
+ Chem.rdchem.ChiralType.CHI_OTHER
138
+ ],
139
+ 'possible_hybridization_list': [
140
+ Chem.rdchem.HybridizationType.SP,
141
+ Chem.rdchem.HybridizationType.SP2,
142
+ Chem.rdchem.HybridizationType.SP3,
143
+ Chem.rdchem.HybridizationType.SP3D,
144
+ Chem.rdchem.HybridizationType.SP3D2,
145
+ Chem.rdchem.HybridizationType.UNSPECIFIED, 'misc'
146
+ ],
147
+ 'possible_numH_list': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'],
148
+ 'possible_implicit_valence_list': [0, 1, 2, 3, 4, 5, 6],
149
+ 'possible_degree_list': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 'misc'],
150
+ 'possible_number_radical_e_list': [0, 1, 2, 3, 4, 'misc'],
151
+ 'possible_is_aromatic_list': [False, True],
152
+ 'possible_is_in_ring_list': [False, True],
153
+ 'possible_bond_type_list': [
154
+ Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE,
155
+ Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC,
156
+ Chem.rdchem.BondType.ZERO
157
+ ],
158
+ 'possible_bond_dirs': [ # only for double bond stereo information
159
+ Chem.rdchem.BondDir.NONE, Chem.rdchem.BondDir.ENDUPRIGHT,
160
+ Chem.rdchem.BondDir.ENDDOWNRIGHT
161
+ ],
162
+ 'possible_bond_stereo_list': [
163
+ Chem.rdchem.BondStereo.STEREONONE,
164
+ Chem.rdchem.BondStereo.STEREOZ,
165
+ Chem.rdchem.BondStereo.STEREOE,
166
+ Chem.rdchem.BondStereo.STEREOCIS,
167
+ Chem.rdchem.BondStereo.STEREOTRANS,
168
+ Chem.rdchem.BondStereo.STEREOANY,
169
+ ],
170
+ 'possible_is_conjugated_list': [False, True]
171
+ }
172
+
173
+ data = pd.read_pickle(
174
+ f'{self.raw_dir}/igcdata_toy/{self.raw_file_names[self.split]}')
175
+
176
+ data_list = []
177
+ for _, r in tqdm(data.iterrows(), total=data.shape[0]):
178
+ smiles = r['isosmiles']
179
+ mol = Chem.MolFromSmiles(smiles.strip('\n'))
180
+ if mol is not None:
181
+ # text
182
+ summary = r['summary']
183
+ # image
184
+ cid = r['cid']
185
+ img_file = f'{self.raw_dir}/igcdata_toy/imgs/CID_{cid}.png'
186
+ img = Image.open(img_file).convert('RGB')
187
+ img = self.img_transform(img).unsqueeze(0)
188
+ # graph
189
+ atom_features_list = []
190
+ for atom in mol.GetAtoms():
191
+ atom_feature = [
192
+ safe_index(
193
+ allowable_features['possible_atomic_num_list'],
194
+ atom.GetAtomicNum()),
195
+ allowable_features['possible_chirality_list'].index(
196
+ atom.GetChiralTag()),
197
+ safe_index(allowable_features['possible_degree_list'],
198
+ atom.GetTotalDegree()),
199
+ safe_index(
200
+ allowable_features['possible_formal_charge_list'],
201
+ atom.GetFormalCharge()),
202
+ safe_index(allowable_features['possible_numH_list'],
203
+ atom.GetTotalNumHs()),
204
+ safe_index(
205
+ allowable_features[
206
+ 'possible_number_radical_e_list'],
207
+ atom.GetNumRadicalElectrons()),
208
+ safe_index(
209
+ allowable_features['possible_hybridization_list'],
210
+ atom.GetHybridization()),
211
+ allowable_features['possible_is_aromatic_list'].index(
212
+ atom.GetIsAromatic()),
213
+ allowable_features['possible_is_in_ring_list'].index(
214
+ atom.IsInRing()),
215
+ ]
216
+ atom_features_list.append(atom_feature)
217
+ x = torch.tensor(np.array(atom_features_list),
218
+ dtype=torch.long)
219
+
220
+ edges_list = []
221
+ edge_features_list = []
222
+ for bond in mol.GetBonds():
223
+ i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
224
+ edge_feature = [
225
+ safe_index(
226
+ allowable_features['possible_bond_type_list'],
227
+ bond.GetBondType()),
228
+ allowable_features['possible_bond_stereo_list'].index(
229
+ bond.GetStereo()),
230
+ allowable_features['possible_is_conjugated_list'].
231
+ index(bond.GetIsConjugated()),
232
+ ]
233
+ edges_list.append((i, j))
234
+ edge_features_list.append(edge_feature)
235
+ edges_list.append((j, i))
236
+ edge_features_list.append(edge_feature)
237
+
238
+ edge_index = torch.tensor(
239
+ np.array(edges_list).T,
240
+ dtype=torch.long,
241
+ )
242
+ edge_attr = torch.tensor(
243
+ np.array(edge_features_list),
244
+ dtype=torch.long,
245
+ )
246
+
247
+ data = Data(
248
+ x=x,
249
+ edge_index=edge_index,
250
+ smiles=smiles,
251
+ edge_attr=edge_attr,
252
+ image=img,
253
+ caption=summary,
254
+ )
255
+
256
+ if self.pre_filter is not None and not self.pre_filter(data):
257
+ continue
258
+ if self.pre_transform is not None:
259
+ data = self.pre_transform(data)
260
+
261
+ data_list.append(data)
262
+
263
+ self.save(data_list, self.processed_paths[0])
@@ -12,6 +12,7 @@ from torch_geometric.data import (
12
12
  download_url,
13
13
  extract_zip,
14
14
  )
15
+ from torch_geometric.io import fs
15
16
  from torch_geometric.utils import remove_self_loops
16
17
 
17
18
 
@@ -181,7 +182,7 @@ class GNNBenchmarkDataset(InMemoryDataset):
181
182
  data_list = self.process_CSL()
182
183
  self.save(data_list, self.processed_paths[0])
183
184
  else:
184
- inputs = torch.load(self.raw_paths[0])
185
+ inputs = fs.torch_load(self.raw_paths[0])
185
186
  for i in range(len(inputs)):
186
187
  data_list = [Data(**data_dict) for data_dict in inputs[i]]
187
188
 
@@ -197,7 +198,7 @@ class GNNBenchmarkDataset(InMemoryDataset):
197
198
  with open(self.raw_paths[0], 'rb') as f:
198
199
  adjs = pickle.load(f)
199
200
 
200
- ys = torch.load(self.raw_paths[1]).tolist()
201
+ ys = fs.torch_load(self.raw_paths[1]).tolist()
201
202
 
202
203
  data_list = []
203
204
  for adj, y in zip(adjs, ys):
@@ -123,8 +123,8 @@ class HGBDataset(InMemoryDataset):
123
123
  start = info.index('LINK\tSTART\tEND\tMEANING') + 1
124
124
  end = info[start:].index('')
125
125
  for key, row in enumerate(info[start:start + end]):
126
- row = row.split('\t')[1:]
127
- src, dst, rel = (v for v in row if v != '')
126
+ edge = row.split('\t')[1:]
127
+ src, dst, rel = (v for v in edge if v != '')
128
128
  src, dst = n_types[int(src)], n_types[int(dst)]
129
129
  rel = rel.split('-')[1]
130
130
  e_types[key] = (src, rel, dst)
@@ -81,7 +81,7 @@ class HM(InMemoryDataset):
81
81
  xs.append(torch.from_numpy(x).to(torch.float))
82
82
 
83
83
  x = torch.from_numpy(df['age'].values).to(torch.float).view(-1, 1)
84
- x = x.nan_to_num(nan=x.nanmean())
84
+ x = x.nan_to_num(nan=x.nanmean()) # type: ignore
85
85
  xs.append(x / x.max())
86
86
 
87
87
  data['customer'].x = torch.cat(xs, dim=-1)
@@ -0,0 +1,134 @@
1
+ import json
2
+ import sys
3
+ from typing import Callable, List, Optional
4
+
5
+ import torch
6
+ from tqdm import tqdm
7
+
8
+ from torch_geometric.data import Data, InMemoryDataset
9
+ from torch_geometric.io import fs
10
+ from torch_geometric.utils import one_hot
11
+
12
+
13
+ class InstructMolDataset(InMemoryDataset):
14
+ r"""The dataset from the `"InstructMol: Multi-Modal Integration for
15
+ Building a Versatile and Reliable Molecular Assistant in Drug Discovery"
16
+ <https://arxiv.org/pdf/2311.16208>`_ paper.
17
+
18
+ Args:
19
+ root (str): Root directory where the dataset should be saved.
20
+ transform (callable, optional): A function/transform that takes in an
21
+ :obj:`torch_geometric.data.Data` object and returns a transformed
22
+ version. The data object will be transformed before every access.
23
+ (default: :obj:`None`)
24
+ pre_transform (callable, optional): A function/transform that takes in
25
+ an :obj:`torch_geometric.data.Data` object and returns a
26
+ transformed version. The data object will be transformed before
27
+ being saved to disk. (default: :obj:`None`)
28
+ pre_filter (callable, optional): A function that takes in an
29
+ :obj:`torch_geometric.data.Data` object and returns a boolean
30
+ value, indicating whether the data object should be included in the
31
+ final dataset. (default: :obj:`None`)
32
+ force_reload (bool, optional): Whether to re-process the dataset.
33
+ (default: :obj:`False`)
34
+ """
35
+ raw_url = 'https://huggingface.co/datasets/OpenMol/PubChemSFT/resolve/main'
36
+
37
+ def __init__(
38
+ self,
39
+ root: str,
40
+ transform: Optional[Callable] = None,
41
+ pre_transform: Optional[Callable] = None,
42
+ pre_filter: Optional[Callable] = None,
43
+ force_reload: bool = False,
44
+ ):
45
+ super().__init__(root, transform, pre_transform, pre_filter,
46
+ force_reload=force_reload)
47
+ self.load(self.processed_paths[0])
48
+
49
+ @property
50
+ def raw_file_names(self) -> List[str]:
51
+ return ['all_clean.json']
52
+
53
+ @property
54
+ def processed_file_names(self) -> List[str]:
55
+ return ['data.pt']
56
+
57
+ def download(self) -> None:
58
+ print('downloading dataset...')
59
+ fs.cp(f'{self.raw_url}/all_clean.json', self.raw_dir)
60
+
61
+ def process(self) -> None:
62
+ try:
63
+ from rdkit import Chem
64
+ from rdkit.Chem.rdchem import BondType as BT
65
+ WITH_RDKIT = True
66
+
67
+ except ImportError:
68
+ WITH_RDKIT = False
69
+
70
+ if not WITH_RDKIT:
71
+ print(("Using a pre-processed version of the dataset. Please "
72
+ "install 'rdkit' to alternatively process the raw data."),
73
+ file=sys.stderr)
74
+
75
+ data_list = fs.torch_load(self.raw_paths[0])
76
+ data_list = [Data(**data_dict) for data_dict in data_list]
77
+
78
+ if self.pre_filter is not None:
79
+ data_list = [d for d in data_list if self.pre_filter(d)]
80
+
81
+ if self.pre_transform is not None:
82
+ data_list = [self.pre_transform(d) for d in data_list]
83
+
84
+ self.save(data_list, self.processed_paths[0])
85
+ return
86
+
87
+ # types of atom and bond
88
+ types = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4, 'Unknow': 5}
89
+ bonds = {BT.SINGLE: 0, BT.DOUBLE: 1, BT.TRIPLE: 2, BT.AROMATIC: 3}
90
+
91
+ # load data
92
+ mols = json.load(open(f'{self.raw_dir}/all_clean.json'))
93
+
94
+ data_list = []
95
+ for smiles, qa_pairs in tqdm(mols.items(), total=len(mols)):
96
+ mol = Chem.MolFromSmiles(smiles)
97
+ if mol is None:
98
+ continue
99
+
100
+ x: torch.Tensor = torch.tensor([
101
+ types[atom.GetSymbol()] if atom.GetSymbol() in types else 5
102
+ for atom in mol.GetAtoms()
103
+ ])
104
+ x = one_hot(x, num_classes=len(types), dtype=torch.float)
105
+
106
+ rows, cols, edge_types = [], [], []
107
+ for bond in mol.GetBonds():
108
+ i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
109
+ edge_types += [bonds[bond.GetBondType()]] * 2
110
+ rows += [i, j]
111
+ cols += [j, i]
112
+
113
+ edge_index = torch.tensor([rows, cols], dtype=torch.long)
114
+ edge_type = torch.tensor(edge_types, dtype=torch.long)
115
+ edge_attr = one_hot(edge_type, num_classes=len(bonds))
116
+
117
+ for question, answer in qa_pairs:
118
+ data = Data(
119
+ x=x,
120
+ edge_index=edge_index,
121
+ edge_attr=edge_attr,
122
+ smiles=smiles,
123
+ instruction=question,
124
+ y=answer,
125
+ )
126
+
127
+ if self.pre_filter is not None and not self.pre_filter(data):
128
+ continue
129
+ if self.pre_transform is not None:
130
+ data = self.pre_transform(data)
131
+
132
+ data_list.append(data)
133
+
134
+ self.save(data_list, self.processed_paths[0])
@@ -5,6 +5,7 @@ import numpy as np
5
5
  import torch
6
6
 
7
7
  from torch_geometric.data import Data, InMemoryDataset, download_url
8
+ from torch_geometric.io import fs
8
9
  from torch_geometric.utils import one_hot
9
10
 
10
11
 
@@ -115,9 +116,9 @@ class LINKXDataset(InMemoryDataset):
115
116
 
116
117
  def _process_wiki(self) -> Data:
117
118
  paths = {x.split('/')[-1]: x for x in self.raw_paths}
118
- x = torch.load(paths['wiki_features2M.pt'])
119
- edge_index = torch.load(paths['wiki_edges2M.pt']).t().contiguous()
120
- y = torch.load(paths['wiki_views2M.pt'])
119
+ x = fs.torch_load(paths['wiki_features2M.pt'])
120
+ edge_index = fs.torch_load(paths['wiki_edges2M.pt']).t().contiguous()
121
+ y = fs.torch_load(paths['wiki_views2M.pt'])
121
122
 
122
123
  return Data(x=x, edge_index=edge_index, y=y)
123
124
 
@@ -188,9 +188,8 @@ class LRGBDataset(InMemoryDataset):
188
188
  graphs = pickle.load(f)
189
189
  elif self.name.split('-')[0] == 'peptides':
190
190
  # Peptides-func and Peptides-struct
191
- with open(osp.join(self.raw_dir, f'{split}.pt'),
192
- 'rb') as f:
193
- graphs = torch.load(f)
191
+ graphs = fs.torch_load(
192
+ osp.join(self.raw_dir, f'{split}.pt'))
194
193
 
195
194
  data_list = []
196
195
  for graph in tqdm(graphs, desc=f'Processing {split} dataset'):
@@ -260,8 +259,7 @@ class LRGBDataset(InMemoryDataset):
260
259
 
261
260
  def process_pcqm_contact(self) -> None:
262
261
  for split in ['train', 'val', 'test']:
263
- with open(osp.join(self.raw_dir, f'{split}.pt'), 'rb') as f:
264
- graphs = torch.load(f)
262
+ graphs = fs.torch_load(osp.join(self.raw_dir, f'{split}.pt'))
265
263
 
266
264
  data_list = []
267
265
  for graph in tqdm(graphs, desc=f'Processing {split} dataset'):
@@ -11,6 +11,7 @@ from torch_geometric.data import (
11
11
  extract_tar,
12
12
  extract_zip,
13
13
  )
14
+ from torch_geometric.io import fs
14
15
 
15
16
 
16
17
  class MalNetTiny(InMemoryDataset):
@@ -65,7 +66,7 @@ class MalNetTiny(InMemoryDataset):
65
66
  self.load(self.processed_paths[0])
66
67
 
67
68
  if split is not None:
68
- split_slices = torch.load(self.processed_paths[1])
69
+ split_slices = fs.torch_load(self.processed_paths[1])
69
70
  if split == 'train':
70
71
  self._indices = range(split_slices[0], split_slices[1])
71
72
  elif split == 'val':
@@ -57,7 +57,7 @@ class MD17(InMemoryDataset):
57
57
  +--------------------+--------------------+-------------------------------+-----------+
58
58
  | Uracil | DFT | :obj:`uracil` | 133,770 |
59
59
  +--------------------+--------------------+-------------------------------+-----------+
60
- | Naphthalene | DFT | :obj:`napthalene` | 326,250 |
60
+ | Naphthalene | DFT | :obj:`naphthalene` | 326,250 |
61
61
  +--------------------+--------------------+-------------------------------+-----------+
62
62
  | Aspirin | DFT | :obj:`aspirin` | 211,762 |
63
63
  +--------------------+--------------------+-------------------------------+-----------+
@@ -77,7 +77,7 @@ class MD17(InMemoryDataset):
77
77
  +--------------------+--------------------+-------------------------------+-----------+
78
78
  | Uracil (R) | DFT (PBE/def2-SVP) | :obj:`revised uracil` | 100,000 |
79
79
  +--------------------+--------------------+-------------------------------+-----------+
80
- | Naphthalene (R) | DFT (PBE/def2-SVP) | :obj:`revised napthalene` | 100,000 |
80
+ | Naphthalene (R) | DFT (PBE/def2-SVP) | :obj:`revised naphthalene` | 100,000 |
81
81
  +--------------------+--------------------+-------------------------------+-----------+
82
82
  | Aspirin (R) | DFT (PBE/def2-SVP) | :obj:`revised aspirin` | 100,000 |
83
83
  +--------------------+--------------------+-------------------------------+-----------+
@@ -309,7 +309,7 @@ class MD17(InMemoryDataset):
309
309
  file_names = {
310
310
  'benzene': 'md17_benzene2017.npz',
311
311
  'uracil': 'md17_uracil.npz',
312
- 'naphtalene': 'md17_naphthalene.npz',
312
+ 'naphthalene': 'md17_naphthalene.npz',
313
313
  'aspirin': 'md17_aspirin.npz',
314
314
  'salicylic acid': 'md17_salicylic.npz',
315
315
  'malonaldehyde': 'md17_malonaldehyde.npz',
@@ -0,0 +1,145 @@
1
+ import os
2
+ import os.path as osp
3
+ from typing import Callable, List, Optional
4
+
5
+ import numpy as np
6
+ import torch
7
+
8
+ from torch_geometric.data import Data, InMemoryDataset
9
+
10
+
11
+ class MedShapeNet(InMemoryDataset):
12
+ r"""The MedShapeNet datasets from the `"MedShapeNet -- A Large-Scale
13
+ Dataset of 3D Medical Shapes for Computer Vision"
14
+ <https://arxiv.org/abs/2308.16139>`_ paper,
15
+ containing 8 different type of structures (classes).
16
+
17
+ .. note::
18
+
19
+ Data objects hold mesh faces instead of edge indices.
20
+ To convert the mesh to a graph, use the
21
+ :obj:`torch_geometric.transforms.FaceToEdge` as :obj:`pre_transform`.
22
+ To convert the mesh to a point cloud, use the
23
+ :obj:`torch_geometric.transforms.SamplePoints` as :obj:`transform` to
24
+ sample a fixed number of points on the mesh faces according to their
25
+ face area.
26
+
27
+ Args:
28
+ root (str): Root directory where the dataset should be saved.
29
+ size (int): Number of invividual 3D structures to download per
30
+ type (classes).
31
+ transform (callable, optional): A function/transform that takes in an
32
+ :obj:`torch_geometric.data.Data` object and returns a transformed
33
+ version. The data object will be transformed before every access.
34
+ (default: :obj:`None`)
35
+ pre_transform (callable, optional): A function/transform that takes in
36
+ an :obj:`torch_geometric.data.Data` object and returns a
37
+ transformed version. The data object will be transformed before
38
+ being saved to disk. (default: :obj:`None`)
39
+ pre_filter (callable, optional): A function that takes in an
40
+ :obj:`torch_geometric.data.Data` object and returns a boolean
41
+ value, indicating whether the data object should be included in the
42
+ final dataset. (default: :obj:`None`)
43
+ force_reload (bool, optional): Whether to re-process the dataset.
44
+ (default: :obj:`False`)
45
+ """
46
+ def __init__(
47
+ self,
48
+ root: str,
49
+ size: int = 100,
50
+ transform: Optional[Callable] = None,
51
+ pre_transform: Optional[Callable] = None,
52
+ pre_filter: Optional[Callable] = None,
53
+ force_reload: bool = False,
54
+ ) -> None:
55
+ self.size = size
56
+ super().__init__(root, transform, pre_transform, pre_filter,
57
+ force_reload=force_reload)
58
+
59
+ path = self.processed_paths[0]
60
+ self.load(path)
61
+
62
+ @property
63
+ def raw_file_names(self) -> List[str]:
64
+ return [
65
+ '3DTeethSeg', 'CoronaryArteries', 'FLARE', 'KITS', 'PULMONARY',
66
+ 'SurgicalInstruments', 'ThoracicAorta_Saitta', 'ToothFairy'
67
+ ]
68
+
69
+ @property
70
+ def processed_file_names(self) -> List[str]:
71
+ return ['dataset.pt']
72
+
73
+ @property
74
+ def raw_paths(self) -> List[str]:
75
+ r"""The absolute filepaths that must be present in order to skip
76
+ downloading.
77
+ """
78
+ return [osp.join(self.raw_dir, f) for f in self.raw_file_names]
79
+
80
+ def process(self) -> None:
81
+ import urllib3
82
+ from MedShapeNet import MedShapeNet as msn
83
+
84
+ msn_instance = msn(timeout=120)
85
+
86
+ urllib3.HTTPConnectionPool("medshapenet.ddns.net", maxsize=50)
87
+
88
+ list_of_datasets = msn_instance.datasets(False)
89
+ list_of_datasets = list(
90
+ filter(
91
+ lambda x: x not in [
92
+ 'medshapenetcore/ASOCA', 'medshapenetcore/AVT',
93
+ 'medshapenetcore/AutoImplantCraniotomy',
94
+ 'medshapenetcore/FaceVR'
95
+ ], list_of_datasets))
96
+
97
+ subset = []
98
+ for dataset in list_of_datasets:
99
+ parts = dataset.split("/")
100
+ self.newpath = self.root + '/' + parts[1 if len(parts) > 1 else 0]
101
+ if not os.path.exists(self.newpath):
102
+ os.makedirs(self.newpath)
103
+ stl_files = msn_instance.dataset_files(dataset, '.stl')
104
+ subset.extend(stl_files[:self.size])
105
+
106
+ for stl_file in stl_files[:self.size]:
107
+ msn_instance.download_stl_as_numpy(bucket_name=dataset,
108
+ stl_file=stl_file,
109
+ output_dir=self.newpath,
110
+ print_output=False)
111
+
112
+ class_mapping = {
113
+ '3DTeethSeg': 0,
114
+ 'CoronaryArteries': 1,
115
+ 'FLARE': 2,
116
+ 'KITS': 3,
117
+ 'PULMONARY': 4,
118
+ 'SurgicalInstruments': 5,
119
+ 'ThoracicAorta_Saitta': 6,
120
+ 'ToothFairy': 7
121
+ }
122
+
123
+ for dataset, path in zip([subset], self.processed_paths):
124
+ data_list = []
125
+ for item in dataset:
126
+ class_name = item.split("/")[0]
127
+ item = item.split("stl")[0]
128
+ target = class_mapping[class_name]
129
+ file = osp.join(self.root, item + 'npz')
130
+
131
+ data = np.load(file)
132
+ pre_data_list = Data(
133
+ pos=torch.tensor(data["vertices"], dtype=torch.float),
134
+ face=torch.tensor(data["faces"],
135
+ dtype=torch.long).t().contiguous())
136
+ pre_data_list.y = torch.tensor([target], dtype=torch.long)
137
+ data_list.append(pre_data_list)
138
+
139
+ if self.pre_filter is not None:
140
+ data_list = [d for d in data_list if self.pre_filter(d)]
141
+
142
+ if self.pre_transform is not None:
143
+ data_list = [self.pre_transform(d) for d in data_list]
144
+
145
+ self.save(data_list, path)