pyg-nightly 2.7.0.dev20241009__py3-none-any.whl → 2.8.0.dev20251228__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (229) hide show
  1. {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251228.dist-info}/METADATA +77 -53
  2. {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251228.dist-info}/RECORD +227 -190
  3. {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251228.dist-info}/WHEEL +1 -1
  4. pyg_nightly-2.8.0.dev20251228.dist-info/licenses/LICENSE +19 -0
  5. torch_geometric/__init__.py +14 -2
  6. torch_geometric/_compile.py +9 -3
  7. torch_geometric/_onnx.py +214 -0
  8. torch_geometric/config_mixin.py +5 -3
  9. torch_geometric/config_store.py +1 -1
  10. torch_geometric/contrib/__init__.py +1 -1
  11. torch_geometric/contrib/explain/pgm_explainer.py +1 -1
  12. torch_geometric/data/batch.py +2 -2
  13. torch_geometric/data/collate.py +1 -3
  14. torch_geometric/data/data.py +109 -5
  15. torch_geometric/data/database.py +4 -0
  16. torch_geometric/data/dataset.py +14 -11
  17. torch_geometric/data/extract.py +1 -1
  18. torch_geometric/data/feature_store.py +17 -22
  19. torch_geometric/data/graph_store.py +3 -2
  20. torch_geometric/data/hetero_data.py +139 -7
  21. torch_geometric/data/hypergraph_data.py +2 -2
  22. torch_geometric/data/in_memory_dataset.py +2 -2
  23. torch_geometric/data/lightning/datamodule.py +42 -28
  24. torch_geometric/data/storage.py +9 -1
  25. torch_geometric/datasets/__init__.py +18 -1
  26. torch_geometric/datasets/actor.py +7 -9
  27. torch_geometric/datasets/airfrans.py +15 -17
  28. torch_geometric/datasets/airports.py +8 -10
  29. torch_geometric/datasets/amazon.py +8 -11
  30. torch_geometric/datasets/amazon_book.py +8 -9
  31. torch_geometric/datasets/amazon_products.py +7 -9
  32. torch_geometric/datasets/aminer.py +8 -9
  33. torch_geometric/datasets/aqsol.py +10 -13
  34. torch_geometric/datasets/attributed_graph_dataset.py +8 -10
  35. torch_geometric/datasets/ba_multi_shapes.py +10 -12
  36. torch_geometric/datasets/ba_shapes.py +5 -6
  37. torch_geometric/datasets/city.py +157 -0
  38. torch_geometric/datasets/dbp15k.py +1 -1
  39. torch_geometric/datasets/git_mol_dataset.py +263 -0
  40. torch_geometric/datasets/hgb_dataset.py +2 -2
  41. torch_geometric/datasets/hm.py +1 -1
  42. torch_geometric/datasets/instruct_mol_dataset.py +134 -0
  43. torch_geometric/datasets/md17.py +3 -3
  44. torch_geometric/datasets/medshapenet.py +145 -0
  45. torch_geometric/datasets/modelnet.py +1 -1
  46. torch_geometric/datasets/molecule_gpt_dataset.py +492 -0
  47. torch_geometric/datasets/molecule_net.py +3 -2
  48. torch_geometric/datasets/ppi.py +2 -1
  49. torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
  50. torch_geometric/datasets/qm7.py +1 -1
  51. torch_geometric/datasets/qm9.py +1 -1
  52. torch_geometric/datasets/snap_dataset.py +8 -4
  53. torch_geometric/datasets/tag_dataset.py +462 -0
  54. torch_geometric/datasets/teeth3ds.py +269 -0
  55. torch_geometric/datasets/web_qsp_dataset.py +310 -209
  56. torch_geometric/datasets/wikics.py +2 -1
  57. torch_geometric/deprecation.py +1 -1
  58. torch_geometric/distributed/__init__.py +13 -0
  59. torch_geometric/distributed/dist_loader.py +2 -2
  60. torch_geometric/distributed/partition.py +2 -2
  61. torch_geometric/distributed/rpc.py +3 -3
  62. torch_geometric/edge_index.py +18 -14
  63. torch_geometric/explain/algorithm/attention_explainer.py +219 -29
  64. torch_geometric/explain/algorithm/base.py +2 -2
  65. torch_geometric/explain/algorithm/captum.py +1 -1
  66. torch_geometric/explain/algorithm/captum_explainer.py +2 -1
  67. torch_geometric/explain/algorithm/gnn_explainer.py +406 -69
  68. torch_geometric/explain/algorithm/graphmask_explainer.py +8 -8
  69. torch_geometric/explain/algorithm/pg_explainer.py +305 -47
  70. torch_geometric/explain/explainer.py +2 -2
  71. torch_geometric/explain/explanation.py +87 -3
  72. torch_geometric/explain/metric/faithfulness.py +1 -1
  73. torch_geometric/graphgym/config.py +3 -2
  74. torch_geometric/graphgym/imports.py +15 -4
  75. torch_geometric/graphgym/logger.py +1 -1
  76. torch_geometric/graphgym/loss.py +1 -1
  77. torch_geometric/graphgym/models/encoder.py +2 -2
  78. torch_geometric/graphgym/models/layer.py +1 -1
  79. torch_geometric/graphgym/utils/comp_budget.py +4 -3
  80. torch_geometric/hash_tensor.py +798 -0
  81. torch_geometric/index.py +14 -5
  82. torch_geometric/inspector.py +4 -0
  83. torch_geometric/io/fs.py +5 -4
  84. torch_geometric/llm/__init__.py +9 -0
  85. torch_geometric/llm/large_graph_indexer.py +741 -0
  86. torch_geometric/llm/models/__init__.py +23 -0
  87. torch_geometric/{nn → llm}/models/g_retriever.py +77 -45
  88. torch_geometric/llm/models/git_mol.py +336 -0
  89. torch_geometric/llm/models/glem.py +397 -0
  90. torch_geometric/{nn/nlp → llm/models}/llm.py +180 -32
  91. torch_geometric/llm/models/llm_judge.py +158 -0
  92. torch_geometric/llm/models/molecule_gpt.py +222 -0
  93. torch_geometric/llm/models/protein_mpnn.py +333 -0
  94. torch_geometric/llm/models/sentence_transformer.py +188 -0
  95. torch_geometric/llm/models/txt2kg.py +353 -0
  96. torch_geometric/llm/models/vision_transformer.py +38 -0
  97. torch_geometric/llm/rag_loader.py +154 -0
  98. torch_geometric/llm/utils/__init__.py +10 -0
  99. torch_geometric/llm/utils/backend_utils.py +443 -0
  100. torch_geometric/llm/utils/feature_store.py +169 -0
  101. torch_geometric/llm/utils/graph_store.py +199 -0
  102. torch_geometric/llm/utils/vectorrag.py +125 -0
  103. torch_geometric/loader/cluster.py +4 -4
  104. torch_geometric/loader/ibmb_loader.py +4 -4
  105. torch_geometric/loader/link_loader.py +1 -1
  106. torch_geometric/loader/link_neighbor_loader.py +2 -1
  107. torch_geometric/loader/mixin.py +6 -5
  108. torch_geometric/loader/neighbor_loader.py +1 -1
  109. torch_geometric/loader/neighbor_sampler.py +2 -2
  110. torch_geometric/loader/prefetch.py +3 -2
  111. torch_geometric/loader/temporal_dataloader.py +2 -2
  112. torch_geometric/loader/utils.py +10 -10
  113. torch_geometric/metrics/__init__.py +14 -0
  114. torch_geometric/metrics/link_pred.py +745 -92
  115. torch_geometric/nn/__init__.py +1 -0
  116. torch_geometric/nn/aggr/base.py +1 -1
  117. torch_geometric/nn/aggr/equilibrium.py +1 -1
  118. torch_geometric/nn/aggr/fused.py +1 -1
  119. torch_geometric/nn/aggr/patch_transformer.py +8 -2
  120. torch_geometric/nn/aggr/set_transformer.py +1 -1
  121. torch_geometric/nn/aggr/utils.py +9 -4
  122. torch_geometric/nn/attention/__init__.py +9 -1
  123. torch_geometric/nn/attention/polynormer.py +107 -0
  124. torch_geometric/nn/attention/qformer.py +71 -0
  125. torch_geometric/nn/attention/sgformer.py +99 -0
  126. torch_geometric/nn/conv/__init__.py +2 -0
  127. torch_geometric/nn/conv/appnp.py +1 -1
  128. torch_geometric/nn/conv/cugraph/gat_conv.py +8 -2
  129. torch_geometric/nn/conv/cugraph/rgcn_conv.py +3 -0
  130. torch_geometric/nn/conv/cugraph/sage_conv.py +3 -0
  131. torch_geometric/nn/conv/dna_conv.py +1 -1
  132. torch_geometric/nn/conv/eg_conv.py +7 -7
  133. torch_geometric/nn/conv/gen_conv.py +1 -1
  134. torch_geometric/nn/conv/gravnet_conv.py +2 -1
  135. torch_geometric/nn/conv/hetero_conv.py +2 -1
  136. torch_geometric/nn/conv/meshcnn_conv.py +487 -0
  137. torch_geometric/nn/conv/message_passing.py +5 -4
  138. torch_geometric/nn/conv/rgcn_conv.py +2 -1
  139. torch_geometric/nn/conv/sg_conv.py +1 -1
  140. torch_geometric/nn/conv/spline_conv.py +2 -1
  141. torch_geometric/nn/conv/ssg_conv.py +1 -1
  142. torch_geometric/nn/conv/transformer_conv.py +5 -3
  143. torch_geometric/nn/data_parallel.py +5 -4
  144. torch_geometric/nn/dense/linear.py +0 -20
  145. torch_geometric/nn/encoding.py +17 -3
  146. torch_geometric/nn/fx.py +14 -12
  147. torch_geometric/nn/model_hub.py +2 -15
  148. torch_geometric/nn/models/__init__.py +11 -2
  149. torch_geometric/nn/models/attentive_fp.py +1 -1
  150. torch_geometric/nn/models/attract_repel.py +148 -0
  151. torch_geometric/nn/models/basic_gnn.py +2 -1
  152. torch_geometric/nn/models/captum.py +1 -1
  153. torch_geometric/nn/models/deep_graph_infomax.py +1 -1
  154. torch_geometric/nn/models/dimenet.py +2 -2
  155. torch_geometric/nn/models/dimenet_utils.py +4 -2
  156. torch_geometric/nn/models/gpse.py +1083 -0
  157. torch_geometric/nn/models/graph_unet.py +13 -4
  158. torch_geometric/nn/models/lpformer.py +783 -0
  159. torch_geometric/nn/models/metapath2vec.py +1 -1
  160. torch_geometric/nn/models/mlp.py +4 -2
  161. torch_geometric/nn/models/node2vec.py +1 -1
  162. torch_geometric/nn/models/polynormer.py +206 -0
  163. torch_geometric/nn/models/rev_gnn.py +3 -3
  164. torch_geometric/nn/models/sgformer.py +219 -0
  165. torch_geometric/nn/models/signed_gcn.py +1 -1
  166. torch_geometric/nn/models/visnet.py +2 -2
  167. torch_geometric/nn/norm/batch_norm.py +17 -7
  168. torch_geometric/nn/norm/diff_group_norm.py +7 -2
  169. torch_geometric/nn/norm/graph_norm.py +9 -4
  170. torch_geometric/nn/norm/instance_norm.py +5 -1
  171. torch_geometric/nn/norm/layer_norm.py +15 -7
  172. torch_geometric/nn/norm/msg_norm.py +8 -2
  173. torch_geometric/nn/pool/__init__.py +8 -4
  174. torch_geometric/nn/pool/cluster_pool.py +3 -4
  175. torch_geometric/nn/pool/connect/base.py +1 -3
  176. torch_geometric/nn/pool/knn.py +13 -10
  177. torch_geometric/nn/pool/select/base.py +1 -4
  178. torch_geometric/nn/to_hetero_module.py +4 -3
  179. torch_geometric/nn/to_hetero_transformer.py +3 -3
  180. torch_geometric/nn/to_hetero_with_bases_transformer.py +4 -4
  181. torch_geometric/profile/__init__.py +2 -0
  182. torch_geometric/profile/nvtx.py +66 -0
  183. torch_geometric/profile/utils.py +20 -5
  184. torch_geometric/sampler/__init__.py +2 -1
  185. torch_geometric/sampler/base.py +336 -7
  186. torch_geometric/sampler/hgt_sampler.py +11 -1
  187. torch_geometric/sampler/neighbor_sampler.py +296 -23
  188. torch_geometric/sampler/utils.py +93 -5
  189. torch_geometric/testing/__init__.py +4 -0
  190. torch_geometric/testing/decorators.py +35 -5
  191. torch_geometric/testing/distributed.py +1 -1
  192. torch_geometric/transforms/__init__.py +2 -0
  193. torch_geometric/transforms/add_gpse.py +49 -0
  194. torch_geometric/transforms/add_metapaths.py +8 -6
  195. torch_geometric/transforms/add_positional_encoding.py +2 -2
  196. torch_geometric/transforms/base_transform.py +2 -1
  197. torch_geometric/transforms/delaunay.py +65 -15
  198. torch_geometric/transforms/face_to_edge.py +32 -3
  199. torch_geometric/transforms/gdc.py +7 -8
  200. torch_geometric/transforms/largest_connected_components.py +1 -1
  201. torch_geometric/transforms/mask.py +5 -1
  202. torch_geometric/transforms/normalize_features.py +3 -3
  203. torch_geometric/transforms/random_link_split.py +1 -1
  204. torch_geometric/transforms/remove_duplicated_edges.py +4 -2
  205. torch_geometric/transforms/rooted_subgraph.py +1 -1
  206. torch_geometric/typing.py +70 -17
  207. torch_geometric/utils/__init__.py +4 -1
  208. torch_geometric/utils/_lexsort.py +0 -9
  209. torch_geometric/utils/_negative_sampling.py +27 -12
  210. torch_geometric/utils/_scatter.py +132 -195
  211. torch_geometric/utils/_sort_edge_index.py +0 -2
  212. torch_geometric/utils/_spmm.py +16 -14
  213. torch_geometric/utils/_subgraph.py +4 -0
  214. torch_geometric/utils/_to_dense_batch.py +2 -2
  215. torch_geometric/utils/_trim_to_layer.py +2 -2
  216. torch_geometric/utils/convert.py +17 -10
  217. torch_geometric/utils/cross_entropy.py +34 -13
  218. torch_geometric/utils/embedding.py +91 -2
  219. torch_geometric/utils/geodesic.py +4 -3
  220. torch_geometric/utils/influence.py +279 -0
  221. torch_geometric/utils/map.py +13 -9
  222. torch_geometric/utils/nested.py +1 -1
  223. torch_geometric/utils/smiles.py +3 -3
  224. torch_geometric/utils/sparse.py +7 -14
  225. torch_geometric/visualization/__init__.py +2 -1
  226. torch_geometric/visualization/graph.py +250 -5
  227. torch_geometric/warnings.py +11 -2
  228. torch_geometric/nn/nlp/__init__.py +0 -7
  229. torch_geometric/nn/nlp/sentence_transformer.py +0 -101
