pyg-nightly 2.6.0.dev20240704__py3-none-any.whl → 2.8.0.dev20251207__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pyg-nightly might be problematic. Click here for more details.

Files changed (268) hide show
  1. {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/METADATA +81 -58
  2. {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/RECORD +265 -221
  3. {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/WHEEL +1 -1
  4. pyg_nightly-2.8.0.dev20251207.dist-info/licenses/LICENSE +19 -0
  5. torch_geometric/__init__.py +34 -1
  6. torch_geometric/_compile.py +11 -3
  7. torch_geometric/_onnx.py +228 -0
  8. torch_geometric/config_mixin.py +8 -3
  9. torch_geometric/config_store.py +1 -1
  10. torch_geometric/contrib/__init__.py +1 -1
  11. torch_geometric/contrib/explain/pgm_explainer.py +1 -1
  12. torch_geometric/data/__init__.py +19 -1
  13. torch_geometric/data/batch.py +2 -2
  14. torch_geometric/data/collate.py +1 -3
  15. torch_geometric/data/data.py +110 -6
  16. torch_geometric/data/database.py +19 -5
  17. torch_geometric/data/dataset.py +14 -9
  18. torch_geometric/data/extract.py +1 -1
  19. torch_geometric/data/feature_store.py +17 -22
  20. torch_geometric/data/graph_store.py +3 -2
  21. torch_geometric/data/hetero_data.py +139 -7
  22. torch_geometric/data/hypergraph_data.py +2 -2
  23. torch_geometric/data/in_memory_dataset.py +2 -2
  24. torch_geometric/data/lightning/datamodule.py +42 -28
  25. torch_geometric/data/storage.py +9 -1
  26. torch_geometric/datasets/__init__.py +20 -1
  27. torch_geometric/datasets/actor.py +7 -9
  28. torch_geometric/datasets/airfrans.py +17 -20
  29. torch_geometric/datasets/airports.py +8 -10
  30. torch_geometric/datasets/amazon.py +8 -11
  31. torch_geometric/datasets/amazon_book.py +8 -9
  32. torch_geometric/datasets/amazon_products.py +7 -9
  33. torch_geometric/datasets/aminer.py +8 -9
  34. torch_geometric/datasets/aqsol.py +10 -13
  35. torch_geometric/datasets/attributed_graph_dataset.py +8 -10
  36. torch_geometric/datasets/ba_multi_shapes.py +10 -12
  37. torch_geometric/datasets/ba_shapes.py +5 -6
  38. torch_geometric/datasets/brca_tgca.py +1 -1
  39. torch_geometric/datasets/city.py +157 -0
  40. torch_geometric/datasets/dbp15k.py +1 -1
  41. torch_geometric/datasets/gdelt_lite.py +3 -2
  42. torch_geometric/datasets/ged_dataset.py +3 -2
  43. torch_geometric/datasets/git_mol_dataset.py +263 -0
  44. torch_geometric/datasets/gnn_benchmark_dataset.py +3 -2
  45. torch_geometric/datasets/hgb_dataset.py +2 -2
  46. torch_geometric/datasets/hm.py +1 -1
  47. torch_geometric/datasets/instruct_mol_dataset.py +134 -0
  48. torch_geometric/datasets/linkx_dataset.py +4 -3
  49. torch_geometric/datasets/lrgb.py +3 -5
  50. torch_geometric/datasets/malnet_tiny.py +2 -1
  51. torch_geometric/datasets/md17.py +3 -3
  52. torch_geometric/datasets/medshapenet.py +145 -0
  53. torch_geometric/datasets/mnist_superpixels.py +2 -3
  54. torch_geometric/datasets/modelnet.py +1 -1
  55. torch_geometric/datasets/molecule_gpt_dataset.py +492 -0
  56. torch_geometric/datasets/molecule_net.py +3 -2
  57. torch_geometric/datasets/neurograph.py +1 -3
  58. torch_geometric/datasets/ogb_mag.py +1 -1
  59. torch_geometric/datasets/opf.py +19 -5
  60. torch_geometric/datasets/pascal_pf.py +1 -1
  61. torch_geometric/datasets/pcqm4m.py +2 -1
  62. torch_geometric/datasets/ppi.py +2 -1
  63. torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
  64. torch_geometric/datasets/qm7.py +1 -1
  65. torch_geometric/datasets/qm9.py +3 -2
  66. torch_geometric/datasets/shrec2016.py +2 -2
  67. torch_geometric/datasets/snap_dataset.py +8 -4
  68. torch_geometric/datasets/tag_dataset.py +462 -0
  69. torch_geometric/datasets/teeth3ds.py +269 -0
  70. torch_geometric/datasets/web_qsp_dataset.py +342 -0
  71. torch_geometric/datasets/wikics.py +2 -1
  72. torch_geometric/datasets/wikidata.py +2 -1
  73. torch_geometric/deprecation.py +1 -1
  74. torch_geometric/distributed/__init__.py +13 -0
  75. torch_geometric/distributed/dist_loader.py +2 -2
  76. torch_geometric/distributed/local_feature_store.py +3 -2
  77. torch_geometric/distributed/local_graph_store.py +2 -1
  78. torch_geometric/distributed/partition.py +9 -8
  79. torch_geometric/distributed/rpc.py +3 -3
  80. torch_geometric/edge_index.py +35 -22
  81. torch_geometric/explain/algorithm/attention_explainer.py +219 -29
  82. torch_geometric/explain/algorithm/base.py +2 -2
  83. torch_geometric/explain/algorithm/captum.py +1 -1
  84. torch_geometric/explain/algorithm/captum_explainer.py +2 -1
  85. torch_geometric/explain/algorithm/gnn_explainer.py +406 -69
  86. torch_geometric/explain/algorithm/graphmask_explainer.py +8 -8
  87. torch_geometric/explain/algorithm/pg_explainer.py +305 -47
  88. torch_geometric/explain/explainer.py +2 -2
  89. torch_geometric/explain/explanation.py +89 -5
  90. torch_geometric/explain/metric/faithfulness.py +1 -1
  91. torch_geometric/graphgym/checkpoint.py +2 -1
  92. torch_geometric/graphgym/config.py +3 -2
  93. torch_geometric/graphgym/imports.py +15 -4
  94. torch_geometric/graphgym/logger.py +1 -1
  95. torch_geometric/graphgym/loss.py +1 -1
  96. torch_geometric/graphgym/models/encoder.py +2 -2
  97. torch_geometric/graphgym/models/layer.py +1 -1
  98. torch_geometric/graphgym/utils/comp_budget.py +4 -3
  99. torch_geometric/hash_tensor.py +798 -0
  100. torch_geometric/index.py +16 -7
  101. torch_geometric/inspector.py +6 -2
  102. torch_geometric/io/fs.py +27 -0
  103. torch_geometric/io/tu.py +2 -3
  104. torch_geometric/llm/__init__.py +9 -0
  105. torch_geometric/llm/large_graph_indexer.py +741 -0
  106. torch_geometric/llm/models/__init__.py +23 -0
  107. torch_geometric/llm/models/g_retriever.py +251 -0
  108. torch_geometric/llm/models/git_mol.py +336 -0
  109. torch_geometric/llm/models/glem.py +397 -0
  110. torch_geometric/llm/models/llm.py +470 -0
  111. torch_geometric/llm/models/llm_judge.py +158 -0
  112. torch_geometric/llm/models/molecule_gpt.py +222 -0
  113. torch_geometric/llm/models/protein_mpnn.py +333 -0
  114. torch_geometric/llm/models/sentence_transformer.py +188 -0
  115. torch_geometric/llm/models/txt2kg.py +353 -0
  116. torch_geometric/llm/models/vision_transformer.py +38 -0
  117. torch_geometric/llm/rag_loader.py +154 -0
  118. torch_geometric/llm/utils/__init__.py +10 -0
  119. torch_geometric/llm/utils/backend_utils.py +443 -0
  120. torch_geometric/llm/utils/feature_store.py +169 -0
  121. torch_geometric/llm/utils/graph_store.py +199 -0
  122. torch_geometric/llm/utils/vectorrag.py +125 -0
  123. torch_geometric/loader/cluster.py +6 -5
  124. torch_geometric/loader/graph_saint.py +2 -1
  125. torch_geometric/loader/ibmb_loader.py +4 -4
  126. torch_geometric/loader/link_loader.py +1 -1
  127. torch_geometric/loader/link_neighbor_loader.py +2 -1
  128. torch_geometric/loader/mixin.py +6 -5
  129. torch_geometric/loader/neighbor_loader.py +1 -1
  130. torch_geometric/loader/neighbor_sampler.py +2 -2
  131. torch_geometric/loader/prefetch.py +4 -3
  132. torch_geometric/loader/temporal_dataloader.py +2 -2
  133. torch_geometric/loader/utils.py +10 -10
  134. torch_geometric/metrics/__init__.py +23 -2
  135. torch_geometric/metrics/link_pred.py +755 -85
  136. torch_geometric/nn/__init__.py +1 -0
  137. torch_geometric/nn/aggr/__init__.py +2 -0
  138. torch_geometric/nn/aggr/base.py +1 -1
  139. torch_geometric/nn/aggr/equilibrium.py +1 -1
  140. torch_geometric/nn/aggr/fused.py +1 -1
  141. torch_geometric/nn/aggr/patch_transformer.py +149 -0
  142. torch_geometric/nn/aggr/set_transformer.py +1 -1
  143. torch_geometric/nn/aggr/utils.py +9 -4
  144. torch_geometric/nn/attention/__init__.py +9 -1
  145. torch_geometric/nn/attention/polynormer.py +107 -0
  146. torch_geometric/nn/attention/qformer.py +71 -0
  147. torch_geometric/nn/attention/sgformer.py +99 -0
  148. torch_geometric/nn/conv/__init__.py +2 -0
  149. torch_geometric/nn/conv/appnp.py +1 -1
  150. torch_geometric/nn/conv/collect.jinja +6 -3
  151. torch_geometric/nn/conv/cugraph/gat_conv.py +8 -2
  152. torch_geometric/nn/conv/cugraph/rgcn_conv.py +3 -0
  153. torch_geometric/nn/conv/cugraph/sage_conv.py +3 -0
  154. torch_geometric/nn/conv/dna_conv.py +1 -1
  155. torch_geometric/nn/conv/eg_conv.py +7 -7
  156. torch_geometric/nn/conv/gat_conv.py +33 -4
  157. torch_geometric/nn/conv/gatv2_conv.py +35 -4
  158. torch_geometric/nn/conv/gen_conv.py +1 -1
  159. torch_geometric/nn/conv/general_conv.py +1 -1
  160. torch_geometric/nn/conv/gravnet_conv.py +2 -1
  161. torch_geometric/nn/conv/hetero_conv.py +3 -2
  162. torch_geometric/nn/conv/meshcnn_conv.py +487 -0
  163. torch_geometric/nn/conv/message_passing.py +6 -5
  164. torch_geometric/nn/conv/mixhop_conv.py +1 -1
  165. torch_geometric/nn/conv/rgcn_conv.py +2 -1
  166. torch_geometric/nn/conv/sg_conv.py +1 -1
  167. torch_geometric/nn/conv/spline_conv.py +2 -1
  168. torch_geometric/nn/conv/ssg_conv.py +1 -1
  169. torch_geometric/nn/conv/transformer_conv.py +5 -3
  170. torch_geometric/nn/data_parallel.py +5 -4
  171. torch_geometric/nn/dense/linear.py +5 -24
  172. torch_geometric/nn/encoding.py +17 -3
  173. torch_geometric/nn/fx.py +17 -15
  174. torch_geometric/nn/model_hub.py +5 -16
  175. torch_geometric/nn/models/__init__.py +11 -0
  176. torch_geometric/nn/models/attentive_fp.py +1 -1
  177. torch_geometric/nn/models/attract_repel.py +148 -0
  178. torch_geometric/nn/models/basic_gnn.py +2 -1
  179. torch_geometric/nn/models/captum.py +1 -1
  180. torch_geometric/nn/models/deep_graph_infomax.py +1 -1
  181. torch_geometric/nn/models/dimenet.py +2 -2
  182. torch_geometric/nn/models/dimenet_utils.py +4 -2
  183. torch_geometric/nn/models/gpse.py +1083 -0
  184. torch_geometric/nn/models/graph_unet.py +13 -4
  185. torch_geometric/nn/models/lpformer.py +783 -0
  186. torch_geometric/nn/models/metapath2vec.py +1 -1
  187. torch_geometric/nn/models/mlp.py +4 -2
  188. torch_geometric/nn/models/node2vec.py +1 -1
  189. torch_geometric/nn/models/polynormer.py +206 -0
  190. torch_geometric/nn/models/rev_gnn.py +3 -3
  191. torch_geometric/nn/models/schnet.py +2 -1
  192. torch_geometric/nn/models/sgformer.py +219 -0
  193. torch_geometric/nn/models/signed_gcn.py +1 -1
  194. torch_geometric/nn/models/visnet.py +2 -2
  195. torch_geometric/nn/norm/batch_norm.py +17 -7
  196. torch_geometric/nn/norm/diff_group_norm.py +7 -2
  197. torch_geometric/nn/norm/graph_norm.py +9 -4
  198. torch_geometric/nn/norm/instance_norm.py +5 -1
  199. torch_geometric/nn/norm/layer_norm.py +15 -7
  200. torch_geometric/nn/norm/msg_norm.py +8 -2
  201. torch_geometric/nn/pool/__init__.py +15 -9
  202. torch_geometric/nn/pool/cluster_pool.py +144 -0
  203. torch_geometric/nn/pool/connect/base.py +1 -3
  204. torch_geometric/nn/pool/edge_pool.py +1 -1
  205. torch_geometric/nn/pool/knn.py +13 -10
  206. torch_geometric/nn/pool/select/base.py +1 -4
  207. torch_geometric/nn/summary.py +1 -1
  208. torch_geometric/nn/to_hetero_module.py +4 -3
  209. torch_geometric/nn/to_hetero_transformer.py +3 -3
  210. torch_geometric/nn/to_hetero_with_bases_transformer.py +5 -5
  211. torch_geometric/profile/__init__.py +2 -0
  212. torch_geometric/profile/nvtx.py +66 -0
  213. torch_geometric/profile/profiler.py +18 -9
  214. torch_geometric/profile/utils.py +20 -5
  215. torch_geometric/sampler/__init__.py +2 -1
  216. torch_geometric/sampler/base.py +337 -8
  217. torch_geometric/sampler/hgt_sampler.py +11 -1
  218. torch_geometric/sampler/neighbor_sampler.py +298 -25
  219. torch_geometric/sampler/utils.py +93 -5
  220. torch_geometric/testing/__init__.py +4 -0
  221. torch_geometric/testing/decorators.py +35 -5
  222. torch_geometric/testing/distributed.py +1 -1
  223. torch_geometric/transforms/__init__.py +4 -0
  224. torch_geometric/transforms/add_gpse.py +49 -0
  225. torch_geometric/transforms/add_metapaths.py +10 -8
  226. torch_geometric/transforms/add_positional_encoding.py +2 -2
  227. torch_geometric/transforms/base_transform.py +2 -1
  228. torch_geometric/transforms/delaunay.py +65 -15
  229. torch_geometric/transforms/face_to_edge.py +32 -3
  230. torch_geometric/transforms/gdc.py +8 -9
  231. torch_geometric/transforms/largest_connected_components.py +1 -1
  232. torch_geometric/transforms/mask.py +5 -1
  233. torch_geometric/transforms/node_property_split.py +1 -1
  234. torch_geometric/transforms/normalize_features.py +3 -3
  235. torch_geometric/transforms/pad.py +1 -1
  236. torch_geometric/transforms/random_link_split.py +1 -1
  237. torch_geometric/transforms/remove_duplicated_edges.py +4 -2
  238. torch_geometric/transforms/remove_self_loops.py +36 -0
  239. torch_geometric/transforms/rooted_subgraph.py +1 -1
  240. torch_geometric/transforms/svd_feature_reduction.py +1 -1
  241. torch_geometric/transforms/virtual_node.py +2 -1
  242. torch_geometric/typing.py +82 -17
  243. torch_geometric/utils/__init__.py +6 -1
  244. torch_geometric/utils/_lexsort.py +0 -9
  245. torch_geometric/utils/_negative_sampling.py +28 -13
  246. torch_geometric/utils/_normalize_edge_index.py +46 -0
  247. torch_geometric/utils/_scatter.py +126 -164
  248. torch_geometric/utils/_sort_edge_index.py +0 -2
  249. torch_geometric/utils/_spmm.py +16 -14
  250. torch_geometric/utils/_subgraph.py +4 -0
  251. torch_geometric/utils/_tree_decomposition.py +1 -1
  252. torch_geometric/utils/_trim_to_layer.py +2 -2
  253. torch_geometric/utils/augmentation.py +1 -1
  254. torch_geometric/utils/convert.py +17 -10
  255. torch_geometric/utils/cross_entropy.py +34 -13
  256. torch_geometric/utils/embedding.py +91 -2
  257. torch_geometric/utils/geodesic.py +28 -25
  258. torch_geometric/utils/influence.py +279 -0
  259. torch_geometric/utils/map.py +14 -10
  260. torch_geometric/utils/nested.py +1 -1
  261. torch_geometric/utils/smiles.py +3 -3
  262. torch_geometric/utils/sparse.py +32 -24
  263. torch_geometric/visualization/__init__.py +2 -1
  264. torch_geometric/visualization/graph.py +250 -5
  265. torch_geometric/warnings.py +11 -2
  266. torch_geometric/nn/nlp/__init__.py +0 -7
  267. torch_geometric/nn/nlp/llm.py +0 -283
  268. torch_geometric/nn/nlp/sentence_transformer.py +0 -94
@@ -0,0 +1,269 @@
1
+ import json
2
+ import os
3
+ import os.path as osp
4
+ from glob import glob
5
+ from typing import Callable, Dict, List, Optional
6
+
7
+ import numpy as np
8
+ import torch
9
+ from tqdm import tqdm
10
+
11
+ from torch_geometric.data import (
12
+ Data,
13
+ InMemoryDataset,
14
+ download_url,
15
+ extract_zip,
16
+ )
17
+
18
+
19
+ class Teeth3DS(InMemoryDataset):
20
+ r"""The Teeth3DS+ dataset from the `"An Extended Benchmark for Intra-oral
21
+ 3D Scans Analysis" <https://crns-smartvision.github.io/teeth3ds/>`_ paper.
22
+
23
+ This dataset is the first comprehensive public benchmark designed to
24
+ advance the field of intra-oral 3D scan analysis developed as part of the
25
+ 3DTeethSeg 2022 and 3DTeethLand 2024 MICCAI challenges, aiming to drive
26
+ research in teeth identification, segmentation, labeling, 3D modeling,
27
+ and dental landmark identification.
28
+ The dataset includes at least 1,800 intra-oral scans (containing 23,999
29
+ annotated teeth) collected from 900 patients, covering both upper and lower
30
+ jaws separately.
31
+
32
+ Args:
33
+ root (str): Root directory where the dataset should be saved.
34
+ split (str): The split name (one of :obj:`"Teeth3DS"`,
35
+ :obj:`"3DTeethSeg22_challenge"` or :obj:`"3DTeethLand_challenge"`).
36
+ train (bool, optional): If :obj:`True`, loads the training dataset,
37
+ otherwise the test dataset. (default: :obj:`True`)
38
+ num_samples (int, optional): Number of points to sample from each mesh.
39
+ (default: :obj:`30000`)
40
+ transform (callable, optional): A function/transform that takes in an
41
+ :obj:`torch_geometric.data.Data` object and returns a transformed
42
+ version. The data object will be transformed before every access.
43
+ (default: :obj:`None`)
44
+ pre_transform (callable, optional): A function/transform that takes in
45
+ an :obj:`torch_geometric.data.Data` object and returns a
46
+ transformed version. The data object will be transformed before
47
+ being saved to disk. (default: :obj:`None`)
48
+ force_reload (bool, optional): Whether to re-process the dataset.
49
+ (default: :obj:`False`)
50
+ """
51
+ urls = {
52
+ 'data_part_1.zip':
53
+ 'https://osf.io/download/qhprs/',
54
+ 'data_part_2.zip':
55
+ 'https://osf.io/download/4pwnr/',
56
+ 'data_part_3.zip':
57
+ 'https://osf.io/download/frwdp/',
58
+ 'data_part_4.zip':
59
+ 'https://osf.io/download/2arn4/',
60
+ 'data_part_5.zip':
61
+ 'https://osf.io/download/xrz5f/',
62
+ 'data_part_6.zip':
63
+ 'https://osf.io/download/23hgq/',
64
+ 'data_part_7.zip':
65
+ 'https://osf.io/download/u83ad/',
66
+ 'train_test_split':
67
+ 'https://files.de-1.osf.io/v1/'
68
+ 'resources/xctdy/providers/osfstorage/?zip='
69
+ }
70
+
71
+ sample_url = {
72
+ 'teeth3ds_sample': 'https://osf.io/download/vr38s/',
73
+ }
74
+
75
+ landmarks_urls = {
76
+ '3DTeethLand_landmarks_train.zip': 'https://osf.io/download/k5hbj/',
77
+ '3DTeethLand_landmarks_test.zip': 'https://osf.io/download/sqw5e/',
78
+ }
79
+
80
+ def __init__(
81
+ self,
82
+ root: str,
83
+ split:
84
+ str = 'Teeth3DS', # [3DTeethSeg22_challenge, 3DTeethLand_challenge]
85
+ train: bool = True,
86
+ num_samples: int = 30000,
87
+ transform: Optional[Callable] = None,
88
+ pre_transform: Optional[Callable] = None,
89
+ force_reload: bool = False,
90
+ ) -> None:
91
+
92
+ self.mode = 'training' if train else 'testing'
93
+ self.split = split
94
+ self.num_samples = num_samples
95
+
96
+ super().__init__(root, transform, pre_transform,
97
+ force_reload=force_reload)
98
+
99
+ @property
100
+ def processed_dir(self) -> str:
101
+ return os.path.join(self.root, f'processed_{self.split}_{self.mode}')
102
+
103
+ @property
104
+ def raw_file_names(self) -> List[str]:
105
+ return ['license.txt']
106
+
107
+ @property
108
+ def processed_file_names(self) -> List[str]:
109
+ # Directory containing train/test split files:
110
+ split_subdir = 'teeth3ds_sample' if self.split == 'sample' else ''
111
+ split_dir = osp.join(
112
+ self.raw_dir,
113
+ split_subdir,
114
+ f'{self.split}_train_test_split',
115
+ )
116
+
117
+ split_files = glob(osp.join(split_dir, f'{self.mode}*.txt'))
118
+
119
+ # Collect all file names from the split files:
120
+ combined_list = []
121
+ for file_path in split_files:
122
+ with open(file_path) as file:
123
+ combined_list.extend(file.read().splitlines())
124
+
125
+ # Generate the list of processed file paths:
126
+ return [f'{file_name}.pt' for file_name in combined_list]
127
+
128
+ def download(self) -> None:
129
+ if self.split == 'sample':
130
+ for key, url in self.sample_url.items():
131
+ path = download_url(url, self.root, filename=key)
132
+ extract_zip(path, self.raw_dir)
133
+ os.unlink(path)
134
+ else:
135
+ for key, url in self.urls.items():
136
+ path = download_url(url, self.root, filename=key)
137
+ extract_zip(path, self.raw_dir)
138
+ os.unlink(path)
139
+ for key, url in self.landmarks_urls.items():
140
+ path = download_url(url, self.root, filename=key)
141
+ extract_zip(path, self.raw_dir) # Extract each downloaded part
142
+ os.unlink(path)
143
+
144
+ def process_file(self, file_path: str) -> Optional[Data]:
145
+ """Processes the input file path to load mesh data, annotations,
146
+ and prepare the input features for a graph-based deep learning model.
147
+ """
148
+ import trimesh
149
+ from fpsample import bucket_fps_kdline_sampling
150
+
151
+ mesh = trimesh.load_mesh(file_path)
152
+
153
+ if isinstance(mesh, list):
154
+ # Handle the case where a list of Geometry objects is returned
155
+ mesh = mesh[0]
156
+
157
+ vertices = mesh.vertices
158
+ vertex_normals = mesh.vertex_normals
159
+
160
+ # Perform sampling on mesh vertices:
161
+ if len(vertices) < self.num_samples:
162
+ sampled_indices = np.random.choice(
163
+ len(vertices),
164
+ self.num_samples,
165
+ replace=True,
166
+ )
167
+ else:
168
+ sampled_indices = bucket_fps_kdline_sampling(
169
+ vertices,
170
+ self.num_samples,
171
+ h=5,
172
+ start_idx=0,
173
+ )
174
+
175
+ if len(sampled_indices) != self.num_samples:
176
+ raise RuntimeError(f"Sampled points mismatch, expected "
177
+ f"{self.num_samples} points, but got "
178
+ f"{len(sampled_indices)} for '{file_path}'")
179
+
180
+ # Extract features and annotations for the sampled points:
181
+ pos = torch.tensor(vertices[sampled_indices], dtype=torch.float)
182
+ x = torch.tensor(vertex_normals[sampled_indices], dtype=torch.float)
183
+
184
+ # Load segmentation annotations:
185
+ seg_annotation_path = file_path.replace('.obj', '.json')
186
+ if osp.exists(seg_annotation_path):
187
+ with open(seg_annotation_path) as f:
188
+ seg_annotations = json.load(f)
189
+ y = torch.tensor(
190
+ np.asarray(seg_annotations['labels'])[sampled_indices],
191
+ dtype=torch.float)
192
+ instances = torch.tensor(
193
+ np.asarray(seg_annotations['instances'])[sampled_indices],
194
+ dtype=torch.float)
195
+ else:
196
+ y = torch.empty(0, 3)
197
+ instances = torch.empty(0, 3)
198
+
199
+ # Load landmarks annotations:
200
+ landmarks_annotation_path = file_path.replace('.obj', '__kpt.json')
201
+
202
+ # Parse keypoint annotations into structured tensors:
203
+ keypoints_dict: Dict[str, List] = {
204
+ key: []
205
+ for key in [
206
+ 'Mesial', 'Distal', 'Cusp', 'InnerPoint', 'OuterPoint',
207
+ 'FacialPoint'
208
+ ]
209
+ }
210
+ keypoint_tensors: Dict[str, torch.Tensor] = {
211
+ key: torch.empty(0, 3)
212
+ for key in [
213
+ 'Mesial', 'Distal', 'Cusp', 'InnerPoint', 'OuterPoint',
214
+ 'FacialPoint'
215
+ ]
216
+ }
217
+ if osp.exists(landmarks_annotation_path):
218
+ with open(landmarks_annotation_path) as f:
219
+ landmarks_annotations = json.load(f)
220
+
221
+ for keypoint in landmarks_annotations['objects']:
222
+ keypoints_dict[keypoint['class']].extend(keypoint['coord'])
223
+
224
+ keypoint_tensors = {
225
+ k: torch.tensor(np.asarray(v),
226
+ dtype=torch.float).reshape(-1, 3)
227
+ for k, v in keypoints_dict.items()
228
+ }
229
+
230
+ data = Data(
231
+ pos=pos,
232
+ x=x,
233
+ y=y,
234
+ instances=instances,
235
+ jaw=file_path.split('.obj')[0].split('_')[1],
236
+ mesial=keypoint_tensors['Mesial'],
237
+ distal=keypoint_tensors['Distal'],
238
+ cusp=keypoint_tensors['Cusp'],
239
+ inner_point=keypoint_tensors['InnerPoint'],
240
+ outer_point=keypoint_tensors['OuterPoint'],
241
+ facial_point=keypoint_tensors['FacialPoint'],
242
+ )
243
+
244
+ if self.pre_transform is not None:
245
+ data = self.pre_transform(data)
246
+
247
+ return data
248
+
249
+ def process(self) -> None:
250
+ for file in tqdm(self.processed_file_names):
251
+ name = file.split('.')[0]
252
+ path = osp.join(self.raw_dir, '**', '*', name + '.obj')
253
+ paths = glob(path)
254
+ if len(paths) == 1:
255
+ data = self.process_file(paths[0])
256
+ torch.save(data, osp.join(self.processed_dir, file))
257
+
258
+ def len(self) -> int:
259
+ return len(self.processed_file_names)
260
+
261
+ def get(self, idx: int) -> Data:
262
+ return torch.load(
263
+ osp.join(self.processed_dir, self.processed_file_names[idx]),
264
+ weights_only=False,
265
+ )
266
+
267
+ def __repr__(self) -> str:
268
+ return (f'{self.__class__.__name__}({len(self)}, '
269
+ f'mode={self.mode}, split={self.split})')
@@ -0,0 +1,342 @@
1
+ # Code adapted from the G-Retriever paper: https://arxiv.org/abs/2402.07630
2
+ import gc
3
+ import os
4
+ from itertools import chain
5
+ from typing import Any, Dict, Iterator, List, Optional
6
+
7
+ import torch
8
+ from tqdm import tqdm
9
+
10
+ from torch_geometric.data import InMemoryDataset
11
+ from torch_geometric.llm.large_graph_indexer import (
12
+ EDGE_RELATION,
13
+ LargeGraphIndexer,
14
+ TripletLike,
15
+ get_features_for_triplets_groups,
16
+ )
17
+ from torch_geometric.llm.models import SentenceTransformer
18
+ from torch_geometric.llm.utils.backend_utils import (
19
+ preprocess_triplet,
20
+ retrieval_via_pcst,
21
+ )
22
+
23
+
24
+ class KGQABaseDataset(InMemoryDataset):
25
+ r"""Base class for the 2 KGQA datasets used in `"Reasoning on Graphs:
26
+ Faithful and Interpretable Large Language Model Reasoning"
27
+ <https://arxiv.org/pdf/2310.01061>`_ paper.
28
+
29
+ Args:
30
+ dataset_name (str): HuggingFace `dataset` name.
31
+ root (str): Root directory where the dataset should be saved.
32
+ split (str, optional): If :obj:`"train"`, loads the training dataset.
33
+ If :obj:`"val"`, loads the validation dataset.
34
+ If :obj:`"test"`, loads the test dataset. (default: :obj:`"train"`)
35
+ force_reload (bool, optional): Whether to re-process the dataset.
36
+ (default: :obj:`False`)
37
+ verbose (bool, optional): Whether to print output. Defaults to False.
38
+ use_pcst (bool, optional): Whether to preprocess the dataset's graph
39
+ with PCST or return the full graphs. (default: :obj:`True`)
40
+ load_dataset_kwargs (dict, optional):
41
+ Keyword arguments for the `datasets.load_dataset` function.
42
+ (default: :obj:`{}`)
43
+ retrieval_kwargs (dict, optional):
44
+ Keyword arguments for the
45
+ `get_features_for_triplets_groups` function.
46
+ (default: :obj:`{}`)
47
+ """
48
+ def __init__(
49
+ self,
50
+ dataset_name: str,
51
+ root: str,
52
+ split: str = "train",
53
+ force_reload: bool = False,
54
+ verbose: bool = False,
55
+ use_pcst: bool = True,
56
+ load_dataset_kwargs: Optional[Dict[str, Any]] = None,
57
+ retrieval_kwargs: Optional[Dict[str, Any]] = None,
58
+ ) -> None:
59
+ self.split = split
60
+ self.dataset_name = dataset_name
61
+ self.use_pcst = use_pcst
62
+ self.load_dataset_kwargs = load_dataset_kwargs or {}
63
+ """
64
+ NOTE: If running into memory issues,
65
+ try reducing this batch size for the LargeGraphIndexer
66
+ used to build our KG.
67
+ Example: self.retrieval_kwargs = {"batch_size": 64}
68
+ """
69
+ self.retrieval_kwargs = retrieval_kwargs or {}
70
+
71
+ # Caching custom subsets of the dataset results in unsupported behavior
72
+ if 'split' in self.load_dataset_kwargs:
73
+ print("WARNING: Caching custom subsets of the dataset \
74
+ results in unsupported behavior.\
75
+ Please specify a separate root directory for each split,\
76
+ or set force_reload=True on subsequent instantiations\
77
+ of the dataset.")
78
+
79
+ self.required_splits = ['train', 'validation', 'test']
80
+
81
+ self.verbose = verbose
82
+ self.force_reload = force_reload
83
+ super().__init__(root, force_reload=force_reload)
84
+ """
85
+ NOTE: Current behavior is to process the entire dataset,
86
+ and only return the split specified by the user.
87
+ """
88
+ if f'{split}_data.pt' not in set(self.processed_file_names):
89
+ raise ValueError(f"Invalid 'split' argument (got {split})")
90
+ if split == 'val':
91
+ split = 'validation'
92
+
93
+ self.load(self.processed_paths[self.required_splits.index(split)])
94
+
95
+ @property
96
+ def raw_file_names(self) -> List[str]:
97
+ return ["raw.pt"]
98
+
99
+ @property
100
+ def processed_file_names(self) -> List[str]:
101
+ return ["train_data.pt", "val_data.pt", "test_data.pt"]
102
+
103
+ def download(self) -> None:
104
+ import datasets
105
+
106
+ # HF Load Dataset by dataset name if no path is specified
107
+ self.load_dataset_kwargs['path'] = self.load_dataset_kwargs.get(
108
+ 'path', self.dataset_name)
109
+ raw_dataset = datasets.load_dataset(**self.load_dataset_kwargs)
110
+
111
+ # Assert that the dataset contains the required splits
112
+ assert all(split in raw_dataset for split in self.required_splits), \
113
+ f"Dataset '{self.dataset_name}' is missing required splits: \
114
+ {self.required_splits}"
115
+
116
+ raw_dataset.save_to_disk(self.raw_paths[0])
117
+
118
+ def _get_trips(self) -> Iterator[TripletLike]:
119
+ # Iterate over each element's graph in each split of the dataset
120
+ # Using chain to lazily iterate without storing all trips in memory
121
+ split_iterators = []
122
+
123
+ for split in self.required_splits:
124
+ # Create an iterator for each element's graph in the current split
125
+ split_graphs = (element['graph']
126
+ for element in self.raw_dataset[split])
127
+ split_iterators.append(chain.from_iterable(split_graphs))
128
+
129
+ # Chain all split iterators together
130
+ return chain.from_iterable(split_iterators)
131
+
132
+ def _build_graph(self) -> None:
133
+ print("Encoding graph...")
134
+ trips = self._get_trips()
135
+ self.indexer: LargeGraphIndexer = LargeGraphIndexer.from_triplets(
136
+ trips, pre_transform=preprocess_triplet)
137
+
138
+ # Nodes:
139
+ print("\tEncoding nodes...")
140
+ nodes = self.indexer.get_unique_node_features()
141
+ x = self.model.encode(nodes, batch_size=256, output_device='cpu')
142
+ self.indexer.add_node_feature(new_feature_name="x", new_feature_vals=x)
143
+
144
+ # Edges:
145
+ print("\tEncoding edges...")
146
+ edges = self.indexer.get_unique_edge_features(
147
+ feature_name=EDGE_RELATION)
148
+ edge_attr = self.model.encode(edges, batch_size=256,
149
+ output_device='cpu')
150
+ self.indexer.add_edge_feature(
151
+ new_feature_name="edge_attr",
152
+ new_feature_vals=edge_attr,
153
+ map_from_feature=EDGE_RELATION,
154
+ )
155
+
156
+ print("\tSaving graph...")
157
+ self.indexer.save(self.indexer_path)
158
+
159
+ def _retrieve_subgraphs(self) -> None:
160
+ raw_splits = [
161
+ self.raw_dataset[split] for split in self.required_splits
162
+ ]
163
+ zipped = zip(
164
+ self.required_splits,
165
+ raw_splits, # noqa
166
+ self.processed_paths,
167
+ )
168
+ for split_name, dataset, path in zipped:
169
+ print(f"Processing {split_name} split...")
170
+
171
+ print("\tEncoding questions...")
172
+ split_questions = [str(element['question']) for element in dataset]
173
+ split_q_embs = self.model.encode(split_questions, batch_size=256,
174
+ output_device='cpu')
175
+
176
+ print("\tRetrieving subgraphs...")
177
+ results_graphs = []
178
+ retrieval_kwargs = {
179
+ **self.retrieval_kwargs,
180
+ **{
181
+ 'pre_transform': preprocess_triplet,
182
+ 'verbose': self.verbose,
183
+ }
184
+ }
185
+ graph_gen = get_features_for_triplets_groups(
186
+ self.indexer, (element['graph'] for element in dataset),
187
+ **retrieval_kwargs)
188
+
189
+ for index in tqdm(range(len(dataset)), disable=not self.verbose):
190
+ data_i = dataset[index]
191
+ graph = next(graph_gen)
192
+ textual_nodes = self.textual_nodes.iloc[
193
+ graph["node_idx"]].reset_index()
194
+ textual_edges = self.textual_edges.iloc[
195
+ graph["edge_idx"]].reset_index()
196
+ if self.use_pcst and len(textual_nodes) > 0 and len(
197
+ textual_edges) > 0:
198
+ subgraph, desc = retrieval_via_pcst(
199
+ graph,
200
+ split_q_embs[index],
201
+ textual_nodes,
202
+ textual_edges,
203
+ )
204
+ else:
205
+ desc = textual_nodes.to_csv(
206
+ index=False) + "\n" + textual_edges.to_csv(
207
+ index=False,
208
+ columns=["src", "edge_attr", "dst"],
209
+ )
210
+ subgraph = graph
211
+ question = f"Question: {data_i['question']}\nAnswer: "
212
+ label = ("|").join(data_i["answer"]).lower()
213
+
214
+ subgraph["question"] = question
215
+ subgraph["label"] = label
216
+ subgraph["desc"] = desc
217
+ results_graphs.append(subgraph.to("cpu"))
218
+ print("\tSaving subgraphs...")
219
+ self.save(results_graphs, path)
220
+
221
+ def process(self) -> None:
222
+ import datasets
223
+ from pandas import DataFrame
224
+ self.raw_dataset = datasets.load_from_disk(self.raw_paths[0])
225
+
226
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
227
+ model_name = 'sentence-transformers/all-roberta-large-v1'
228
+ self.model: SentenceTransformer = SentenceTransformer(model_name).to(
229
+ device)
230
+ self.model.eval()
231
+ self.indexer_path = os.path.join(self.processed_dir,
232
+ "large_graph_indexer")
233
+ if self.force_reload or not os.path.exists(self.indexer_path):
234
+ self._build_graph()
235
+ else:
236
+ print("Loading graph...")
237
+ self.indexer = LargeGraphIndexer.from_disk(self.indexer_path)
238
+ self.textual_nodes = DataFrame.from_dict(
239
+ {"node_attr": self.indexer.get_node_features()})
240
+ self.textual_nodes["node_id"] = self.textual_nodes.index
241
+ self.textual_nodes = self.textual_nodes[["node_id", "node_attr"]]
242
+ self.textual_edges = DataFrame(self.indexer.get_edge_features(),
243
+ columns=["src", "edge_attr", "dst"])
244
+ self.textual_edges["src"] = [
245
+ self.indexer._nodes[h] for h in self.textual_edges["src"]
246
+ ]
247
+ self.textual_edges["dst"] = [
248
+ self.indexer._nodes[h] for h in self.textual_edges["dst"]
249
+ ]
250
+ self._retrieve_subgraphs()
251
+
252
+ gc.collect()
253
+ torch.cuda.empty_cache()
254
+
255
+
256
+ class WebQSPDataset(KGQABaseDataset):
257
+ r"""The WebQuestionsSP dataset of the `"The Value of Semantic Parse
258
+ Labeling for Knowledge Base Question Answering"
259
+ <https://aclanthology.org/P16-2033/>`_ paper.
260
+
261
+ Args:
262
+ root (str): Root directory where the dataset should be saved.
263
+ split (str, optional): If :obj:`"train"`, loads the training dataset.
264
+ If :obj:`"val"`, loads the validation dataset.
265
+ If :obj:`"test"`, loads the test dataset. (default: :obj:`"train"`)
266
+ force_reload (bool, optional): Whether to re-process the dataset.
267
+ (default: :obj:`False`)
268
+ verbose (bool, optional): Whether to print output. Defaults to False.
269
+ use_pcst (bool, optional): Whether to preprocess the dataset's graph
270
+ with PCST or return the full graphs. (default: :obj:`True`)
271
+ load_dataset_kwargs (dict, optional):
272
+ Keyword arguments for the `datasets.load_dataset` function.
273
+ (default: :obj:`{}`)
274
+ retrieval_kwargs (dict, optional):
275
+ Keyword arguments for the
276
+ `get_features_for_triplets_groups` function.
277
+ (default: :obj:`{}`)
278
+ """
279
+ def __init__(
280
+ self,
281
+ root: str,
282
+ split: str = "train",
283
+ force_reload: bool = False,
284
+ verbose: bool = False,
285
+ use_pcst: bool = True,
286
+ load_dataset_kwargs: Optional[Dict[str, Any]] = None,
287
+ retrieval_kwargs: Optional[Dict[str, Any]] = None,
288
+ ) -> None:
289
+ load_dataset_kwargs = load_dataset_kwargs or {}
290
+ retrieval_kwargs = retrieval_kwargs or {}
291
+ # Modify these paramters if running into memory/compute issues
292
+ default_retrieval_kwargs = {
293
+ 'max_batch_size': 250, # Lower batch size to reduce memory usage
294
+ 'num_workers':
295
+ None, # Use all available workers, or set to number of threads
296
+ }
297
+ retrieval_kwargs = {**default_retrieval_kwargs, **retrieval_kwargs}
298
+ dataset_name = 'rmanluo/RoG-webqsp'
299
+ super().__init__(dataset_name, root, split, force_reload, verbose,
300
+ use_pcst, load_dataset_kwargs=load_dataset_kwargs,
301
+ retrieval_kwargs=retrieval_kwargs)
302
+
303
+
304
+ class CWQDataset(KGQABaseDataset):
305
+ r"""The ComplexWebQuestions (CWQ) dataset of the `"The Web as a
306
+ Knowledge-base forAnswering Complex Questions"
307
+ <https://arxiv.org/pdf/1803.06643>`_ paper.
308
+
309
+ Args:
310
+ root (str): Root directory where the dataset should be saved.
311
+ split (str, optional): If :obj:`"train"`, loads the training dataset.
312
+ If :obj:`"val"`, loads the validation dataset.
313
+ If :obj:`"test"`, loads the test dataset. (default: :obj:`"train"`)
314
+ force_reload (bool, optional): Whether to re-process the dataset.
315
+ (default: :obj:`False`)
316
+ verbose (bool, optional): Whether to print output. Defaults to False.
317
+ use_pcst (bool, optional): Whether to preprocess the dataset's graph
318
+ with PCST or return the full graphs. (default: :obj:`True`)
319
+ load_dataset_kwargs (dict, optional):
320
+ Keyword arguments for the `datasets.load_dataset` function.
321
+ (default: :obj:`{}`)
322
+ retrieval_kwargs (dict, optional):
323
+ Keyword arguments for the
324
+ `get_features_for_triplets_groups` function.
325
+ (default: :obj:`{}`)
326
+ """
327
+ def __init__(
328
+ self,
329
+ root: str,
330
+ split: str = "train",
331
+ force_reload: bool = False,
332
+ verbose: bool = False,
333
+ use_pcst: bool = True,
334
+ load_dataset_kwargs: Optional[Dict[str, Any]] = None,
335
+ retrieval_kwargs: Optional[Dict[str, Any]] = None,
336
+ ) -> None:
337
+ load_dataset_kwargs = load_dataset_kwargs or {}
338
+ retrieval_kwargs = retrieval_kwargs or {}
339
+ dataset_name = 'rmanluo/RoG-cwq'
340
+ super().__init__(dataset_name, root, split, force_reload, verbose,
341
+ use_pcst, load_dataset_kwargs=load_dataset_kwargs,
342
+ retrieval_kwargs=retrieval_kwargs)
@@ -45,7 +45,8 @@ class WikiCS(InMemoryDataset):
45
45
  warnings.warn(
46
46
  f"The {self.__class__.__name__} dataset now returns an "
47
47
  f"undirected graph by default. Please explicitly specify "
48
- f"'is_undirected=False' to restore the old behavior.")
48
+ f"'is_undirected=False' to restore the old behavior.",
49
+ stacklevel=2)
49
50
  is_undirected = True
50
51
  self.is_undirected = is_undirected
51
52
  super().__init__(root, transform, pre_transform,
@@ -10,6 +10,7 @@ from torch_geometric.data import (
10
10
  download_url,
11
11
  extract_tar,
12
12
  )
13
+ from torch_geometric.io import fs
13
14
 
14
15
 
15
16
  class Wikidata5M(InMemoryDataset):
@@ -99,7 +100,7 @@ class Wikidata5M(InMemoryDataset):
99
100
  values = line.strip().split('\t')
100
101
  entity_to_id[values[0]] = i
101
102
 
102
- x = torch.load(self.raw_paths[1])
103
+ x = fs.torch_load(self.raw_paths[1])
103
104
 
104
105
  edge_indices = []
105
106
  edge_types = []
@@ -23,7 +23,7 @@ def deprecated(
23
23
  out = f"'{name}' is deprecated"
24
24
  if details is not None:
25
25
  out += f", {details}"
26
- warnings.warn(out)
26
+ warnings.warn(out, stacklevel=2)
27
27
  return func(*args, **kwargs)
28
28
 
29
29
  return wrapper
@@ -1,3 +1,5 @@
1
+ from warnings import warn
2
+
1
3
  from .dist_context import DistContext
2
4
  from .local_feature_store import LocalFeatureStore
3
5
  from .local_graph_store import LocalGraphStore
@@ -7,6 +9,17 @@ from .dist_loader import DistLoader
7
9
  from .dist_neighbor_loader import DistNeighborLoader
8
10
  from .dist_link_neighbor_loader import DistLinkNeighborLoader
9
11
 
12
+ warn(
13
+ "`torch_geometric.distributed` has been deprecated since 2.7.0 and will "
14
+ "no longer be maintained. For distributed training, refer to our "
15
+ "tutorials on distributed training at "
16
+ "https://pytorch-geometric.readthedocs.io/en/latest/tutorial/distributed.html " # noqa: E501
17
+ "or cuGraph examples at "
18
+ "https://github.com/rapidsai/cugraph-gnn/tree/main/python/cugraph-pyg/cugraph_pyg/examples", # noqa: E501
19
+ stacklevel=2,
20
+ category=DeprecationWarning,
21
+ )
22
+
10
23
  __all__ = classes = [
11
24
  'DistContext',
12
25
  'LocalFeatureStore',
@@ -138,9 +138,9 @@ class DistLoader:
138
138
  # close RPC & worker group at exit:
139
139
  atexit.register(shutdown_rpc, self.current_ctx_worker.worker_name)
140
140
 
141
- except RuntimeError:
141
+ except RuntimeError as e:
142
142
  raise RuntimeError(f"`{self}.init_fn()` could not initialize the "
143
- f"worker loop of the neighbor sampler")
143
+ f"worker loop of the neighbor sampler") from e
144
144
 
145
145
  def __repr__(self) -> str:
146
146
  return f'{self.__class__.__name__}(pid={self.pid})'