pyg-nightly 2.6.0.dev20240319__py3-none-any.whl → 2.7.0.dev20250114__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (226) hide show
  1. {pyg_nightly-2.6.0.dev20240319.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/METADATA +31 -47
  2. {pyg_nightly-2.6.0.dev20240319.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/RECORD +226 -199
  3. {pyg_nightly-2.6.0.dev20240319.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/WHEEL +1 -1
  4. torch_geometric/__init__.py +28 -1
  5. torch_geometric/_compile.py +8 -1
  6. torch_geometric/_onnx.py +14 -0
  7. torch_geometric/config_mixin.py +113 -0
  8. torch_geometric/config_store.py +28 -19
  9. torch_geometric/data/__init__.py +24 -1
  10. torch_geometric/data/batch.py +2 -2
  11. torch_geometric/data/collate.py +8 -2
  12. torch_geometric/data/data.py +16 -8
  13. torch_geometric/data/database.py +61 -15
  14. torch_geometric/data/dataset.py +14 -6
  15. torch_geometric/data/feature_store.py +25 -42
  16. torch_geometric/data/graph_store.py +1 -5
  17. torch_geometric/data/hetero_data.py +18 -9
  18. torch_geometric/data/in_memory_dataset.py +2 -4
  19. torch_geometric/data/large_graph_indexer.py +677 -0
  20. torch_geometric/data/lightning/datamodule.py +4 -4
  21. torch_geometric/data/separate.py +6 -1
  22. torch_geometric/data/storage.py +17 -7
  23. torch_geometric/data/summary.py +14 -4
  24. torch_geometric/data/temporal.py +1 -2
  25. torch_geometric/datasets/__init__.py +17 -2
  26. torch_geometric/datasets/actor.py +9 -11
  27. torch_geometric/datasets/airfrans.py +15 -18
  28. torch_geometric/datasets/airports.py +10 -12
  29. torch_geometric/datasets/amazon.py +8 -11
  30. torch_geometric/datasets/amazon_book.py +9 -10
  31. torch_geometric/datasets/amazon_products.py +9 -10
  32. torch_geometric/datasets/aminer.py +8 -9
  33. torch_geometric/datasets/aqsol.py +10 -13
  34. torch_geometric/datasets/attributed_graph_dataset.py +10 -12
  35. torch_geometric/datasets/ba_multi_shapes.py +10 -12
  36. torch_geometric/datasets/ba_shapes.py +5 -6
  37. torch_geometric/datasets/bitcoin_otc.py +1 -1
  38. torch_geometric/datasets/brca_tgca.py +1 -1
  39. torch_geometric/datasets/cornell.py +145 -0
  40. torch_geometric/datasets/dblp.py +2 -1
  41. torch_geometric/datasets/dbp15k.py +2 -2
  42. torch_geometric/datasets/fake.py +1 -3
  43. torch_geometric/datasets/flickr.py +2 -1
  44. torch_geometric/datasets/freebase.py +1 -1
  45. torch_geometric/datasets/gdelt_lite.py +3 -2
  46. torch_geometric/datasets/ged_dataset.py +3 -2
  47. torch_geometric/datasets/git_mol_dataset.py +263 -0
  48. torch_geometric/datasets/gnn_benchmark_dataset.py +11 -10
  49. torch_geometric/datasets/hgb_dataset.py +8 -8
  50. torch_geometric/datasets/imdb.py +2 -1
  51. torch_geometric/datasets/karate.py +3 -2
  52. torch_geometric/datasets/last_fm.py +2 -1
  53. torch_geometric/datasets/linkx_dataset.py +4 -3
  54. torch_geometric/datasets/lrgb.py +3 -5
  55. torch_geometric/datasets/malnet_tiny.py +4 -3
  56. torch_geometric/datasets/mnist_superpixels.py +2 -3
  57. torch_geometric/datasets/molecule_gpt_dataset.py +485 -0
  58. torch_geometric/datasets/molecule_net.py +15 -3
  59. torch_geometric/datasets/motif_generator/base.py +0 -1
  60. torch_geometric/datasets/neurograph.py +1 -3
  61. torch_geometric/datasets/ogb_mag.py +1 -1
  62. torch_geometric/datasets/opf.py +239 -0
  63. torch_geometric/datasets/ose_gvcs.py +1 -1
  64. torch_geometric/datasets/pascal.py +11 -9
  65. torch_geometric/datasets/pascal_pf.py +1 -1
  66. torch_geometric/datasets/pcpnet_dataset.py +1 -1
  67. torch_geometric/datasets/pcqm4m.py +10 -3
  68. torch_geometric/datasets/ppi.py +1 -1
  69. torch_geometric/datasets/qm9.py +8 -7
  70. torch_geometric/datasets/rcdd.py +4 -4
  71. torch_geometric/datasets/reddit.py +2 -1
  72. torch_geometric/datasets/reddit2.py +2 -1
  73. torch_geometric/datasets/rel_link_pred_dataset.py +3 -3
  74. torch_geometric/datasets/s3dis.py +5 -3
  75. torch_geometric/datasets/shapenet.py +3 -3
  76. torch_geometric/datasets/shrec2016.py +2 -2
  77. torch_geometric/datasets/snap_dataset.py +7 -1
  78. torch_geometric/datasets/tag_dataset.py +350 -0
  79. torch_geometric/datasets/upfd.py +2 -1
  80. torch_geometric/datasets/web_qsp_dataset.py +246 -0
  81. torch_geometric/datasets/webkb.py +2 -2
  82. torch_geometric/datasets/wikics.py +1 -1
  83. torch_geometric/datasets/wikidata.py +3 -2
  84. torch_geometric/datasets/wikipedia_network.py +2 -2
  85. torch_geometric/datasets/willow_object_class.py +1 -1
  86. torch_geometric/datasets/word_net.py +2 -2
  87. torch_geometric/datasets/yelp.py +2 -1
  88. torch_geometric/datasets/zinc.py +1 -1
  89. torch_geometric/device.py +42 -0
  90. torch_geometric/distributed/local_feature_store.py +3 -2
  91. torch_geometric/distributed/local_graph_store.py +2 -1
  92. torch_geometric/distributed/partition.py +9 -8
  93. torch_geometric/edge_index.py +616 -438
  94. torch_geometric/explain/algorithm/base.py +0 -1
  95. torch_geometric/explain/algorithm/graphmask_explainer.py +1 -2
  96. torch_geometric/explain/algorithm/pg_explainer.py +1 -1
  97. torch_geometric/explain/explanation.py +2 -2
  98. torch_geometric/graphgym/checkpoint.py +2 -1
  99. torch_geometric/graphgym/logger.py +4 -4
  100. torch_geometric/graphgym/loss.py +1 -1
  101. torch_geometric/graphgym/utils/agg_runs.py +6 -6
  102. torch_geometric/index.py +826 -0
  103. torch_geometric/inspector.py +8 -3
  104. torch_geometric/io/fs.py +28 -2
  105. torch_geometric/io/npz.py +2 -1
  106. torch_geometric/io/off.py +2 -2
  107. torch_geometric/io/sdf.py +2 -2
  108. torch_geometric/io/tu.py +4 -5
  109. torch_geometric/loader/__init__.py +4 -0
  110. torch_geometric/loader/cluster.py +10 -4
  111. torch_geometric/loader/graph_saint.py +2 -1
  112. torch_geometric/loader/ibmb_loader.py +12 -4
  113. torch_geometric/loader/mixin.py +1 -1
  114. torch_geometric/loader/neighbor_loader.py +1 -1
  115. torch_geometric/loader/neighbor_sampler.py +2 -2
  116. torch_geometric/loader/prefetch.py +1 -1
  117. torch_geometric/loader/rag_loader.py +107 -0
  118. torch_geometric/loader/utils.py +8 -7
  119. torch_geometric/loader/zip_loader.py +10 -0
  120. torch_geometric/metrics/__init__.py +11 -2
  121. torch_geometric/metrics/link_pred.py +159 -34
  122. torch_geometric/nn/aggr/__init__.py +4 -0
  123. torch_geometric/nn/aggr/attention.py +0 -2
  124. torch_geometric/nn/aggr/base.py +2 -4
  125. torch_geometric/nn/aggr/patch_transformer.py +143 -0
  126. torch_geometric/nn/aggr/set_transformer.py +1 -1
  127. torch_geometric/nn/aggr/variance_preserving.py +33 -0
  128. torch_geometric/nn/attention/__init__.py +5 -1
  129. torch_geometric/nn/attention/qformer.py +71 -0
  130. torch_geometric/nn/conv/collect.jinja +7 -4
  131. torch_geometric/nn/conv/cugraph/base.py +8 -12
  132. torch_geometric/nn/conv/edge_conv.py +3 -2
  133. torch_geometric/nn/conv/fused_gat_conv.py +1 -1
  134. torch_geometric/nn/conv/gat_conv.py +35 -7
  135. torch_geometric/nn/conv/gatv2_conv.py +36 -6
  136. torch_geometric/nn/conv/general_conv.py +1 -1
  137. torch_geometric/nn/conv/graph_conv.py +21 -3
  138. torch_geometric/nn/conv/gravnet_conv.py +3 -2
  139. torch_geometric/nn/conv/hetero_conv.py +3 -3
  140. torch_geometric/nn/conv/hgt_conv.py +1 -1
  141. torch_geometric/nn/conv/message_passing.py +138 -87
  142. torch_geometric/nn/conv/mixhop_conv.py +1 -1
  143. torch_geometric/nn/conv/propagate.jinja +9 -1
  144. torch_geometric/nn/conv/rgcn_conv.py +5 -5
  145. torch_geometric/nn/conv/spline_conv.py +4 -4
  146. torch_geometric/nn/conv/x_conv.py +3 -2
  147. torch_geometric/nn/dense/linear.py +11 -6
  148. torch_geometric/nn/fx.py +3 -3
  149. torch_geometric/nn/model_hub.py +3 -1
  150. torch_geometric/nn/models/__init__.py +10 -2
  151. torch_geometric/nn/models/deep_graph_infomax.py +1 -2
  152. torch_geometric/nn/models/dimenet_utils.py +5 -7
  153. torch_geometric/nn/models/g_retriever.py +230 -0
  154. torch_geometric/nn/models/git_mol.py +336 -0
  155. torch_geometric/nn/models/glem.py +385 -0
  156. torch_geometric/nn/models/gnnff.py +0 -1
  157. torch_geometric/nn/models/graph_unet.py +12 -3
  158. torch_geometric/nn/models/jumping_knowledge.py +63 -4
  159. torch_geometric/nn/models/lightgcn.py +1 -1
  160. torch_geometric/nn/models/metapath2vec.py +5 -5
  161. torch_geometric/nn/models/molecule_gpt.py +222 -0
  162. torch_geometric/nn/models/node2vec.py +2 -3
  163. torch_geometric/nn/models/schnet.py +2 -1
  164. torch_geometric/nn/models/signed_gcn.py +3 -3
  165. torch_geometric/nn/module_dict.py +2 -2
  166. torch_geometric/nn/nlp/__init__.py +9 -0
  167. torch_geometric/nn/nlp/llm.py +322 -0
  168. torch_geometric/nn/nlp/sentence_transformer.py +134 -0
  169. torch_geometric/nn/nlp/vision_transformer.py +33 -0
  170. torch_geometric/nn/norm/batch_norm.py +1 -1
  171. torch_geometric/nn/parameter_dict.py +2 -2
  172. torch_geometric/nn/pool/__init__.py +21 -5
  173. torch_geometric/nn/pool/cluster_pool.py +145 -0
  174. torch_geometric/nn/pool/connect/base.py +0 -1
  175. torch_geometric/nn/pool/edge_pool.py +1 -1
  176. torch_geometric/nn/pool/graclus.py +4 -2
  177. torch_geometric/nn/pool/pool.py +8 -2
  178. torch_geometric/nn/pool/select/base.py +0 -1
  179. torch_geometric/nn/pool/voxel_grid.py +3 -2
  180. torch_geometric/nn/resolver.py +1 -1
  181. torch_geometric/nn/sequential.jinja +10 -23
  182. torch_geometric/nn/sequential.py +204 -78
  183. torch_geometric/nn/summary.py +1 -1
  184. torch_geometric/nn/to_hetero_with_bases_transformer.py +19 -19
  185. torch_geometric/profile/__init__.py +2 -0
  186. torch_geometric/profile/nvtx.py +66 -0
  187. torch_geometric/profile/profiler.py +30 -19
  188. torch_geometric/resolver.py +1 -1
  189. torch_geometric/sampler/base.py +34 -13
  190. torch_geometric/sampler/neighbor_sampler.py +11 -10
  191. torch_geometric/sampler/utils.py +1 -1
  192. torch_geometric/template.py +1 -0
  193. torch_geometric/testing/__init__.py +6 -2
  194. torch_geometric/testing/decorators.py +53 -20
  195. torch_geometric/testing/feature_store.py +1 -1
  196. torch_geometric/transforms/__init__.py +2 -0
  197. torch_geometric/transforms/add_metapaths.py +5 -5
  198. torch_geometric/transforms/add_positional_encoding.py +1 -1
  199. torch_geometric/transforms/delaunay.py +65 -14
  200. torch_geometric/transforms/face_to_edge.py +32 -3
  201. torch_geometric/transforms/gdc.py +7 -6
  202. torch_geometric/transforms/laplacian_lambda_max.py +3 -3
  203. torch_geometric/transforms/mask.py +5 -1
  204. torch_geometric/transforms/node_property_split.py +1 -2
  205. torch_geometric/transforms/pad.py +7 -6
  206. torch_geometric/transforms/random_link_split.py +1 -1
  207. torch_geometric/transforms/remove_self_loops.py +36 -0
  208. torch_geometric/transforms/svd_feature_reduction.py +1 -1
  209. torch_geometric/transforms/to_sparse_tensor.py +1 -1
  210. torch_geometric/transforms/two_hop.py +1 -1
  211. torch_geometric/transforms/virtual_node.py +2 -1
  212. torch_geometric/typing.py +43 -6
  213. torch_geometric/utils/__init__.py +5 -1
  214. torch_geometric/utils/_negative_sampling.py +1 -1
  215. torch_geometric/utils/_normalize_edge_index.py +46 -0
  216. torch_geometric/utils/_scatter.py +38 -12
  217. torch_geometric/utils/_subgraph.py +4 -0
  218. torch_geometric/utils/_tree_decomposition.py +2 -2
  219. torch_geometric/utils/augmentation.py +1 -1
  220. torch_geometric/utils/convert.py +12 -8
  221. torch_geometric/utils/geodesic.py +24 -22
  222. torch_geometric/utils/hetero.py +1 -1
  223. torch_geometric/utils/map.py +8 -2
  224. torch_geometric/utils/smiles.py +65 -27
  225. torch_geometric/utils/sparse.py +39 -25
  226. torch_geometric/visualization/graph.py +3 -4
@@ -0,0 +1,350 @@
1
+ import os
2
+ import os.path as osp
3
+ from collections.abc import Sequence
4
+ from typing import Dict, List, Optional, Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ from torch import Tensor
9
+ from tqdm import tqdm
10
+
11
+ from torch_geometric.data import InMemoryDataset, download_google_url
12
+ from torch_geometric.data.data import BaseData
13
+
14
+ try:
15
+ from pandas import DataFrame, read_csv
16
+ WITH_PANDAS = True
17
+ except ImportError:
18
+ WITH_PANDAS = False
19
+
20
+ IndexType = Union[slice, Tensor, np.ndarray, Sequence]
21
+
22
+
23
+ class TAGDataset(InMemoryDataset):
24
+ r"""The Text Attributed Graph datasets from the
25
+ `"Learning on Large-scale Text-attributed Graphs via Variational Inference
26
+ " <https://arxiv.org/abs/2210.14709>`_ paper.
27
+ This dataset is aiming on transform `ogbn products`, `ogbn arxiv`
28
+ into Text Attributed Graph that each node in graph is associate with a
29
+ raw text, that dataset can be adapt to DataLoader (for LM training) and
30
+ NeighborLoader(for GNN training). In addition, this class can be use as a
31
+ wrapper class by convert a InMemoryDataset with Tokenizer and text into
32
+ Text Attributed Graph.
33
+
34
+ Args:
35
+ root (str): Root directory where the dataset should be saved.
36
+ dataset (InMemoryDataset): The name of the dataset
37
+ (:obj:`"ogbn-products"`, :obj:`"ogbn-arxiv"`).
38
+ tokenizer_name (str): The tokenizer name for language model,
39
+ Be sure to use same tokenizer name as your `model id` of model repo
40
+ on huggingface.co.
41
+ text (List[str]): list of raw text associate with node, the order of
42
+ list should be align with node list
43
+ split_idx (Optional[Dict[str, torch.Tensor]]): Optional dictionary,
44
+ for saving split index, it is required that if your dataset doesn't
45
+ have get_split_idx function
46
+ tokenize_batch_size (int): batch size of tokenizing text, the
47
+ tokenizing process will run on cpu, default: 256
48
+ token_on_disk (bool): save token as .pt file on disk or not,
49
+ default: False
50
+ text_on_disk (bool): save given text(list of str) as dataframe on disk
51
+ or not, default: False
52
+ force_reload (bool): default: False
53
+ .. note::
54
+ See `example/llm_plus_gnn/glem.py` for example usage
55
+ """
56
+ raw_text_id = {
57
+ 'ogbn-arxiv': '1g3OOVhRyiyKv13LY6gbp8GLITocOUr_3',
58
+ 'ogbn-products': '1I-S176-W4Bm1iPDjQv3hYwQBtxE0v8mt'
59
+ }
60
+
61
+ def __init__(self, root: str, dataset: InMemoryDataset,
62
+ tokenizer_name: str, text: Optional[List[str]] = None,
63
+ split_idx: Optional[Dict[str, Tensor]] = None,
64
+ tokenize_batch_size: int = 256, token_on_disk: bool = False,
65
+ text_on_disk: bool = False,
66
+ force_reload: bool = False) -> None:
67
+ # list the vars you want to pass in before run download & process
68
+ self.name = dataset.name
69
+ self.text = text
70
+ self.tokenizer_name = tokenizer_name
71
+ from transformers import AutoTokenizer
72
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
73
+ if self.tokenizer.pad_token_id is None:
74
+ self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
75
+ if self.tokenizer.pad_token is None:
76
+ self.tokenizer.pad_token = self.tokenizer.eos_token
77
+
78
+ self.dir_name = '_'.join(dataset.name.split('-'))
79
+ self.root = osp.join(root, self.dir_name)
80
+ missing_str_list = []
81
+ if not WITH_PANDAS:
82
+ missing_str_list.append('pandas')
83
+ if len(missing_str_list) > 0:
84
+ missing_str = ' '.join(missing_str_list)
85
+ error_out = f"`pip install {missing_str}` to use this dataset."
86
+ raise ImportError(error_out)
87
+ if hasattr(dataset, 'get_idx_split'):
88
+ self.split_idx = dataset.get_idx_split()
89
+ elif split_idx is not None:
90
+ self.split_idx = split_idx
91
+ else:
92
+ raise ValueError("TAGDataset need split idx for generating "
93
+ "is_gold mask, please pass splited index "
94
+ "in format of dictionaty with 'train', 'valid' "
95
+ "'test' index tensor to 'split_idx'")
96
+ if text is not None and text_on_disk:
97
+ self.save_node_text(text)
98
+ self.text_on_disk = text_on_disk
99
+ # init will call download and process
100
+ super().__init__(self.root, transform=None, pre_transform=None,
101
+ pre_filter=None, force_reload=force_reload)
102
+ # after processing and download
103
+ # Dataset has to have BaseData as _data
104
+ assert dataset._data is not None
105
+ self._data = dataset._data # reassign reference
106
+ assert self._data is not None
107
+ assert dataset._data.y is not None
108
+ assert isinstance(self._data, BaseData)
109
+ assert self._data.num_nodes is not None
110
+ assert isinstance(dataset._data.num_nodes, int)
111
+ assert isinstance(self._data.num_nodes, int)
112
+ self._n_id = torch.arange(self._data.num_nodes)
113
+ is_good_tensor = self.load_gold_mask()
114
+ self._is_gold = is_good_tensor.squeeze()
115
+ self._data['is_gold'] = is_good_tensor
116
+ if self.text is not None and len(self.text) != self._data.num_nodes:
117
+ raise ValueError("The number of text sequence in 'text' should be "
118
+ "equal to number of nodes!")
119
+ self.token_on_disk = token_on_disk
120
+ self.tokenize_batch_size = tokenize_batch_size
121
+ self._token = self.tokenize_graph(self.tokenize_batch_size)
122
+ self.__num_classes__ = dataset.num_classes
123
+
124
+ @property
125
+ def num_classes(self) -> int:
126
+ return self.__num_classes__
127
+
128
+ @property
129
+ def raw_file_names(self) -> List[str]:
130
+ file_names = []
131
+ for root, _, files in os.walk(osp.join(self.root, 'raw')):
132
+ for file in files:
133
+ file_names.append(file)
134
+ return file_names
135
+
136
+ @property
137
+ def processed_file_names(self) -> List[str]:
138
+ return [
139
+ 'geometric_data_processed.pt', 'pre_filter.pt',
140
+ 'pre_transformed.pt'
141
+ ]
142
+
143
+ @property
144
+ def token(self) -> Dict[str, Tensor]:
145
+ if self._token is None: # lazy load
146
+ self._token = self.tokenize_graph()
147
+ return self._token
148
+
149
+ # load is_gold after init
150
+ @property
151
+ def is_gold(self) -> Tensor:
152
+ if self._is_gold is None:
153
+ print('lazy load is_gold!!')
154
+ self._is_gold = self.load_gold_mask()
155
+ return self._is_gold
156
+
157
+ def get_n_id(self, node_idx: IndexType) -> Tensor:
158
+ if self._n_id is None:
159
+ assert self._data is not None
160
+ assert self._data.num_nodes is not None
161
+ assert isinstance(self._data.num_nodes, int)
162
+ self._n_id = torch.arange(self._data.num_nodes)
163
+ return self._n_id[node_idx]
164
+
165
+ def load_gold_mask(self) -> Tensor:
166
+ r"""Use original train split as gold split, generating is_gold mask
167
+ for picking ground truth labels and pseudo labels.
168
+ """
169
+ train_split_idx = self.get_idx_split()['train']
170
+ assert self._data is not None
171
+ assert self._data.num_nodes is not None
172
+ assert isinstance(self._data.num_nodes, int)
173
+ is_good_tensor = torch.zeros(self._data.num_nodes,
174
+ dtype=torch.bool).view(-1, 1)
175
+ is_good_tensor[train_split_idx] = True
176
+ return is_good_tensor
177
+
178
+ def get_gold(self, node_idx: IndexType) -> Tensor:
179
+ r"""Get gold mask for given node_idx.
180
+
181
+ Args:
182
+ node_idx (torch.tensor): a tensor contain node idx
183
+ """
184
+ if self._is_gold is None:
185
+ self._is_gold = self.is_gold
186
+ return self._is_gold[node_idx]
187
+
188
+ def get_idx_split(self) -> Dict[str, Tensor]:
189
+ return self.split_idx
190
+
191
+ def download(self) -> None:
192
+ print('downloading raw text')
193
+ raw_text_path = download_google_url(id=self.raw_text_id[self.name],
194
+ folder=f'{self.root}/raw',
195
+ filename='node-text.csv.gz',
196
+ log=True)
197
+ text_df = read_csv(raw_text_path)
198
+ self.text = list(text_df['text'])
199
+
200
+ def process(self) -> None:
201
+ if osp.exists(osp.join(self.root, 'raw', 'node-text.csv.gz')):
202
+ text_df = read_csv(osp.join(self.root, 'raw', 'node-text.csv.gz'))
203
+ self.text = list(text_df['text'])
204
+ elif self.name in self.raw_text_id:
205
+ self.download()
206
+ else:
207
+ print('The dataset is not ogbn-products nor ogbn-arxiv,'
208
+ 'please pass in your raw text string list to `text`')
209
+ if self.text is None:
210
+ raise ValueError("The TAGDataset only have ogbn-products and "
211
+ "ogbn-arxiv raw text in default "
212
+ "The raw text of each node is not specified"
213
+ "Please pass in 'text' when convert your dataset "
214
+ "to Text Attribute Graph Dataset")
215
+
216
+ def save_node_text(self, text: List[str]) -> None:
217
+ node_text_path = osp.join(self.root, 'raw', 'node-text.csv.gz')
218
+ if osp.exists(node_text_path):
219
+ print(f'The raw text is existed at {node_text_path}')
220
+ else:
221
+ print(f'Saving raw text file at {node_text_path}')
222
+ os.makedirs(f'{self.root}/raw', exist_ok=True)
223
+ text_df = DataFrame(text, columns=['text'])
224
+ text_df.to_csv(osp.join(node_text_path), compression='gzip',
225
+ index=False)
226
+
227
+ def tokenize_graph(self, batch_size: int = 256) -> Dict[str, Tensor]:
228
+ r"""Tokenizing the text associate with each node, running in cpu.
229
+
230
+ Args:
231
+ batch_size (Optional[int]): batch size of list of text for
232
+ generating emebdding
233
+ Returns:
234
+ Dict[str, torch.Tensor]: tokenized graph
235
+ """
236
+ data_len = 0
237
+ if self.text is not None:
238
+ data_len = len(self.text)
239
+ else:
240
+ raise ValueError("The TAGDataset need text for tokenization")
241
+ token_keys = ['input_ids', 'token_type_ids', 'attention_mask']
242
+ path = os.path.join(self.processed_dir, 'token', self.tokenizer_name)
243
+ # Check if the .pt files already exist
244
+ token_files_exist = any(
245
+ os.path.exists(os.path.join(path, f'{k}.pt')) for k in token_keys)
246
+
247
+ if token_files_exist and self.token_on_disk:
248
+ print('Found tokenized file, loading may take several minutes...')
249
+ all_encoded_token = {
250
+ k: torch.load(os.path.join(path, f'{k}.pt'), weights_only=True)
251
+ for k in token_keys
252
+ if os.path.exists(os.path.join(path, f'{k}.pt'))
253
+ }
254
+ return all_encoded_token
255
+
256
+ all_encoded_token = {k: [] for k in token_keys}
257
+ pbar = tqdm(total=data_len)
258
+
259
+ pbar.set_description('Tokenizing Text Attributed Graph')
260
+ for i in range(0, data_len, batch_size):
261
+ end_index = min(data_len, i + batch_size)
262
+ token = self.tokenizer(self.text[i:min(i + batch_size, data_len)],
263
+ padding='max_length', truncation=True,
264
+ max_length=512, return_tensors="pt")
265
+ for k in token.keys():
266
+ all_encoded_token[k].append(token[k])
267
+ pbar.update(end_index - i)
268
+ pbar.close()
269
+
270
+ all_encoded_token = {
271
+ k: torch.cat(v)
272
+ for k, v in all_encoded_token.items() if len(v) > 0
273
+ }
274
+ if self.token_on_disk:
275
+ os.makedirs(path, exist_ok=True)
276
+ print('Saving tokens on Disk')
277
+ for k, tensor in all_encoded_token.items():
278
+ torch.save(tensor, os.path.join(path, f'{k}.pt'))
279
+ print('Token saved:', os.path.join(path, f'{k}.pt'))
280
+ os.environ["TOKENIZERS_PARALLELISM"] = 'true' # supressing warning
281
+ return all_encoded_token
282
+
283
+ def __repr__(self) -> str:
284
+ return f'{self.__class__.__name__}()'
285
+
286
+ class TextDataset(torch.utils.data.Dataset):
287
+ r"""This nested dataset provides textual data for each node in
288
+ the graph. Factory method to create TextDataset from TAGDataset.
289
+
290
+ Args:
291
+ tag_dataset (TAGDataset): the parent dataset
292
+ """
293
+ def __init__(self, tag_dataset: 'TAGDataset') -> None:
294
+ self.tag_dataset = tag_dataset
295
+ self.token = tag_dataset.token
296
+ assert tag_dataset._data is not None
297
+ self._data = tag_dataset._data
298
+
299
+ assert tag_dataset._data.y is not None
300
+ self.labels = tag_dataset._data.y
301
+
302
+ def get_token(self, node_idx: IndexType) -> Dict[str, Tensor]:
303
+ r"""This function will be called in __getitem__().
304
+
305
+ Args:
306
+ node_idx (IndexType): selected node idx in each batch
307
+ Returns:
308
+ items (Dict[str, Tensor]): input for LM
309
+ """
310
+ items = {k: v[node_idx] for k, v in self.token.items()}
311
+ return items
312
+
313
+ # for LM training
314
+ def __getitem__(
315
+ self, node_id: IndexType
316
+ ) -> Dict[str, Union[Tensor, Dict[str, Tensor]]]:
317
+ r"""This function will override the function in
318
+ torch.utils.data.Dataset, and will be called when you
319
+ iterate batch in the dataloader, make sure all following
320
+ key value pairs are present in the return dict.
321
+
322
+ Args:
323
+ node_id (List[int]): list of node idx for selecting tokens,
324
+ labels etc. when iterating data loader for LM
325
+ Returns:
326
+ items (dict): input k,v pairs for Language model training and
327
+ inference
328
+ """
329
+ item: Dict[str, Union[Tensor, Dict[str, Tensor]]] = {}
330
+ item['input'] = self.get_token(node_id)
331
+ item['labels'] = self.labels[node_id]
332
+ item['is_gold'] = self.tag_dataset.get_gold(node_id)
333
+ item['n_id'] = self.tag_dataset.get_n_id(node_id)
334
+ return item
335
+
336
+ def __len__(self) -> int:
337
+ assert self._data.num_nodes is not None
338
+ return self._data.num_nodes
339
+
340
+ def get(self, idx: int) -> BaseData:
341
+ return self._data
342
+
343
+ def __repr__(self) -> str:
344
+ return f'{self.__class__.__name__}()'
345
+
346
+ def to_text_dataset(self) -> TextDataset:
347
+ r"""Factory Build text dataset from Text Attributed Graph Dataset
348
+ each data point is node's associated text token.
349
+ """
350
+ return TAGDataset.TextDataset(self)
@@ -3,7 +3,6 @@ import os.path as osp
3
3
  from typing import Callable, List, Optional
4
4
 
5
5
  import numpy as np
6
- import scipy.sparse as sp
7
6
  import torch
8
7
 
9
8
  from torch_geometric.data import (
@@ -130,6 +129,8 @@ class UPFD(InMemoryDataset):
130
129
  os.remove(path)
131
130
 
132
131
  def process(self) -> None:
132
+ import scipy.sparse as sp
133
+
133
134
  x = sp.load_npz(
134
135
  osp.join(self.raw_dir, f'new_{self.feature}_feature.npz'))
135
136
  x = torch.from_numpy(x.todense()).to(torch.float)
@@ -0,0 +1,246 @@
1
+ # Code adapted from the G-Retriever paper: https://arxiv.org/abs/2402.07630
2
+ from typing import Any, Dict, List, Tuple, no_type_check
3
+
4
+ import numpy as np
5
+ import torch
6
+ from torch import Tensor
7
+ from tqdm import tqdm
8
+
9
+ from torch_geometric.data import Data, InMemoryDataset
10
+ from torch_geometric.nn.nlp import SentenceTransformer
11
+
12
+
13
+ @no_type_check
14
+ def retrieval_via_pcst(
15
+ data: Data,
16
+ q_emb: Tensor,
17
+ textual_nodes: Any,
18
+ textual_edges: Any,
19
+ topk: int = 3,
20
+ topk_e: int = 3,
21
+ cost_e: float = 0.5,
22
+ ) -> Tuple[Data, str]:
23
+ c = 0.01
24
+
25
+ from pcst_fast import pcst_fast
26
+
27
+ root = -1
28
+ num_clusters = 1
29
+ pruning = 'gw'
30
+ verbosity_level = 0
31
+ if topk > 0:
32
+ n_prizes = torch.nn.CosineSimilarity(dim=-1)(q_emb, data.x)
33
+ topk = min(topk, data.num_nodes)
34
+ _, topk_n_indices = torch.topk(n_prizes, topk, largest=True)
35
+
36
+ n_prizes = torch.zeros_like(n_prizes)
37
+ n_prizes[topk_n_indices] = torch.arange(topk, 0, -1).float()
38
+ else:
39
+ n_prizes = torch.zeros(data.num_nodes)
40
+
41
+ if topk_e > 0:
42
+ e_prizes = torch.nn.CosineSimilarity(dim=-1)(q_emb, data.edge_attr)
43
+ topk_e = min(topk_e, e_prizes.unique().size(0))
44
+
45
+ topk_e_values, _ = torch.topk(e_prizes.unique(), topk_e, largest=True)
46
+ e_prizes[e_prizes < topk_e_values[-1]] = 0.0
47
+ last_topk_e_value = topk_e
48
+ for k in range(topk_e):
49
+ indices = e_prizes == topk_e_values[k]
50
+ value = min((topk_e - k) / sum(indices), last_topk_e_value - c)
51
+ e_prizes[indices] = value
52
+ last_topk_e_value = value * (1 - c)
53
+ # reduce the cost of the edges such that at least one edge is selected
54
+ cost_e = min(cost_e, e_prizes.max().item() * (1 - c / 2))
55
+ else:
56
+ e_prizes = torch.zeros(data.num_edges)
57
+
58
+ costs = []
59
+ edges = []
60
+ virtual_n_prizes = []
61
+ virtual_edges = []
62
+ virtual_costs = []
63
+ mapping_n = {}
64
+ mapping_e = {}
65
+ for i, (src, dst) in enumerate(data.edge_index.t().numpy()):
66
+ prize_e = e_prizes[i]
67
+ if prize_e <= cost_e:
68
+ mapping_e[len(edges)] = i
69
+ edges.append((src, dst))
70
+ costs.append(cost_e - prize_e)
71
+ else:
72
+ virtual_node_id = data.num_nodes + len(virtual_n_prizes)
73
+ mapping_n[virtual_node_id] = i
74
+ virtual_edges.append((src, virtual_node_id))
75
+ virtual_edges.append((virtual_node_id, dst))
76
+ virtual_costs.append(0)
77
+ virtual_costs.append(0)
78
+ virtual_n_prizes.append(prize_e - cost_e)
79
+
80
+ prizes = np.concatenate([n_prizes, np.array(virtual_n_prizes)])
81
+ num_edges = len(edges)
82
+ if len(virtual_costs) > 0:
83
+ costs = np.array(costs + virtual_costs)
84
+ edges = np.array(edges + virtual_edges)
85
+
86
+ vertices, edges = pcst_fast(edges, prizes, costs, root, num_clusters,
87
+ pruning, verbosity_level)
88
+
89
+ selected_nodes = vertices[vertices < data.num_nodes]
90
+ selected_edges = [mapping_e[e] for e in edges if e < num_edges]
91
+ virtual_vertices = vertices[vertices >= data.num_nodes]
92
+ if len(virtual_vertices) > 0:
93
+ virtual_vertices = vertices[vertices >= data.num_nodes]
94
+ virtual_edges = [mapping_n[i] for i in virtual_vertices]
95
+ selected_edges = np.array(selected_edges + virtual_edges)
96
+
97
+ edge_index = data.edge_index[:, selected_edges]
98
+ selected_nodes = np.unique(
99
+ np.concatenate(
100
+ [selected_nodes, edge_index[0].numpy(), edge_index[1].numpy()]))
101
+
102
+ n = textual_nodes.iloc[selected_nodes]
103
+ e = textual_edges.iloc[selected_edges]
104
+ desc = n.to_csv(index=False) + '\n' + e.to_csv(
105
+ index=False, columns=['src', 'edge_attr', 'dst'])
106
+
107
+ mapping = {n: i for i, n in enumerate(selected_nodes.tolist())}
108
+ src = [mapping[i] for i in edge_index[0].tolist()]
109
+ dst = [mapping[i] for i in edge_index[1].tolist()]
110
+
111
+ data = Data(
112
+ x=data.x[selected_nodes],
113
+ edge_index=torch.tensor([src, dst]),
114
+ edge_attr=data.edge_attr[selected_edges],
115
+ )
116
+
117
+ return data, desc
118
+
119
+
120
+ class WebQSPDataset(InMemoryDataset):
121
+ r"""The WebQuestionsSP dataset of the `"The Value of Semantic Parse
122
+ Labeling for Knowledge Base Question Answering"
123
+ <https://aclanthology.org/P16-2033/>`_ paper.
124
+
125
+ Args:
126
+ root (str): Root directory where the dataset should be saved.
127
+ split (str, optional): If :obj:`"train"`, loads the training dataset.
128
+ If :obj:`"val"`, loads the validation dataset.
129
+ If :obj:`"test"`, loads the test dataset. (default: :obj:`"train"`)
130
+ force_reload (bool, optional): Whether to re-process the dataset.
131
+ (default: :obj:`False`)
132
+ use_pcst (bool, optional): Whether to preprocess the dataset's graph
133
+ with PCST or return the full graphs. (default: :obj:`True`)
134
+ """
135
+ def __init__(
136
+ self,
137
+ root: str,
138
+ split: str = "train",
139
+ force_reload: bool = False,
140
+ use_pcst: bool = True,
141
+ ) -> None:
142
+ self.use_pcst = use_pcst
143
+ super().__init__(root, force_reload=force_reload)
144
+
145
+ if split not in {'train', 'val', 'test'}:
146
+ raise ValueError(f"Invalid 'split' argument (got {split})")
147
+
148
+ path = self.processed_paths[['train', 'val', 'test'].index(split)]
149
+ self.load(path)
150
+
151
+ @property
152
+ def processed_file_names(self) -> List[str]:
153
+ return ['train_data.pt', 'val_data.pt', 'test_data.pt']
154
+
155
+ def process(self) -> None:
156
+ import datasets
157
+ import pandas as pd
158
+
159
+ datasets = datasets.load_dataset('rmanluo/RoG-webqsp')
160
+
161
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
162
+ model_name = 'sentence-transformers/all-roberta-large-v1'
163
+ model = SentenceTransformer(model_name).to(device)
164
+ model.eval()
165
+
166
+ for dataset, path in zip(
167
+ [datasets['train'], datasets['validation'], datasets['test']],
168
+ self.processed_paths,
169
+ ):
170
+ questions = [example["question"] for example in dataset]
171
+ question_embs = model.encode(
172
+ questions,
173
+ batch_size=256,
174
+ output_device='cpu',
175
+ )
176
+
177
+ data_list = []
178
+ for i, example in enumerate(tqdm(dataset)):
179
+ raw_nodes: Dict[str, int] = {}
180
+ raw_edges = []
181
+ for tri in example["graph"]:
182
+ h, r, t = tri
183
+ h = h.lower()
184
+ t = t.lower()
185
+ if h not in raw_nodes:
186
+ raw_nodes[h] = len(raw_nodes)
187
+ if t not in raw_nodes:
188
+ raw_nodes[t] = len(raw_nodes)
189
+ raw_edges.append({
190
+ "src": raw_nodes[h],
191
+ "edge_attr": r,
192
+ "dst": raw_nodes[t]
193
+ })
194
+ nodes = pd.DataFrame([{
195
+ "node_id": v,
196
+ "node_attr": k,
197
+ } for k, v in raw_nodes.items()],
198
+ columns=["node_id", "node_attr"])
199
+ edges = pd.DataFrame(raw_edges,
200
+ columns=["src", "edge_attr", "dst"])
201
+
202
+ nodes.node_attr = nodes.node_attr.fillna("")
203
+ x = model.encode(
204
+ nodes.node_attr.tolist(),
205
+ batch_size=256,
206
+ output_device='cpu',
207
+ )
208
+ edge_attr = model.encode(
209
+ edges.edge_attr.tolist(),
210
+ batch_size=256,
211
+ output_device='cpu',
212
+ )
213
+ edge_index = torch.tensor([
214
+ edges.src.tolist(),
215
+ edges.dst.tolist(),
216
+ ], dtype=torch.long)
217
+
218
+ question = f"Question: {example['question']}\nAnswer: "
219
+ label = ('|').join(example['answer']).lower()
220
+ data = Data(
221
+ x=x,
222
+ edge_index=edge_index,
223
+ edge_attr=edge_attr,
224
+ )
225
+ if self.use_pcst and len(nodes) > 0 and len(edges) > 0:
226
+ data, desc = retrieval_via_pcst(
227
+ data,
228
+ question_embs[i],
229
+ nodes,
230
+ edges,
231
+ topk=3,
232
+ topk_e=5,
233
+ cost_e=0.5,
234
+ )
235
+ else:
236
+ desc = nodes.to_csv(index=False) + "\n" + edges.to_csv(
237
+ index=False,
238
+ columns=["src", "edge_attr", "dst"],
239
+ )
240
+
241
+ data.question = question
242
+ data.label = label
243
+ data.desc = desc
244
+ data_list.append(data)
245
+
246
+ self.save(data_list, path)
@@ -102,7 +102,7 @@ class WebKB(InMemoryDataset):
102
102
  download_url(f'{self.url}/splits/{f}', self.raw_dir)
103
103
 
104
104
  def process(self) -> None:
105
- with open(self.raw_paths[0], 'r') as f:
105
+ with open(self.raw_paths[0]) as f:
106
106
  lines = f.read().split('\n')[1:-1]
107
107
  xs = [[float(value) for value in line.split('\t')[1].split(',')]
108
108
  for line in lines]
@@ -111,7 +111,7 @@ class WebKB(InMemoryDataset):
111
111
  ys = [int(line.split('\t')[2]) for line in lines]
112
112
  y = torch.tensor(ys, dtype=torch.long)
113
113
 
114
- with open(self.raw_paths[1], 'r') as f:
114
+ with open(self.raw_paths[1]) as f:
115
115
  lines = f.read().split('\n')[1:-1]
116
116
  edge_indices = [[int(value) for value in line.split('\t')]
117
117
  for line in lines]
@@ -65,7 +65,7 @@ class WikiCS(InMemoryDataset):
65
65
  download_url(f'{self.url}/{name}', self.raw_dir)
66
66
 
67
67
  def process(self) -> None:
68
- with open(self.raw_paths[0], 'r') as f:
68
+ with open(self.raw_paths[0]) as f:
69
69
  data = json.load(f)
70
70
 
71
71
  x = torch.tensor(data['features'], dtype=torch.float)
@@ -10,6 +10,7 @@ from torch_geometric.data import (
10
10
  download_url,
11
11
  extract_tar,
12
12
  )
13
+ from torch_geometric.io import fs
13
14
 
14
15
 
15
16
  class Wikidata5M(InMemoryDataset):
@@ -99,7 +100,7 @@ class Wikidata5M(InMemoryDataset):
99
100
  values = line.strip().split('\t')
100
101
  entity_to_id[values[0]] = i
101
102
 
102
- x = torch.load(self.raw_paths[1])
103
+ x = fs.torch_load(self.raw_paths[1])
103
104
 
104
105
  edge_indices = []
105
106
  edge_types = []
@@ -107,7 +108,7 @@ class Wikidata5M(InMemoryDataset):
107
108
 
108
109
  rel_to_id: Dict[str, int] = {}
109
110
  for split, path in enumerate(self.raw_paths[2:]):
110
- with open(path, 'r') as f:
111
+ with open(path) as f:
111
112
  for line in f:
112
113
  head, rel, tail = line[:-1].split('\t')
113
114
  src = entity_to_id[head]
@@ -105,7 +105,7 @@ class WikipediaNetwork(InMemoryDataset):
105
105
 
106
106
  def process(self) -> None:
107
107
  if self.geom_gcn_preprocess:
108
- with open(self.raw_paths[0], 'r') as f:
108
+ with open(self.raw_paths[0]) as f:
109
109
  lines = f.read().split('\n')[1:-1]
110
110
  xs = [[float(value) for value in line.split('\t')[1].split(',')]
111
111
  for line in lines]
@@ -113,7 +113,7 @@ class WikipediaNetwork(InMemoryDataset):
113
113
  ys = [int(line.split('\t')[2]) for line in lines]
114
114
  y = torch.tensor(ys, dtype=torch.long)
115
115
 
116
- with open(self.raw_paths[1], 'r') as f:
116
+ with open(self.raw_paths[1]) as f:
117
117
  lines = f.read().split('\n')[1:-1]
118
118
  edge_indices = [[int(value) for value in line.split('\t')]
119
119
  for line in lines]