@@ -1,241 +1,342 @@
1
1
  # Code adapted from the G-Retriever paper: https://arxiv.org/abs/2402.07630
2
- from typing import Any, Dict, List, Tuple, no_type_check
2
+ import gc
3
+ import os
4
+ from itertools import chain
5
+ from typing import Any, Dict, Iterator, List, Optional
3
6
 
4
- import numpy as np
5
7
  import torch
6
- from torch import Tensor
7
8
  from tqdm import tqdm
8
9
 
9
- from torch_geometric.data import Data, InMemoryDataset
10
- from torch_geometric.nn.nlp import SentenceTransformer
11
-
12
-
13
- @no_type_check
14
- def retrieval_via_pcst(
15
- data: Data,
16
- q_emb: Tensor,
17
- textual_nodes: Any,
18
- textual_edges: Any,
19
- topk: int = 3,
20
- topk_e: int = 3,
21
- cost_e: float = 0.5,
22
- ) -> Tuple[Data, str]:
23
- c = 0.01
24
- if len(textual_nodes) == 0 or len(textual_edges) == 0:
25
- desc = textual_nodes.to_csv(index=False) + "\n" + textual_edges.to_csv(
26
- index=False,
27
- columns=["src", "edge_attr", "dst"],
28
- )
29
- return data, desc
30
-
31
- from pcst_fast import pcst_fast
32
-
33
- root = -1
34
- num_clusters = 1
35
- pruning = 'gw'
36
- verbosity_level = 0
37
- if topk > 0:
38
- n_prizes = torch.nn.CosineSimilarity(dim=-1)(q_emb, data.x)
39
- topk = min(topk, data.num_nodes)
40
- _, topk_n_indices = torch.topk(n_prizes, topk, largest=True)
41
-
42
- n_prizes = torch.zeros_like(n_prizes)
43
- n_prizes[topk_n_indices] = torch.arange(topk, 0, -1).float()
44
- else:
45
- n_prizes = torch.zeros(data.num_nodes)
46
-
47
- if topk_e > 0:
48
- e_prizes = torch.nn.CosineSimilarity(dim=-1)(q_emb, data.edge_attr)
49
- topk_e = min(topk_e, e_prizes.unique().size(0))
50
-
51
- topk_e_values, _ = torch.topk(e_prizes.unique(), topk_e, largest=True)
52
- e_prizes[e_prizes < topk_e_values[-1]] = 0.0
53
- last_topk_e_value = topk_e
54
- for k in range(topk_e):
55
- indices = e_prizes == topk_e_values[k]
56
- value = min((topk_e - k) / sum(indices), last_topk_e_value - c)
57
- e_prizes[indices] = value
58
- last_topk_e_value = value * (1 - c)
59
- # reduce the cost of the edges such that at least one edge is selected
60
- cost_e = min(cost_e, e_prizes.max().item() * (1 - c / 2))
61
- else:
62
- e_prizes = torch.zeros(data.num_edges)
63
-
64
- costs = []
65
- edges = []
66
- virtual_n_prizes = []
67
- virtual_edges = []
68
- virtual_costs = []
69
- mapping_n = {}
70
- mapping_e = {}
71
- for i, (src, dst) in enumerate(data.edge_index.t().numpy()):
72
- prize_e = e_prizes[i]
73
- if prize_e <= cost_e:
74
- mapping_e[len(edges)] = i
75
- edges.append((src, dst))
76
- costs.append(cost_e - prize_e)
77
- else:
78
- virtual_node_id = data.num_nodes + len(virtual_n_prizes)
79
- mapping_n[virtual_node_id] = i
80
- virtual_edges.append((src, virtual_node_id))
81
- virtual_edges.append((virtual_node_id, dst))
82
- virtual_costs.append(0)
83
- virtual_costs.append(0)
84
- virtual_n_prizes.append(prize_e - cost_e)
85
-
86
- prizes = np.concatenate([n_prizes, np.array(virtual_n_prizes)])
87
- num_edges = len(edges)
88
- if len(virtual_costs) > 0:
89
- costs = np.array(costs + virtual_costs)
90
- edges = np.array(edges + virtual_edges)
91
-
92
- vertices, edges = pcst_fast(edges, prizes, costs, root, num_clusters,
93
- pruning, verbosity_level)
94
-
95
- selected_nodes = vertices[vertices < data.num_nodes]
96
- selected_edges = [mapping_e[e] for e in edges if e < num_edges]
97
- virtual_vertices = vertices[vertices >= data.num_nodes]
98
- if len(virtual_vertices) > 0:
99
- virtual_vertices = vertices[vertices >= data.num_nodes]
100
- virtual_edges = [mapping_n[i] for i in virtual_vertices]
101
- selected_edges = np.array(selected_edges + virtual_edges)
102
-
103
- edge_index = data.edge_index[:, selected_edges]
104
- selected_nodes = np.unique(
105
- np.concatenate(
106
- [selected_nodes, edge_index[0].numpy(), edge_index[1].numpy()]))
107
-
108
- n = textual_nodes.iloc[selected_nodes]
109
- e = textual_edges.iloc[selected_edges]
110
- desc = n.to_csv(index=False) + '\n' + e.to_csv(
111
- index=False, columns=['src', 'edge_attr', 'dst'])
112
-
113
- mapping = {n: i for i, n in enumerate(selected_nodes.tolist())}
114
- src = [mapping[i] for i in edge_index[0].tolist()]
115
- dst = [mapping[i] for i in edge_index[1].tolist()]
116
-
117
- data = Data(
118
- x=data.x[selected_nodes],
119
- edge_index=torch.tensor([src, dst]),
120
- edge_attr=data.edge_attr[selected_edges],
121
- )
122
-
123
- return data, desc
124
-
125
-
126
- class WebQSPDataset(InMemoryDataset):
127
- r"""The WebQuestionsSP dataset of the `"The Value of Semantic Parse
128
- Labeling for Knowledge Base Question Answering"
129
- <https://aclanthology.org/P16-2033/>`_ paper.
10
+ from torch_geometric.data import InMemoryDataset
11
+ from torch_geometric.llm.large_graph_indexer import (
12
+ EDGE_RELATION,
13
+ LargeGraphIndexer,
14
+ TripletLike,
15
+ get_features_for_triplets_groups,
16
+ )
17
+ from torch_geometric.llm.models import SentenceTransformer
18
+ from torch_geometric.llm.utils.backend_utils import (
19
+ preprocess_triplet,
20
+ retrieval_via_pcst,
21
+ )
22
+
23
+
24
+ class KGQABaseDataset(InMemoryDataset):
25
+ r"""Base class for the 2 KGQA datasets used in `"Reasoning on Graphs:
26
+ Faithful and Interpretable Large Language Model Reasoning"
27
+ <https://arxiv.org/pdf/2310.01061>`_ paper.
130
28
 
