pyg-nightly 2.6.0.dev20240704__py3-none-any.whl → 2.8.0.dev20251207__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pyg-nightly might be problematic. Click here for more details.

Files changed (268) hide show
  1. {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/METADATA +81 -58
  2. {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/RECORD +265 -221
  3. {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/WHEEL +1 -1
  4. pyg_nightly-2.8.0.dev20251207.dist-info/licenses/LICENSE +19 -0
  5. torch_geometric/__init__.py +34 -1
  6. torch_geometric/_compile.py +11 -3
  7. torch_geometric/_onnx.py +228 -0
  8. torch_geometric/config_mixin.py +8 -3
  9. torch_geometric/config_store.py +1 -1
  10. torch_geometric/contrib/__init__.py +1 -1
  11. torch_geometric/contrib/explain/pgm_explainer.py +1 -1
  12. torch_geometric/data/__init__.py +19 -1
  13. torch_geometric/data/batch.py +2 -2
  14. torch_geometric/data/collate.py +1 -3
  15. torch_geometric/data/data.py +110 -6
  16. torch_geometric/data/database.py +19 -5
  17. torch_geometric/data/dataset.py +14 -9
  18. torch_geometric/data/extract.py +1 -1
  19. torch_geometric/data/feature_store.py +17 -22
  20. torch_geometric/data/graph_store.py +3 -2
  21. torch_geometric/data/hetero_data.py +139 -7
  22. torch_geometric/data/hypergraph_data.py +2 -2
  23. torch_geometric/data/in_memory_dataset.py +2 -2
  24. torch_geometric/data/lightning/datamodule.py +42 -28
  25. torch_geometric/data/storage.py +9 -1
  26. torch_geometric/datasets/__init__.py +20 -1
  27. torch_geometric/datasets/actor.py +7 -9
  28. torch_geometric/datasets/airfrans.py +17 -20
  29. torch_geometric/datasets/airports.py +8 -10
  30. torch_geometric/datasets/amazon.py +8 -11
  31. torch_geometric/datasets/amazon_book.py +8 -9
  32. torch_geometric/datasets/amazon_products.py +7 -9
  33. torch_geometric/datasets/aminer.py +8 -9
  34. torch_geometric/datasets/aqsol.py +10 -13
  35. torch_geometric/datasets/attributed_graph_dataset.py +8 -10
  36. torch_geometric/datasets/ba_multi_shapes.py +10 -12
  37. torch_geometric/datasets/ba_shapes.py +5 -6
  38. torch_geometric/datasets/brca_tgca.py +1 -1
  39. torch_geometric/datasets/city.py +157 -0
  40. torch_geometric/datasets/dbp15k.py +1 -1
  41. torch_geometric/datasets/gdelt_lite.py +3 -2
  42. torch_geometric/datasets/ged_dataset.py +3 -2
  43. torch_geometric/datasets/git_mol_dataset.py +263 -0
  44. torch_geometric/datasets/gnn_benchmark_dataset.py +3 -2
  45. torch_geometric/datasets/hgb_dataset.py +2 -2
  46. torch_geometric/datasets/hm.py +1 -1
  47. torch_geometric/datasets/instruct_mol_dataset.py +134 -0
  48. torch_geometric/datasets/linkx_dataset.py +4 -3
  49. torch_geometric/datasets/lrgb.py +3 -5
  50. torch_geometric/datasets/malnet_tiny.py +2 -1
  51. torch_geometric/datasets/md17.py +3 -3
  52. torch_geometric/datasets/medshapenet.py +145 -0
  53. torch_geometric/datasets/mnist_superpixels.py +2 -3
  54. torch_geometric/datasets/modelnet.py +1 -1
  55. torch_geometric/datasets/molecule_gpt_dataset.py +492 -0
  56. torch_geometric/datasets/molecule_net.py +3 -2
  57. torch_geometric/datasets/neurograph.py +1 -3
  58. torch_geometric/datasets/ogb_mag.py +1 -1
  59. torch_geometric/datasets/opf.py +19 -5
  60. torch_geometric/datasets/pascal_pf.py +1 -1
  61. torch_geometric/datasets/pcqm4m.py +2 -1
  62. torch_geometric/datasets/ppi.py +2 -1
  63. torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
  64. torch_geometric/datasets/qm7.py +1 -1
  65. torch_geometric/datasets/qm9.py +3 -2
  66. torch_geometric/datasets/shrec2016.py +2 -2
  67. torch_geometric/datasets/snap_dataset.py +8 -4
  68. torch_geometric/datasets/tag_dataset.py +462 -0
  69. torch_geometric/datasets/teeth3ds.py +269 -0
  70. torch_geometric/datasets/web_qsp_dataset.py +342 -0
  71. torch_geometric/datasets/wikics.py +2 -1
  72. torch_geometric/datasets/wikidata.py +2 -1
  73. torch_geometric/deprecation.py +1 -1
  74. torch_geometric/distributed/__init__.py +13 -0
  75. torch_geometric/distributed/dist_loader.py +2 -2
  76. torch_geometric/distributed/local_feature_store.py +3 -2
  77. torch_geometric/distributed/local_graph_store.py +2 -1
  78. torch_geometric/distributed/partition.py +9 -8
  79. torch_geometric/distributed/rpc.py +3 -3
  80. torch_geometric/edge_index.py +35 -22
  81. torch_geometric/explain/algorithm/attention_explainer.py +219 -29
  82. torch_geometric/explain/algorithm/base.py +2 -2
  83. torch_geometric/explain/algorithm/captum.py +1 -1
  84. torch_geometric/explain/algorithm/captum_explainer.py +2 -1
  85. torch_geometric/explain/algorithm/gnn_explainer.py +406 -69
  86. torch_geometric/explain/algorithm/graphmask_explainer.py +8 -8
  87. torch_geometric/explain/algorithm/pg_explainer.py +305 -47
  88. torch_geometric/explain/explainer.py +2 -2
  89. torch_geometric/explain/explanation.py +89 -5
  90. torch_geometric/explain/metric/faithfulness.py +1 -1
  91. torch_geometric/graphgym/checkpoint.py +2 -1
  92. torch_geometric/graphgym/config.py +3 -2
  93. torch_geometric/graphgym/imports.py +15 -4
  94. torch_geometric/graphgym/logger.py +1 -1
  95. torch_geometric/graphgym/loss.py +1 -1
  96. torch_geometric/graphgym/models/encoder.py +2 -2
  97. torch_geometric/graphgym/models/layer.py +1 -1
  98. torch_geometric/graphgym/utils/comp_budget.py +4 -3
  99. torch_geometric/hash_tensor.py +798 -0
  100. torch_geometric/index.py +16 -7
  101. torch_geometric/inspector.py +6 -2
  102. torch_geometric/io/fs.py +27 -0
  103. torch_geometric/io/tu.py +2 -3
  104. torch_geometric/llm/__init__.py +9 -0
  105. torch_geometric/llm/large_graph_indexer.py +741 -0
  106. torch_geometric/llm/models/__init__.py +23 -0
  107. torch_geometric/llm/models/g_retriever.py +251 -0
  108. torch_geometric/llm/models/git_mol.py +336 -0
  109. torch_geometric/llm/models/glem.py +397 -0
  110. torch_geometric/llm/models/llm.py +470 -0
  111. torch_geometric/llm/models/llm_judge.py +158 -0
  112. torch_geometric/llm/models/molecule_gpt.py +222 -0
  113. torch_geometric/llm/models/protein_mpnn.py +333 -0
  114. torch_geometric/llm/models/sentence_transformer.py +188 -0
  115. torch_geometric/llm/models/txt2kg.py +353 -0
  116. torch_geometric/llm/models/vision_transformer.py +38 -0
  117. torch_geometric/llm/rag_loader.py +154 -0
  118. torch_geometric/llm/utils/__init__.py +10 -0
  119. torch_geometric/llm/utils/backend_utils.py +443 -0
  120. torch_geometric/llm/utils/feature_store.py +169 -0
  121. torch_geometric/llm/utils/graph_store.py +199 -0
  122. torch_geometric/llm/utils/vectorrag.py +125 -0
  123. torch_geometric/loader/cluster.py +6 -5
  124. torch_geometric/loader/graph_saint.py +2 -1
  125. torch_geometric/loader/ibmb_loader.py +4 -4
  126. torch_geometric/loader/link_loader.py +1 -1
  127. torch_geometric/loader/link_neighbor_loader.py +2 -1
  128. torch_geometric/loader/mixin.py +6 -5
  129. torch_geometric/loader/neighbor_loader.py +1 -1
  130. torch_geometric/loader/neighbor_sampler.py +2 -2
  131. torch_geometric/loader/prefetch.py +4 -3
  132. torch_geometric/loader/temporal_dataloader.py +2 -2
  133. torch_geometric/loader/utils.py +10 -10
  134. torch_geometric/metrics/__init__.py +23 -2
  135. torch_geometric/metrics/link_pred.py +755 -85
  136. torch_geometric/nn/__init__.py +1 -0
  137. torch_geometric/nn/aggr/__init__.py +2 -0
  138. torch_geometric/nn/aggr/base.py +1 -1
  139. torch_geometric/nn/aggr/equilibrium.py +1 -1
  140. torch_geometric/nn/aggr/fused.py +1 -1
  141. torch_geometric/nn/aggr/patch_transformer.py +149 -0
  142. torch_geometric/nn/aggr/set_transformer.py +1 -1
  143. torch_geometric/nn/aggr/utils.py +9 -4
  144. torch_geometric/nn/attention/__init__.py +9 -1
  145. torch_geometric/nn/attention/polynormer.py +107 -0
  146. torch_geometric/nn/attention/qformer.py +71 -0
  147. torch_geometric/nn/attention/sgformer.py +99 -0
  148. torch_geometric/nn/conv/__init__.py +2 -0
  149. torch_geometric/nn/conv/appnp.py +1 -1
  150. torch_geometric/nn/conv/collect.jinja +6 -3
  151. torch_geometric/nn/conv/cugraph/gat_conv.py +8 -2
  152. torch_geometric/nn/conv/cugraph/rgcn_conv.py +3 -0
  153. torch_geometric/nn/conv/cugraph/sage_conv.py +3 -0
  154. torch_geometric/nn/conv/dna_conv.py +1 -1
  155. torch_geometric/nn/conv/eg_conv.py +7 -7
  156. torch_geometric/nn/conv/gat_conv.py +33 -4
  157. torch_geometric/nn/conv/gatv2_conv.py +35 -4
  158. torch_geometric/nn/conv/gen_conv.py +1 -1
  159. torch_geometric/nn/conv/general_conv.py +1 -1
  160. torch_geometric/nn/conv/gravnet_conv.py +2 -1
  161. torch_geometric/nn/conv/hetero_conv.py +3 -2
  162. torch_geometric/nn/conv/meshcnn_conv.py +487 -0
  163. torch_geometric/nn/conv/message_passing.py +6 -5
  164. torch_geometric/nn/conv/mixhop_conv.py +1 -1
  165. torch_geometric/nn/conv/rgcn_conv.py +2 -1
  166. torch_geometric/nn/conv/sg_conv.py +1 -1
  167. torch_geometric/nn/conv/spline_conv.py +2 -1
  168. torch_geometric/nn/conv/ssg_conv.py +1 -1
  169. torch_geometric/nn/conv/transformer_conv.py +5 -3
  170. torch_geometric/nn/data_parallel.py +5 -4
  171. torch_geometric/nn/dense/linear.py +5 -24
  172. torch_geometric/nn/encoding.py +17 -3
  173. torch_geometric/nn/fx.py +17 -15
  174. torch_geometric/nn/model_hub.py +5 -16
  175. torch_geometric/nn/models/__init__.py +11 -0
  176. torch_geometric/nn/models/attentive_fp.py +1 -1
  177. torch_geometric/nn/models/attract_repel.py +148 -0
  178. torch_geometric/nn/models/basic_gnn.py +2 -1
  179. torch_geometric/nn/models/captum.py +1 -1
  180. torch_geometric/nn/models/deep_graph_infomax.py +1 -1
  181. torch_geometric/nn/models/dimenet.py +2 -2
  182. torch_geometric/nn/models/dimenet_utils.py +4 -2
  183. torch_geometric/nn/models/gpse.py +1083 -0
  184. torch_geometric/nn/models/graph_unet.py +13 -4
  185. torch_geometric/nn/models/lpformer.py +783 -0
  186. torch_geometric/nn/models/metapath2vec.py +1 -1
  187. torch_geometric/nn/models/mlp.py +4 -2
  188. torch_geometric/nn/models/node2vec.py +1 -1
  189. torch_geometric/nn/models/polynormer.py +206 -0
  190. torch_geometric/nn/models/rev_gnn.py +3 -3
  191. torch_geometric/nn/models/schnet.py +2 -1
  192. torch_geometric/nn/models/sgformer.py +219 -0
  193. torch_geometric/nn/models/signed_gcn.py +1 -1
  194. torch_geometric/nn/models/visnet.py +2 -2
  195. torch_geometric/nn/norm/batch_norm.py +17 -7
  196. torch_geometric/nn/norm/diff_group_norm.py +7 -2
  197. torch_geometric/nn/norm/graph_norm.py +9 -4
  198. torch_geometric/nn/norm/instance_norm.py +5 -1
  199. torch_geometric/nn/norm/layer_norm.py +15 -7
  200. torch_geometric/nn/norm/msg_norm.py +8 -2
  201. torch_geometric/nn/pool/__init__.py +15 -9
  202. torch_geometric/nn/pool/cluster_pool.py +144 -0
  203. torch_geometric/nn/pool/connect/base.py +1 -3
  204. torch_geometric/nn/pool/edge_pool.py +1 -1
  205. torch_geometric/nn/pool/knn.py +13 -10
  206. torch_geometric/nn/pool/select/base.py +1 -4
  207. torch_geometric/nn/summary.py +1 -1
  208. torch_geometric/nn/to_hetero_module.py +4 -3
  209. torch_geometric/nn/to_hetero_transformer.py +3 -3
  210. torch_geometric/nn/to_hetero_with_bases_transformer.py +5 -5
  211. torch_geometric/profile/__init__.py +2 -0
  212. torch_geometric/profile/nvtx.py +66 -0
  213. torch_geometric/profile/profiler.py +18 -9
  214. torch_geometric/profile/utils.py +20 -5
  215. torch_geometric/sampler/__init__.py +2 -1
  216. torch_geometric/sampler/base.py +337 -8
  217. torch_geometric/sampler/hgt_sampler.py +11 -1
  218. torch_geometric/sampler/neighbor_sampler.py +298 -25
  219. torch_geometric/sampler/utils.py +93 -5
  220. torch_geometric/testing/__init__.py +4 -0
  221. torch_geometric/testing/decorators.py +35 -5
  222. torch_geometric/testing/distributed.py +1 -1
  223. torch_geometric/transforms/__init__.py +4 -0
  224. torch_geometric/transforms/add_gpse.py +49 -0
  225. torch_geometric/transforms/add_metapaths.py +10 -8
  226. torch_geometric/transforms/add_positional_encoding.py +2 -2
  227. torch_geometric/transforms/base_transform.py +2 -1
  228. torch_geometric/transforms/delaunay.py +65 -15
  229. torch_geometric/transforms/face_to_edge.py +32 -3
  230. torch_geometric/transforms/gdc.py +8 -9
  231. torch_geometric/transforms/largest_connected_components.py +1 -1
  232. torch_geometric/transforms/mask.py +5 -1
  233. torch_geometric/transforms/node_property_split.py +1 -1
  234. torch_geometric/transforms/normalize_features.py +3 -3
  235. torch_geometric/transforms/pad.py +1 -1
  236. torch_geometric/transforms/random_link_split.py +1 -1
  237. torch_geometric/transforms/remove_duplicated_edges.py +4 -2
  238. torch_geometric/transforms/remove_self_loops.py +36 -0
  239. torch_geometric/transforms/rooted_subgraph.py +1 -1
  240. torch_geometric/transforms/svd_feature_reduction.py +1 -1
  241. torch_geometric/transforms/virtual_node.py +2 -1
  242. torch_geometric/typing.py +82 -17
  243. torch_geometric/utils/__init__.py +6 -1
  244. torch_geometric/utils/_lexsort.py +0 -9
  245. torch_geometric/utils/_negative_sampling.py +28 -13
  246. torch_geometric/utils/_normalize_edge_index.py +46 -0
  247. torch_geometric/utils/_scatter.py +126 -164
  248. torch_geometric/utils/_sort_edge_index.py +0 -2
  249. torch_geometric/utils/_spmm.py +16 -14
  250. torch_geometric/utils/_subgraph.py +4 -0
  251. torch_geometric/utils/_tree_decomposition.py +1 -1
  252. torch_geometric/utils/_trim_to_layer.py +2 -2
  253. torch_geometric/utils/augmentation.py +1 -1
  254. torch_geometric/utils/convert.py +17 -10
  255. torch_geometric/utils/cross_entropy.py +34 -13
  256. torch_geometric/utils/embedding.py +91 -2
  257. torch_geometric/utils/geodesic.py +28 -25
  258. torch_geometric/utils/influence.py +279 -0
  259. torch_geometric/utils/map.py +14 -10
  260. torch_geometric/utils/nested.py +1 -1
  261. torch_geometric/utils/smiles.py +3 -3
  262. torch_geometric/utils/sparse.py +32 -24
  263. torch_geometric/visualization/__init__.py +2 -1
  264. torch_geometric/visualization/graph.py +250 -5
  265. torch_geometric/warnings.py +11 -2
  266. torch_geometric/nn/nlp/__init__.py +0 -7
  267. torch_geometric/nn/nlp/llm.py +0 -283
  268. torch_geometric/nn/nlp/sentence_transformer.py +0 -94
@@ -0,0 +1,462 @@
1
+ import csv
2
+ import os
3
+ import os.path as osp
4
+ from collections.abc import Sequence
5
+ from typing import Dict, List, Optional, Union
6
+
7
+ import numpy as np
8
+ import torch
9
+ from torch import Tensor
10
+ from tqdm import tqdm
11
+
12
+ from torch_geometric.data import InMemoryDataset, download_google_url
13
+ from torch_geometric.data.data import BaseData
14
+ from torch_geometric.io import fs
15
+
16
+ try:
17
+ from pandas import DataFrame, read_csv
18
+ WITH_PANDAS = True
19
+ except ImportError:
20
+ WITH_PANDAS = False
21
+
22
+ IndexType = Union[slice, Tensor, np.ndarray, Sequence]
23
+
24
+
25
+ class TAGDataset(InMemoryDataset):
26
+ r"""The Text Attributed Graph datasets from the
27
+ `"Learning on Large-scale Text-attributed Graphs via Variational Inference"
28
+ <https://arxiv.org/abs/2210.14709>`_ paper and `"Harnessing Explanations:
29
+ LLM-to-LM Interpreter for Enhanced Text-Attributed Graph Representation
30
+ Learning" <https://arxiv.org/abs/2305.19523>`_ paper.
31
+ This dataset is aiming on transform `ogbn products`, `ogbn arxiv`
32
+ into Text Attributed Graph that each node in graph is associate with a
33
+ raw text, LLM prediction and explanation, that dataset can be adapt to
34
+ DataLoader (for LM training) and NeighborLoader(for GNN training).
35
+ In addition, this class can be use as a wrapper class by convert a
36
+ InMemoryDataset with Tokenizer and text into Text Attributed Graph.
37
+
38
+ Args:
39
+ root (str): Root directory where the dataset should be saved.
40
+ dataset (InMemoryDataset): The name of the dataset
41
+ (:obj:`"ogbn-products"`, :obj:`"ogbn-arxiv"`).
42
+ tokenizer_name (str): The tokenizer name for language model,
43
+ Be sure to use same tokenizer name as your `model id` of model repo
44
+ on huggingface.co.
45
+ text (List[str]): list of raw text associate with node, the order of
46
+ list should be align with node list
47
+ split_idx (Optional[Dict[str, torch.Tensor]]): Optional dictionary,
48
+ for saving split index, it is required that if your dataset doesn't
49
+ have get_split_idx function
50
+ tokenize_batch_size (int): batch size of tokenizing text, the
51
+ tokenizing process will run on cpu, default: 256
52
+ token_on_disk (bool): save token as .pt file on disk or not,
53
+ default: False
54
+ text_on_disk (bool): save given text(list of str) as dataframe on disk
55
+ or not, default: False
56
+ force_reload (bool): default: False
57
+ .. note::
58
+ See `example/llm/glem.py` for example usage
59
+ """
60
+ raw_text_id = {
61
+ 'ogbn-arxiv': '1g3OOVhRyiyKv13LY6gbp8GLITocOUr_3',
62
+ 'ogbn-products': '1I-S176-W4Bm1iPDjQv3hYwQBtxE0v8mt'
63
+ }
64
+
65
+ llm_prediction_url = 'https://github.com/XiaoxinHe/TAPE/raw/main/gpt_preds'
66
+
67
+ llm_explanation_id = {
68
+ 'ogbn-arxiv': '1o8n2xRen-N_elF9NQpIca0iCHJgEJbRQ',
69
+ }
70
+
71
+ def __init__(
72
+ self,
73
+ root: str,
74
+ dataset: InMemoryDataset,
75
+ tokenizer_name: str,
76
+ text: Optional[List[str]] = None,
77
+ split_idx: Optional[Dict[str, Tensor]] = None,
78
+ tokenize_batch_size: int = 256,
79
+ token_on_disk: bool = False,
80
+ text_on_disk: bool = False,
81
+ force_reload: bool = False,
82
+ ) -> None:
83
+ # list the vars you want to pass in before run download & process
84
+ self.name = dataset.name
85
+ self.text = text
86
+ self.llm_prediction_topk = 5
87
+ self.tokenizer_name = tokenizer_name
88
+ from transformers import AutoTokenizer
89
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
90
+ if self.tokenizer.pad_token_id is None:
91
+ self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
92
+ if self.tokenizer.pad_token is None:
93
+ self.tokenizer.pad_token = self.tokenizer.eos_token
94
+
95
+ self.dir_name = '_'.join(dataset.name.split('-'))
96
+ self.root = osp.join(root, self.dir_name)
97
+ missing_str_list = []
98
+ if not WITH_PANDAS:
99
+ missing_str_list.append('pandas')
100
+ if len(missing_str_list) > 0:
101
+ missing_str = ' '.join(missing_str_list)
102
+ error_out = f"`pip install {missing_str}` to use this dataset."
103
+ raise ImportError(error_out)
104
+ if hasattr(dataset, 'get_idx_split'):
105
+ self.split_idx = dataset.get_idx_split()
106
+ elif split_idx is not None:
107
+ self.split_idx = split_idx
108
+ else:
109
+ raise ValueError("TAGDataset need split idx for generating "
110
+ "is_gold mask, please pass splited index "
111
+ "in format of dictionaty with 'train', 'valid' "
112
+ "'test' index tensor to 'split_idx'")
113
+ if text_on_disk:
114
+ if text is not None:
115
+ self.save_node_text(text)
116
+ self.text_on_disk = text_on_disk
117
+ # init will call download and process
118
+ super().__init__(self.root, transform=None, pre_transform=None,
119
+ pre_filter=None, force_reload=force_reload)
120
+ # after processing and download
121
+ # Dataset has to have BaseData as _data
122
+ assert dataset._data is not None
123
+ self._data = dataset._data # reassign reference
124
+ assert self._data is not None
125
+ assert dataset._data.y is not None
126
+ assert isinstance(self._data, BaseData)
127
+ assert self._data.num_nodes is not None
128
+ assert isinstance(dataset._data.num_nodes, int)
129
+ assert isinstance(self._data.num_nodes, int)
130
+ self._n_id = torch.arange(self._data.num_nodes)
131
+ is_good_tensor = self.load_gold_mask()
132
+ self._is_gold = is_good_tensor.squeeze()
133
+ self._data['is_gold'] = is_good_tensor
134
+ if self.text is not None and len(self.text) != self._data.num_nodes:
135
+ raise ValueError("The number of text sequence in 'text' should be "
136
+ "equal to number of nodes!")
137
+ self.token_on_disk = token_on_disk
138
+ self.tokenize_batch_size = tokenize_batch_size
139
+ self._token = self.tokenize_graph(self.tokenize_batch_size)
140
+ self._llm_explanation_token: Dict[str, Tensor] = {}
141
+ self._all_token: Dict[str, Tensor] = {}
142
+ if self.name in self.llm_explanation_id:
143
+ self._llm_explanation_token = self.tokenize_graph(
144
+ self.tokenize_batch_size, text_type='llm_explanation')
145
+ self._all_token = self.tokenize_graph(self.tokenize_batch_size,
146
+ text_type='all')
147
+ self.__num_classes__ = dataset.num_classes
148
+
149
+ @property
150
+ def num_classes(self) -> int:
151
+ return self.__num_classes__
152
+
153
+ @property
154
+ def raw_file_names(self) -> List[str]:
155
+ file_names = []
156
+ for _, _, files in os.walk(osp.join(self.root, 'raw')):
157
+ for file in files:
158
+ file_names.append(file)
159
+ return file_names
160
+
161
+ @property
162
+ def processed_file_names(self) -> List[str]:
163
+ return [
164
+ 'geometric_data_processed.pt', 'pre_filter.pt',
165
+ 'pre_transformed.pt'
166
+ ]
167
+
168
+ @property
169
+ def token(self) -> Dict[str, Tensor]:
170
+ if self._token is None: # lazy load
171
+ self._token = self.tokenize_graph()
172
+ return self._token
173
+
174
+ @property
175
+ def llm_explanation_token(self) -> Dict[str, Tensor]:
176
+ if self._llm_explanation_token is None and \
177
+ self.name in self.llm_explanation_id:
178
+ self._llm_explanation_token = self.tokenize_graph(
179
+ text_type='llm_explanation')
180
+ return self._llm_explanation_token
181
+
182
+ @property
183
+ def all_token(self) -> Dict[str, Tensor]:
184
+ if self._all_token is None and \
185
+ self.name in self.llm_explanation_id:
186
+ self._all_token = self.tokenize_graph(text_type='all')
187
+ return self._all_token
188
+
189
+ # load is_gold after init
190
+ @property
191
+ def is_gold(self) -> Tensor:
192
+ if self._is_gold is None:
193
+ print('lazy load is_gold!!')
194
+ self._is_gold = self.load_gold_mask()
195
+ return self._is_gold
196
+
197
+ def get_n_id(self, node_idx: IndexType) -> Tensor:
198
+ if self._n_id is None:
199
+ assert self._data is not None
200
+ assert self._data.num_nodes is not None
201
+ assert isinstance(self._data.num_nodes, int)
202
+ self._n_id = torch.arange(self._data.num_nodes)
203
+ return self._n_id[node_idx]
204
+
205
+ def load_gold_mask(self) -> Tensor:
206
+ r"""Use original train split as gold split, generating is_gold mask
207
+ for picking ground truth labels and pseudo labels.
208
+ """
209
+ train_split_idx = self.get_idx_split()['train']
210
+ assert self._data is not None
211
+ assert self._data.num_nodes is not None
212
+ assert isinstance(self._data.num_nodes, int)
213
+ is_good_tensor = torch.zeros(self._data.num_nodes,
214
+ dtype=torch.bool).view(-1, 1)
215
+ is_good_tensor[train_split_idx] = True
216
+ return is_good_tensor
217
+
218
+ def get_gold(self, node_idx: IndexType) -> Tensor:
219
+ r"""Get gold mask for given node_idx.
220
+
221
+ Args:
222
+ node_idx (torch.tensor): a tensor contain node idx
223
+ """
224
+ if self._is_gold is None:
225
+ self._is_gold = self.is_gold
226
+ return self._is_gold[node_idx]
227
+
228
+ def get_idx_split(self) -> Dict[str, Tensor]:
229
+ return self.split_idx
230
+
231
+ def download(self) -> None:
232
+ print('downloading raw text')
233
+ raw_text_path = download_google_url(id=self.raw_text_id[self.name],
234
+ folder=f'{self.root}/raw',
235
+ filename='node-text.csv.gz',
236
+ log=True)
237
+ self.text = list(read_csv(raw_text_path)['text'])
238
+ if self.name in self.llm_explanation_id:
239
+ print('downloading llm explanations')
240
+ llm_explanation_path = download_google_url(
241
+ id=self.llm_explanation_id[self.name],
242
+ folder=f'{self.root}/raw', filename='node-gpt-response.csv.gz',
243
+ log=True)
244
+ self.llm_explanation = list(read_csv(llm_explanation_path)['text'])
245
+ print('downloading llm predictions')
246
+ fs.cp(f'{self.llm_prediction_url}/{self.name}.csv', self.raw_dir)
247
+
248
+ def process(self) -> None:
249
+ # process Title and Abstraction
250
+ if osp.exists(osp.join(self.root, 'raw', 'node-text.csv.gz')):
251
+ text_df = read_csv(osp.join(self.root, 'raw', 'node-text.csv.gz'))
252
+ self.text = list(text_df['text'])
253
+ elif self.name in self.raw_text_id:
254
+ self.download()
255
+ else:
256
+ print('The dataset is not ogbn-products nor ogbn-arxiv,'
257
+ 'please pass in your raw text string list to `text`')
258
+ if self.text is None:
259
+ raise ValueError("The TAGDataset only have ogbn-products and "
260
+ "ogbn-arxiv raw text in default "
261
+ "The raw text of each node is not specified"
262
+ "Please pass in 'text' when convert your dataset "
263
+ "to Text Attribute Graph Dataset")
264
+ # process LLM explanation and prediction
265
+ llm_explanation_path = f'{self.raw_dir}/node-gpt-response.csv.gz'
266
+ llm_prediction_path = f'{self.raw_dir}/{self.name}.csv'
267
+ if osp.exists(llm_explanation_path) and osp.exists(
268
+ llm_prediction_path):
269
+ # load LLM explanation
270
+ self.llm_explanation = list(read_csv(llm_explanation_path)['text'])
271
+ # load LLM prediction
272
+ preds = []
273
+ with open(llm_prediction_path) as file:
274
+ reader = csv.reader(file)
275
+ for row in reader:
276
+ inner_list = []
277
+ for value in row:
278
+ inner_list.append(int(value))
279
+ preds.append(inner_list)
280
+
281
+ pl = torch.zeros(len(preds), self.llm_prediction_topk,
282
+ dtype=torch.long)
283
+ for i, pred in enumerate(preds):
284
+ pl[i][:len(pred)] = torch.tensor(
285
+ pred[:self.llm_prediction_topk], dtype=torch.long) + 1
286
+
287
+ if self.llm_explanation is None or pl is None:
288
+ raise ValueError(
289
+ "The TAGDataset only have ogbn-arxiv LLM explanations"
290
+ "and predictions in default. The llm explanation and"
291
+ "prediction of each node is not specified.Please pass in"
292
+ "'llm_explanation' and 'llm_prediction' when"
293
+ "convert your dataset to Text Attribute Graph Dataset")
294
+ elif self.name in self.llm_explanation_id:
295
+ self.download()
296
+ else:
297
+ print(
298
+ 'The dataset is not ogbn-arxiv,'
299
+ 'please pass in your llm explanation list to `llm_explanation`'
300
+ 'and llm prediction list to `llm_prediction`')
301
+
302
+ def save_node_text(self, text: List[str]) -> None:
303
+ node_text_path = osp.join(self.root, 'raw', 'node-text.csv.gz')
304
+ if osp.exists(node_text_path):
305
+ print(f'The raw text is existed at {node_text_path}')
306
+ else:
307
+ print(f'Saving raw text file at {node_text_path}')
308
+ os.makedirs(f'{self.root}/raw', exist_ok=True)
309
+ text_df = DataFrame(text, columns=['text'])
310
+ text_df.to_csv(osp.join(node_text_path), compression='gzip',
311
+ index=False)
312
+
313
+ def tokenize_graph(self, batch_size: int = 256,
314
+ text_type: str = 'raw_text') -> Dict[str, Tensor]:
315
+ r"""Tokenizing the text associate with each node, running in cpu.
316
+
317
+ Args:
318
+ batch_size (Optional[int]): batch size of list of text for
319
+ generating emebdding
320
+ text_type (Optional[str]): type of text
321
+ Returns:
322
+ Dict[str, torch.Tensor]: tokenized graph
323
+ """
324
+ assert text_type in ['raw_text', 'llm_explanation', 'all']
325
+ if text_type == 'raw_text':
326
+ _text = self.text
327
+ elif text_type == 'llm_explanation':
328
+ _text = self.llm_explanation
329
+ elif text_type == 'all':
330
+ if self.text is None or self.llm_explanation is None:
331
+ raise ValueError("The TAGDataset need text and llm explanation"
332
+ "for tokenizing all text")
333
+ _text = [
334
+ f'{raw_txt} Explanation: {exp_txt}'
335
+ for raw_txt, exp_txt in zip(self.text, self.llm_explanation)
336
+ ]
337
+
338
+ data_len = 0
339
+ if _text is not None:
340
+ data_len = len(_text)
341
+ else:
342
+ raise ValueError("The TAGDataset need text for tokenization")
343
+ token_keys = ['input_ids', 'token_type_ids', 'attention_mask']
344
+ path = os.path.join(self.processed_dir, 'token', text_type,
345
+ self.tokenizer_name)
346
+ # Check if the .pt files already exist
347
+ token_files_exist = any(
348
+ os.path.exists(os.path.join(path, f'{k}.pt')) for k in token_keys)
349
+
350
+ if token_files_exist and self.token_on_disk:
351
+ print('Found tokenized file, loading may take several minutes...')
352
+ all_encoded_token = {
353
+ k: torch.load(os.path.join(path, f'{k}.pt'), weights_only=True)
354
+ for k in token_keys
355
+ if os.path.exists(os.path.join(path, f'{k}.pt'))
356
+ }
357
+ return all_encoded_token
358
+
359
+ all_encoded_token = {k: [] for k in token_keys}
360
+ pbar = tqdm(total=data_len)
361
+
362
+ pbar.set_description(f'Tokenizing Text Attributed Graph {text_type}')
363
+ for i in range(0, data_len, batch_size):
364
+ end_index = min(data_len, i + batch_size)
365
+ token = self.tokenizer(_text[i:end_index], padding='max_length',
366
+ truncation=True, max_length=512,
367
+ return_tensors="pt")
368
+ for k in token.keys():
369
+ all_encoded_token[k].append(token[k])
370
+ pbar.update(end_index - i)
371
+ pbar.close()
372
+
373
+ all_encoded_token = {
374
+ k: torch.cat(v)
375
+ for k, v in all_encoded_token.items() if len(v) > 0
376
+ }
377
+ if self.token_on_disk:
378
+ os.makedirs(path, exist_ok=True)
379
+ print('Saving tokens on Disk')
380
+ for k, tensor in all_encoded_token.items():
381
+ torch.save(tensor, os.path.join(path, f'{k}.pt'))
382
+ print('Token saved:', os.path.join(path, f'{k}.pt'))
383
+ os.environ["TOKENIZERS_PARALLELISM"] = 'true' # suppressing warning
384
+ return all_encoded_token
385
+
386
+ def __repr__(self) -> str:
387
+ return f'{self.__class__.__name__}()'
388
+
389
+ class TextDataset(torch.utils.data.Dataset):
390
+ r"""This nested dataset provides textual data for each node in
391
+ the graph. Factory method to create TextDataset from TAGDataset.
392
+
393
+ Args:
394
+ tag_dataset (TAGDataset): the parent dataset
395
+ text_type (str): type of text
396
+ """
397
+ def __init__(self, tag_dataset: 'TAGDataset',
398
+ text_type: str = 'raw_text') -> None:
399
+ assert text_type in ['raw_text', 'llm_explanation', 'all']
400
+ self.tag_dataset = tag_dataset
401
+ if text_type == 'raw_text':
402
+ self.token = tag_dataset.token
403
+ elif text_type == 'llm_explanation':
404
+ self.token = tag_dataset.llm_explanation_token
405
+ elif text_type == 'all':
406
+ self.token = tag_dataset.all_token
407
+ assert tag_dataset._data is not None
408
+ self._data = tag_dataset._data
409
+
410
+ assert tag_dataset._data.y is not None
411
+ self.labels = tag_dataset._data.y
412
+
413
+ def get_token(self, node_idx: IndexType) -> Dict[str, Tensor]:
414
+ r"""This function will be called in __getitem__().
415
+
416
+ Args:
417
+ node_idx (IndexType): selected node idx in each batch
418
+ Returns:
419
+ items (Dict[str, Tensor]): input for LM
420
+ """
421
+ items = {k: v[node_idx] for k, v in self.token.items()}
422
+ return items
423
+
424
+ # for LM training
425
+ def __getitem__(
426
+ self,
427
+ node_id: IndexType,
428
+ ) -> Dict[str, Union[Tensor, Dict[str, Tensor]]]:
429
+ r"""This function will override the function in
430
+ torch.utils.data.Dataset, and will be called when you
431
+ iterate batch in the dataloader, make sure all following
432
+ key value pairs are present in the return dict.
433
+
434
+ Args:
435
+ node_id (List[int]): list of node idx for selecting tokens,
436
+ labels etc. when iterating data loader for LM
437
+ Returns:
438
+ items (dict): input k,v pairs for Language model training and
439
+ inference
440
+ """
441
+ item: Dict[str, Union[Tensor, Dict[str, Tensor]]] = {}
442
+ item['input'] = self.get_token(node_id)
443
+ item['labels'] = self.labels[node_id]
444
+ item['is_gold'] = self.tag_dataset.get_gold(node_id)
445
+ item['n_id'] = self.tag_dataset.get_n_id(node_id)
446
+ return item
447
+
448
+ def __len__(self) -> int:
449
+ assert self._data.num_nodes is not None
450
+ return self._data.num_nodes
451
+
452
+ def get(self, idx: int) -> BaseData:
453
+ return self._data
454
+
455
+ def __repr__(self) -> str:
456
+ return f'{self.__class__.__name__}()'
457
+
458
+ def to_text_dataset(self, text_type: str = 'raw_text') -> TextDataset:
459
+ r"""Factory Build text dataset from Text Attributed Graph Dataset
460
+ each data point is node's associated text token.
461
+ """
462
+ return TAGDataset.TextDataset(self, text_type)