131
29
  Args:
30
+ dataset_name (str): HuggingFace `dataset` name.
132
31
  root (str): Root directory where the dataset should be saved.
133
32
  split (str, optional): If :obj:`"train"`, loads the training dataset.
134
33
  If :obj:`"val"`, loads the validation dataset.
135
34
  If :obj:`"test"`, loads the test dataset. (default: :obj:`"train"`)
136
35
  force_reload (bool, optional): Whether to re-process the dataset.
137
36
  (default: :obj:`False`)
37
+ verbose (bool, optional): Whether to print output. Defaults to False.
38
+ use_pcst (bool, optional): Whether to preprocess the dataset's graph
39
+ with PCST or return the full graphs. (default: :obj:`True`)
40
+ load_dataset_kwargs (dict, optional):
41
+ Keyword arguments for the `datasets.load_dataset` function.
42
+ (default: :obj:`{}`)
43
+ retrieval_kwargs (dict, optional):
44
+ Keyword arguments for the
45
+ `get_features_for_triplets_groups` function.
46
+ (default: :obj:`{}`)
138
47
  """
139
48
  def __init__(
140
49
  self,
50
+ dataset_name: str,
141
51
  root: str,
142
52
  split: str = "train",
143
53
  force_reload: bool = False,
54
+ verbose: bool = False,
55
+ use_pcst: bool = True,
56
+ load_dataset_kwargs: Optional[Dict[str, Any]] = None,
57
+ retrieval_kwargs: Optional[Dict[str, Any]] = None,
144
58
  ) -> None:
59
+ self.split = split
60
+ self.dataset_name = dataset_name
61
+ self.use_pcst = use_pcst
62
+ self.load_dataset_kwargs = load_dataset_kwargs or {}
63
+ """
64
+ NOTE: If running into memory issues,
65
+ try reducing this batch size for the LargeGraphIndexer
66
+ used to build our KG.
67
+ Example: self.retrieval_kwargs = {"batch_size": 64}
68
+ """
69
+ self.retrieval_kwargs = retrieval_kwargs or {}
70
+
71
+ # Caching custom subsets of the dataset results in unsupported behavior
72
+ if 'split' in self.load_dataset_kwargs:
73
+ print("WARNING: Caching custom subsets of the dataset \
74
+ results in unsupported behavior.\
75
+ Please specify a separate root directory for each split,\
76
+ or set force_reload=True on subsequent instantiations\
77
+ of the dataset.")
78
+
79
+ self.required_splits = ['train', 'validation', 'test']
80
+
81
+ self.verbose = verbose
82
+ self.force_reload = force_reload
145
83
  super().__init__(root, force_reload=force_reload)
146
-
147
- if split not in {'train', 'val', 'test'}:
84
+ """
85
+ NOTE: Current behavior is to process the entire dataset,
86
+ and only return the split specified by the user.
87
+ """
88
+ if f'{split}_data.pt' not in set(self.processed_file_names):
148
89
  raise ValueError(f"Invalid 'split' argument (got {split})")
90
+ if split == 'val':
91
+ split = 'validation'
149
92
 
150
- path = self.processed_paths[['train', 'val', 'test'].index(split)]
151
- self.load(path)
93
+ self.load(self.processed_paths[self.required_splits.index(split)])
94
+
95
+ @property
96
+ def raw_file_names(self) -> List[str]:
97
+ return ["raw.pt"]
152
98
 
153
99
  @property
154
100
  def processed_file_names(self) -> List[str]:
155
- return ['train_data.pt', 'val_data.pt', 'test_data.pt']
101
+ return ["train_data.pt", "val_data.pt", "test_data.pt"]
156
102
 
157
- def process(self) -> None:
103
+ def download(self) -> None:
158
104
  import datasets
159
- import pandas as pd
160
105
 
161
- datasets = datasets.load_dataset('rmanluo/RoG-webqsp')
106
+ # HF Load Dataset by dataset name if no path is specified
107
+ self.load_dataset_kwargs['path'] = self.load_dataset_kwargs.get(
108
+ 'path', self.dataset_name)
109
+ raw_dataset = datasets.load_dataset(**self.load_dataset_kwargs)
110
+
111
+ # Assert that the dataset contains the required splits
112
+ assert all(split in raw_dataset for split in self.required_splits), \
113
+ f"Dataset '{self.dataset_name}' is missing required splits: \
114
+ {self.required_splits}"
115
+
116
+ raw_dataset.save_to_disk(self.raw_paths[0])
117
+
118
+ def _get_trips(self) -> Iterator[TripletLike]:
119
+ # Iterate over each element's graph in each split of the dataset
120
+ # Using chain to lazily iterate without storing all trips in memory
121
+ split_iterators = []
122
+
123
+ for split in self.required_splits:
124
+ # Create an iterator for each element's graph in the current split
125
+ split_graphs = (element['graph']
126
+ for element in self.raw_dataset[split])
127
+ split_iterators.append(chain.from_iterable(split_graphs))
128
+
129
+ # Chain all split iterators together
130
+ return chain.from_iterable(split_iterators)
131
+
132
+ def _build_graph(self) -> None:
133
+ print("Encoding graph...")
134
+ trips = self._get_trips()
135
+ self.indexer: LargeGraphIndexer = LargeGraphIndexer.from_triplets(
136
+ trips, pre_transform=preprocess_triplet)
137
+
138
+ # Nodes:
139
+ print("\tEncoding nodes...")
140
+ nodes = self.indexer.get_unique_node_features()
141
+ x = self.model.encode(nodes, batch_size=256, output_device='cpu')
142
+ self.indexer.add_node_feature(new_feature_name="x", new_feature_vals=x)
143
+
144
+ # Edges:
145
+ print("\tEncoding edges...")
146
+ edges = self.indexer.get_unique_edge_features(
147
+ feature_name=EDGE_RELATION)
148
+ edge_attr = self.model.encode(edges, batch_size=256,
149
+ output_device='cpu')
150
+ self.indexer.add_edge_feature(
151
+ new_feature_name="edge_attr",
152
+ new_feature_vals=edge_attr,
153
+ map_from_feature=EDGE_RELATION,
154
+ )
155
+
156
+ print("\tSaving graph...")
157
+ self.indexer.save(self.indexer_path)
158
+
159
+ def _retrieve_subgraphs(self) -> None:
160
+ raw_splits = [
161
+ self.raw_dataset[split] for split in self.required_splits
162
+ ]
163
+ zipped = zip(
164
+ self.required_splits,
165
+ raw_splits, # noqa
166
+ self.processed_paths,
167
+ )
168
+ for split_name, dataset, path in zipped:
169
+ print(f"Processing {split_name} split...")
170
+
171
+ print("\tEncoding questions...")
172
+ split_questions = [str(element['question']) for element in dataset]
173
+ split_q_embs = self.model.encode(split_questions, batch_size=256,
174
+ output_device='cpu')
175
+
176
+ print("\tRetrieving subgraphs...")
177
+ results_graphs = []
178
+ retrieval_kwargs = {
179
+ **self.retrieval_kwargs,
180
+ **{
181
+ 'pre_transform': preprocess_triplet,
182
+ 'verbose': self.verbose,
183
+ }
184
+ }
185
+ graph_gen = get_features_for_triplets_groups(
186
+ self.indexer, (element['graph'] for element in dataset),
187
+ **retrieval_kwargs)
188
+
189
+ for index in tqdm(range(len(dataset)), disable=not self.verbose):
190
+ data_i = dataset[index]
191
+ graph = next(graph_gen)
192
+ textual_nodes = self.textual_nodes.iloc[
193
+ graph["node_idx"]].reset_index()
194
+ textual_edges = self.textual_edges.iloc[
195
+ graph["edge_idx"]].reset_index()
196
+ if self.use_pcst and len(textual_nodes) > 0 and len(
197
+ textual_edges) > 0:
198
+ subgraph, desc = retrieval_via_pcst(
199
+ graph,
200
+ split_q_embs[index],
201
+ textual_nodes,
202
+ textual_edges,
203
+ )
204
+ else:
205
+ desc = textual_nodes.to_csv(
206
+ index=False) + "\n" + textual_edges.to_csv(
207
+ index=False,
208
+ columns=["src", "edge_attr", "dst"],
209
+ )
210
+ subgraph = graph
211
+ question = f"Question: {data_i['question']}\nAnswer: "
212
+ label = ("|").join(data_i["answer"]).lower()
213
+
214
+ subgraph["question"] = question
215
+ subgraph["label"] = label
216
+ subgraph["desc"] = desc
217
+ results_graphs.append(subgraph.to("cpu"))
218
+ print("\tSaving subgraphs...")
219
+ self.save(results_graphs, path)
220
+
221
+ def process(self) -> None:
222
+ import datasets
223
+ from pandas import DataFrame
224
+ self.raw_dataset = datasets.load_from_disk(self.raw_paths[0])
162
225
 
163
226
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
164
227
  model_name = 'sentence-transformers/all-roberta-large-v1'
165
- model = SentenceTransformer(model_name).to(device)
166
- model.eval()
167
-
168
- for dataset, path in zip(
169
- [datasets['train'], datasets['validation'], datasets['test']],
170
- self.processed_paths,
171
- ):
172
- questions = [example["question"] for example in dataset]
173
- question_embs = model.encode(
174
- questions,
175
- batch_size=256,
176
- output_device='cpu',
177
- )
178
-
179
- data_list = []
180
- for i, example in enumerate(tqdm(dataset)):
181
- raw_nodes: Dict[str, int] = {}
182
- raw_edges = []
183
- for tri in example["graph"]:
184
- h, r, t = tri
185
- h = h.lower()
186
- t = t.lower()
187
- if h not in raw_nodes:
188
- raw_nodes[h] = len(raw_nodes)
189
- if t not in raw_nodes:
190
- raw_nodes[t] = len(raw_nodes)
191
- raw_edges.append({
192
- "src": raw_nodes[h],
193
- "edge_attr": r,
194
- "dst": raw_nodes[t]
195
- })
196
- nodes = pd.DataFrame([{
197
- "node_id": v,
198
- "node_attr": k,
199
- } for k, v in raw_nodes.items()],
200
- columns=["node_id", "node_attr"])
201
- edges = pd.DataFrame(raw_edges,
202
- columns=["src", "edge_attr", "dst"])
203
-
204
- nodes.node_attr = nodes.node_attr.fillna("")
205
- x = model.encode(
206
- nodes.node_attr.tolist(),
207
- batch_size=256,
208
- output_device='cpu',
209
- )
210
- edge_attr = model.encode(
211
- edges.edge_attr.tolist(),
212
- batch_size=256,
213
- output_device='cpu',
214
- )
215
- edge_index = torch.tensor([
216
- edges.src.tolist(),
217
- edges.dst.tolist(),
218
- ], dtype=torch.long)
219
-
220
- question = f"Question: {example['question']}\nAnswer: "
221
- label = ('|').join(example['answer']).lower()
222
- data = Data(
223
- x=x,
224
- edge_index=edge_index,
225
- edge_attr=edge_attr,
226
- )
227
- data, desc = retrieval_via_pcst(
228
- data,
229
- question_embs[i],
230
- nodes,
231
- edges,
232
- topk=3,
233
- topk_e=5,
234
- cost_e=0.5,
235
- )
236
- data.question = question
237
- data.label = label
238
- data.desc = desc
239
- data_list.append(data)
240
-
241
- self.save(data_list, path)
228
+ self.model: SentenceTransformer = SentenceTransformer(model_name).to(
229
+ device)
230
+ self.model.eval()
231
+ self.indexer_path = os.path.join(self.processed_dir,
232
+ "large_graph_indexer")
233
+ if self.force_reload or not os.path.exists(self.indexer_path):
234
+ self._build_graph()
235
+ else:
236
+ print("Loading graph...")
237
+ self.indexer = LargeGraphIndexer.from_disk(self.indexer_path)
238
+ self.textual_nodes = DataFrame.from_dict(
239
+ {"node_attr": self.indexer.get_node_features()})
240
+ self.textual_nodes["node_id"] = self.textual_nodes.index
241
+ self.textual_nodes = self.textual_nodes[["node_id", "node_attr"]]
242
+ self.textual_edges = DataFrame(self.indexer.get_edge_features(),
243
+ columns=["src", "edge_attr", "dst"])
244
+ self.textual_edges["src"] = [
245
+ self.indexer._nodes[h] for h in self.textual_edges["src"]
246
+ ]
247
+ self.textual_edges["dst"] = [
248
+ self.indexer._nodes[h] for h in self.textual_edges["dst"]
249
+ ]
250
+ self._retrieve_subgraphs()
251
+
252
+ gc.collect()
253
+ torch.cuda.empty_cache()
254
+
255
+
256
+ class WebQSPDataset(KGQABaseDataset):
257
+ r"""The WebQuestionsSP dataset of the `"The Value of Semantic Parse
258
+ Labeling for Knowledge Base Question Answering"
259
+ <https://aclanthology.org/P16-2033/>`_ paper.
260
+
261
+ Args:
262
+ root (str): Root directory where the dataset should be saved.
263
+ split (str, optional): If :obj:`"train"`, loads the training dataset.
264
+ If :obj:`"val"`, loads the validation dataset.
265
+ If :obj:`"test"`, loads the test dataset. (default: :obj:`"train"`)
266
+ force_reload (bool, optional): Whether to re-process the dataset.
267
+ (default: :obj:`False`)
268
+ verbose (bool, optional): Whether to print output. Defaults to False.
269
+ use_pcst (bool, optional): Whether to preprocess the dataset's graph
270
+ with PCST or return the full graphs. (default: :obj:`True`)
271
+ load_dataset_kwargs (dict, optional):
272
+ Keyword arguments for the `datasets.load_dataset` function.
273
+ (default: :obj:`{}`)
274
+ retrieval_kwargs (dict, optional):
275
+ Keyword arguments for the
276
+ `get_features_for_triplets_groups` function.
277
+ (default: :obj:`{}`)
278
+ """
279
+ def __init__(
280
+ self,
281
+ root: str,
282
+ split: str = "train",
283
+ force_reload: bool = False,
284
+ verbose: bool = False,
285
+ use_pcst: bool = True,
286
+ load_dataset_kwargs: Optional[Dict[str, Any]] = None,
287
+ retrieval_kwargs: Optional[Dict[str, Any]] = None,
288
+ ) -> None:
289
+ load_dataset_kwargs = load_dataset_kwargs or {}
290
+ retrieval_kwargs = retrieval_kwargs or {}
291
+ # Modify these paramters if running into memory/compute issues
292
+ default_retrieval_kwargs = {
293
+ 'max_batch_size': 250, # Lower batch size to reduce memory usage
294
+ 'num_workers':
295
+ None, # Use all available workers, or set to number of threads
296
+ }
297
+ retrieval_kwargs = {**default_retrieval_kwargs, **retrieval_kwargs}
298
+ dataset_name = 'rmanluo/RoG-webqsp'
299
+ super().__init__(dataset_name, root, split, force_reload, verbose,
300
+ use_pcst, load_dataset_kwargs=load_dataset_kwargs,
301
+ retrieval_kwargs=retrieval_kwargs)
302
+
303
+
304
+ class CWQDataset(KGQABaseDataset):
305
+ r"""The ComplexWebQuestions (CWQ) dataset of the `"The Web as a
306
+ Knowledge-base forAnswering Complex Questions"
307
+ <https://arxiv.org/pdf/1803.06643>`_ paper.
308
+
309
+ Args:
310
+ root (str): Root directory where the dataset should be saved.
311
+ split (str, optional): If :obj:`"train"`, loads the training dataset.
312
+ If :obj:`"val"`, loads the validation dataset.
313
+ If :obj:`"test"`, loads the test dataset. (default: :obj:`"train"`)
314
+ force_reload (bool, optional): Whether to re-process the dataset.
315
+ (default: :obj:`False`)
316
+ verbose (bool, optional): Whether to print output. Defaults to False.
317
+ use_pcst (bool, optional): Whether to preprocess the dataset's graph
318
+ with PCST or return the full graphs. (default: :obj:`True`)
319
+ load_dataset_kwargs (dict, optional):
320
+ Keyword arguments for the `datasets.load_dataset` function.
321
+ (default: :obj:`{}`)
322
+ retrieval_kwargs (dict, optional):
323
+ Keyword arguments for the
324
+ `get_features_for_triplets_groups` function.
325
+ (default: :obj:`{}`)
326
+ """
327
+ def __init__(
328
+ self,
329
+ root: str,
330
+ split: str = "train",
331
+ force_reload: bool = False,
332
+ verbose: bool = False,
333
+ use_pcst: bool = True,
334
+ load_dataset_kwargs: Optional[Dict[str, Any]] = None,
335
+ retrieval_kwargs: Optional[Dict[str, Any]] = None,
336
+ ) -> None:
337
+ load_dataset_kwargs = load_dataset_kwargs or {}
338
+ retrieval_kwargs = retrieval_kwargs or {}
339
+ dataset_name = 'rmanluo/RoG-cwq'
340
+ super().__init__(dataset_name, root, split, force_reload, verbose,
341
+ use_pcst, load_dataset_kwargs=load_dataset_kwargs,
342
+ retrieval_kwargs=retrieval_kwargs)
@@ -45,7 +45,8 @@ class WikiCS(InMemoryDataset):
45
45
  warnings.warn(
46
46
  f"The {self.__class__.__name__} dataset now returns an "
47
47
  f"undirected graph by default. Please explicitly specify "
48
- f"'is_undirected=False' to restore the old behavior.")
48
+ f"'is_undirected=False' to restore the old behavior.",
49
+ stacklevel=2)
49
50
  is_undirected = True
50
51
  self.is_undirected = is_undirected
51
52
  super().__init__(root, transform, pre_transform,
@@ -23,7 +23,7 @@ def deprecated(
23
23
  out = f"'{name}' is deprecated"
24
24
  if details is not None:
25
25
  out += f", {details}"
26
- warnings.warn(out)
26
+ warnings.warn(out, stacklevel=2)
27
27
  return func(*args, **kwargs)
28
28
 
29
29
  return wrapper
@@ -1,3 +1,5 @@
1
+ from warnings import warn
2
+
1
3
  from .dist_context import DistContext
2
4
  from .local_feature_store import LocalFeatureStore
3
5
  from .local_graph_store import LocalGraphStore
@@ -7,6 +9,17 @@ from .dist_loader import DistLoader
7
9
  from .dist_neighbor_loader import DistNeighborLoader
8
10
  from .dist_link_neighbor_loader import DistLinkNeighborLoader
9
11
 
12
+ warn(
13
+ "`torch_geometric.distributed` has been deprecated since 2.7.0 and will "
14
+ "no longer be maintained. For distributed training, refer to our "
15
+ "tutorials on distributed training at "
16
+ "https://pytorch-geometric.readthedocs.io/en/latest/tutorial/distributed.html " # noqa: E501
17
+ "or cuGraph examples at "
18
+ "https://github.com/rapidsai/cugraph-gnn/tree/main/python/cugraph-pyg/cugraph_pyg/examples", # noqa: E501
19
+ stacklevel=2,
20
+ category=DeprecationWarning,
21
+ )
22
+
10
23
  __all__ = classes = [
11
24
  'DistContext',
12
25
  'LocalFeatureStore',
@@ -138,9 +138,9 @@ class DistLoader:
138
138
  # close RPC & worker group at exit:
139
139
  atexit.register(shutdown_rpc, self.current_ctx_worker.worker_name)
140
140
 
141
- except RuntimeError:
141
+ except RuntimeError as e:
142
142
  raise RuntimeError(f"`{self}.init_fn()` could not initialize the "
143
- f"worker loop of the neighbor sampler")
143
+ f"worker loop of the neighbor sampler") from e
144
144
 
145
145
  def __repr__(self) -> str:
146
146
  return f'{self.__class__.__name__}(pid={self.pid})'
@@ -304,7 +304,7 @@ class Partitioner:
304
304
  elif self.is_node_level_time:
305
305
  node_time = data.time
306
306
 
307
- # Sort by column to avoid keeping track of permuations in
307
+ # Sort by column to avoid keeping track of permutations in
308
308
  # `NeighborSampler` when converting to CSC format:
309
309
  global_row, global_col, perm = sort_csc(
310
310
  global_row, global_col, node_time, edge_time)
@@ -361,7 +361,7 @@ class Partitioner:
361
361
  'edge_types': self.edge_types,
362
362
  'node_offset': list(node_offset.values()) if node_offset else None,
363
363
  'is_hetero': self.is_hetero,
364
- 'is_sorted': True, # Based on colum/destination.
364
+ 'is_sorted': True, # Based on column/destination.
365
365
  }
366
366
  with open(osp.join(self.root, 'META.json'), 'w') as f:
367
367
  json.dump(meta, f)
@@ -92,7 +92,7 @@ def shutdown_rpc(id: str = None, graceful: bool = True,
92
92
  class RPCRouter:
93
93
  r"""A router to get the worker based on the partition ID."""
94
94
  def __init__(self, partition_to_workers: List[List[str]]):
95
- for pid, rpc_worker_list in enumerate(partition_to_workers):
95
+ for rpc_worker_list in partition_to_workers:
96
96
  if len(rpc_worker_list) == 0:
97
97
  raise ValueError('No RPC worker is in worker list')
98
98
  self.partition_to_workers = partition_to_workers
@@ -120,7 +120,7 @@ def rpc_partition_to_workers(
120
120
  partition_to_workers = [[] for _ in range(num_partitions)]
121
121
  gathered_results = global_all_gather(
122
122
  (ctx.role, num_partitions, current_partition_idx))
123
- for worker_name, (role, nparts, idx) in gathered_results.items():
123
+ for worker_name, (_, _, idx) in gathered_results.items():
124
124
  partition_to_workers[idx].append(worker_name)
125
125
  return partition_to_workers
126
126
 
@@ -144,7 +144,7 @@ _rpc_call_pool: Dict[int, RPCCallBase] = {}
144
144
  @rpc_require_initialized
145
145
  def rpc_register(call: RPCCallBase) -> int:
146
146
  r"""Registers a call for RPC requests."""
147
- global _rpc_call_id, _rpc_call_pool
147
+ global _rpc_call_id
148
148
 
149
149
  with _rpc_call_lock:
150
150
  call_id = _rpc_call_id