pyg-nightly 2.7.0.dev20241009__py3-none-any.whl → 2.8.0.dev20251207__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pyg-nightly might be problematic. Click here for more details.
- {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/METADATA +77 -53
- {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/RECORD +226 -189
- {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/WHEEL +1 -1
- pyg_nightly-2.8.0.dev20251207.dist-info/licenses/LICENSE +19 -0
- torch_geometric/__init__.py +14 -2
- torch_geometric/_compile.py +9 -3
- torch_geometric/_onnx.py +214 -0
- torch_geometric/config_mixin.py +5 -3
- torch_geometric/config_store.py +1 -1
- torch_geometric/contrib/__init__.py +1 -1
- torch_geometric/contrib/explain/pgm_explainer.py +1 -1
- torch_geometric/data/batch.py +2 -2
- torch_geometric/data/collate.py +1 -3
- torch_geometric/data/data.py +109 -5
- torch_geometric/data/database.py +4 -0
- torch_geometric/data/dataset.py +14 -11
- torch_geometric/data/extract.py +1 -1
- torch_geometric/data/feature_store.py +17 -22
- torch_geometric/data/graph_store.py +3 -2
- torch_geometric/data/hetero_data.py +139 -7
- torch_geometric/data/hypergraph_data.py +2 -2
- torch_geometric/data/in_memory_dataset.py +2 -2
- torch_geometric/data/lightning/datamodule.py +42 -28
- torch_geometric/data/storage.py +9 -1
- torch_geometric/datasets/__init__.py +18 -1
- torch_geometric/datasets/actor.py +7 -9
- torch_geometric/datasets/airfrans.py +15 -17
- torch_geometric/datasets/airports.py +8 -10
- torch_geometric/datasets/amazon.py +8 -11
- torch_geometric/datasets/amazon_book.py +8 -9
- torch_geometric/datasets/amazon_products.py +7 -9
- torch_geometric/datasets/aminer.py +8 -9
- torch_geometric/datasets/aqsol.py +10 -13
- torch_geometric/datasets/attributed_graph_dataset.py +8 -10
- torch_geometric/datasets/ba_multi_shapes.py +10 -12
- torch_geometric/datasets/ba_shapes.py +5 -6
- torch_geometric/datasets/city.py +157 -0
- torch_geometric/datasets/dbp15k.py +1 -1
- torch_geometric/datasets/git_mol_dataset.py +263 -0
- torch_geometric/datasets/hgb_dataset.py +2 -2
- torch_geometric/datasets/hm.py +1 -1
- torch_geometric/datasets/instruct_mol_dataset.py +134 -0
- torch_geometric/datasets/md17.py +3 -3
- torch_geometric/datasets/medshapenet.py +145 -0
- torch_geometric/datasets/modelnet.py +1 -1
- torch_geometric/datasets/molecule_gpt_dataset.py +492 -0
- torch_geometric/datasets/molecule_net.py +3 -2
- torch_geometric/datasets/ppi.py +2 -1
- torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
- torch_geometric/datasets/qm7.py +1 -1
- torch_geometric/datasets/qm9.py +1 -1
- torch_geometric/datasets/snap_dataset.py +8 -4
- torch_geometric/datasets/tag_dataset.py +462 -0
- torch_geometric/datasets/teeth3ds.py +269 -0
- torch_geometric/datasets/web_qsp_dataset.py +310 -209
- torch_geometric/datasets/wikics.py +2 -1
- torch_geometric/deprecation.py +1 -1
- torch_geometric/distributed/__init__.py +13 -0
- torch_geometric/distributed/dist_loader.py +2 -2
- torch_geometric/distributed/partition.py +2 -2
- torch_geometric/distributed/rpc.py +3 -3
- torch_geometric/edge_index.py +18 -14
- torch_geometric/explain/algorithm/attention_explainer.py +219 -29
- torch_geometric/explain/algorithm/base.py +2 -2
- torch_geometric/explain/algorithm/captum.py +1 -1
- torch_geometric/explain/algorithm/captum_explainer.py +2 -1
- torch_geometric/explain/algorithm/gnn_explainer.py +406 -69
- torch_geometric/explain/algorithm/graphmask_explainer.py +8 -8
- torch_geometric/explain/algorithm/pg_explainer.py +305 -47
- torch_geometric/explain/explainer.py +2 -2
- torch_geometric/explain/explanation.py +87 -3
- torch_geometric/explain/metric/faithfulness.py +1 -1
- torch_geometric/graphgym/config.py +3 -2
- torch_geometric/graphgym/imports.py +15 -4
- torch_geometric/graphgym/logger.py +1 -1
- torch_geometric/graphgym/loss.py +1 -1
- torch_geometric/graphgym/models/encoder.py +2 -2
- torch_geometric/graphgym/models/layer.py +1 -1
- torch_geometric/graphgym/utils/comp_budget.py +4 -3
- torch_geometric/hash_tensor.py +798 -0
- torch_geometric/index.py +14 -5
- torch_geometric/inspector.py +4 -0
- torch_geometric/io/fs.py +5 -4
- torch_geometric/llm/__init__.py +9 -0
- torch_geometric/llm/large_graph_indexer.py +741 -0
- torch_geometric/llm/models/__init__.py +23 -0
- torch_geometric/{nn → llm}/models/g_retriever.py +77 -45
- torch_geometric/llm/models/git_mol.py +336 -0
- torch_geometric/llm/models/glem.py +397 -0
- torch_geometric/{nn/nlp → llm/models}/llm.py +179 -31
- torch_geometric/llm/models/llm_judge.py +158 -0
- torch_geometric/llm/models/molecule_gpt.py +222 -0
- torch_geometric/llm/models/protein_mpnn.py +333 -0
- torch_geometric/llm/models/sentence_transformer.py +188 -0
- torch_geometric/llm/models/txt2kg.py +353 -0
- torch_geometric/llm/models/vision_transformer.py +38 -0
- torch_geometric/llm/rag_loader.py +154 -0
- torch_geometric/llm/utils/__init__.py +10 -0
- torch_geometric/llm/utils/backend_utils.py +443 -0
- torch_geometric/llm/utils/feature_store.py +169 -0
- torch_geometric/llm/utils/graph_store.py +199 -0
- torch_geometric/llm/utils/vectorrag.py +125 -0
- torch_geometric/loader/cluster.py +4 -4
- torch_geometric/loader/ibmb_loader.py +4 -4
- torch_geometric/loader/link_loader.py +1 -1
- torch_geometric/loader/link_neighbor_loader.py +2 -1
- torch_geometric/loader/mixin.py +6 -5
- torch_geometric/loader/neighbor_loader.py +1 -1
- torch_geometric/loader/neighbor_sampler.py +2 -2
- torch_geometric/loader/prefetch.py +3 -2
- torch_geometric/loader/temporal_dataloader.py +2 -2
- torch_geometric/loader/utils.py +10 -10
- torch_geometric/metrics/__init__.py +14 -0
- torch_geometric/metrics/link_pred.py +745 -92
- torch_geometric/nn/__init__.py +1 -0
- torch_geometric/nn/aggr/base.py +1 -1
- torch_geometric/nn/aggr/equilibrium.py +1 -1
- torch_geometric/nn/aggr/fused.py +1 -1
- torch_geometric/nn/aggr/patch_transformer.py +8 -2
- torch_geometric/nn/aggr/set_transformer.py +1 -1
- torch_geometric/nn/aggr/utils.py +9 -4
- torch_geometric/nn/attention/__init__.py +9 -1
- torch_geometric/nn/attention/polynormer.py +107 -0
- torch_geometric/nn/attention/qformer.py +71 -0
- torch_geometric/nn/attention/sgformer.py +99 -0
- torch_geometric/nn/conv/__init__.py +2 -0
- torch_geometric/nn/conv/appnp.py +1 -1
- torch_geometric/nn/conv/cugraph/gat_conv.py +8 -2
- torch_geometric/nn/conv/cugraph/rgcn_conv.py +3 -0
- torch_geometric/nn/conv/cugraph/sage_conv.py +3 -0
- torch_geometric/nn/conv/dna_conv.py +1 -1
- torch_geometric/nn/conv/eg_conv.py +7 -7
- torch_geometric/nn/conv/gen_conv.py +1 -1
- torch_geometric/nn/conv/gravnet_conv.py +2 -1
- torch_geometric/nn/conv/hetero_conv.py +2 -1
- torch_geometric/nn/conv/meshcnn_conv.py +487 -0
- torch_geometric/nn/conv/message_passing.py +5 -4
- torch_geometric/nn/conv/rgcn_conv.py +2 -1
- torch_geometric/nn/conv/sg_conv.py +1 -1
- torch_geometric/nn/conv/spline_conv.py +2 -1
- torch_geometric/nn/conv/ssg_conv.py +1 -1
- torch_geometric/nn/conv/transformer_conv.py +5 -3
- torch_geometric/nn/data_parallel.py +5 -4
- torch_geometric/nn/dense/linear.py +0 -20
- torch_geometric/nn/encoding.py +17 -3
- torch_geometric/nn/fx.py +14 -12
- torch_geometric/nn/model_hub.py +2 -15
- torch_geometric/nn/models/__init__.py +11 -2
- torch_geometric/nn/models/attentive_fp.py +1 -1
- torch_geometric/nn/models/attract_repel.py +148 -0
- torch_geometric/nn/models/basic_gnn.py +2 -1
- torch_geometric/nn/models/captum.py +1 -1
- torch_geometric/nn/models/deep_graph_infomax.py +1 -1
- torch_geometric/nn/models/dimenet.py +2 -2
- torch_geometric/nn/models/dimenet_utils.py +4 -2
- torch_geometric/nn/models/gpse.py +1083 -0
- torch_geometric/nn/models/graph_unet.py +13 -4
- torch_geometric/nn/models/lpformer.py +783 -0
- torch_geometric/nn/models/metapath2vec.py +1 -1
- torch_geometric/nn/models/mlp.py +4 -2
- torch_geometric/nn/models/node2vec.py +1 -1
- torch_geometric/nn/models/polynormer.py +206 -0
- torch_geometric/nn/models/rev_gnn.py +3 -3
- torch_geometric/nn/models/sgformer.py +219 -0
- torch_geometric/nn/models/signed_gcn.py +1 -1
- torch_geometric/nn/models/visnet.py +2 -2
- torch_geometric/nn/norm/batch_norm.py +17 -7
- torch_geometric/nn/norm/diff_group_norm.py +7 -2
- torch_geometric/nn/norm/graph_norm.py +9 -4
- torch_geometric/nn/norm/instance_norm.py +5 -1
- torch_geometric/nn/norm/layer_norm.py +15 -7
- torch_geometric/nn/norm/msg_norm.py +8 -2
- torch_geometric/nn/pool/__init__.py +8 -4
- torch_geometric/nn/pool/cluster_pool.py +3 -4
- torch_geometric/nn/pool/connect/base.py +1 -3
- torch_geometric/nn/pool/knn.py +13 -10
- torch_geometric/nn/pool/select/base.py +1 -4
- torch_geometric/nn/to_hetero_module.py +4 -3
- torch_geometric/nn/to_hetero_transformer.py +3 -3
- torch_geometric/nn/to_hetero_with_bases_transformer.py +4 -4
- torch_geometric/profile/__init__.py +2 -0
- torch_geometric/profile/nvtx.py +66 -0
- torch_geometric/profile/utils.py +20 -5
- torch_geometric/sampler/__init__.py +2 -1
- torch_geometric/sampler/base.py +336 -7
- torch_geometric/sampler/hgt_sampler.py +11 -1
- torch_geometric/sampler/neighbor_sampler.py +296 -23
- torch_geometric/sampler/utils.py +93 -5
- torch_geometric/testing/__init__.py +4 -0
- torch_geometric/testing/decorators.py +35 -5
- torch_geometric/testing/distributed.py +1 -1
- torch_geometric/transforms/__init__.py +2 -0
- torch_geometric/transforms/add_gpse.py +49 -0
- torch_geometric/transforms/add_metapaths.py +8 -6
- torch_geometric/transforms/add_positional_encoding.py +2 -2
- torch_geometric/transforms/base_transform.py +2 -1
- torch_geometric/transforms/delaunay.py +65 -15
- torch_geometric/transforms/face_to_edge.py +32 -3
- torch_geometric/transforms/gdc.py +7 -8
- torch_geometric/transforms/largest_connected_components.py +1 -1
- torch_geometric/transforms/mask.py +5 -1
- torch_geometric/transforms/normalize_features.py +3 -3
- torch_geometric/transforms/random_link_split.py +1 -1
- torch_geometric/transforms/remove_duplicated_edges.py +4 -2
- torch_geometric/transforms/rooted_subgraph.py +1 -1
- torch_geometric/typing.py +70 -17
- torch_geometric/utils/__init__.py +4 -1
- torch_geometric/utils/_lexsort.py +0 -9
- torch_geometric/utils/_negative_sampling.py +27 -12
- torch_geometric/utils/_scatter.py +132 -195
- torch_geometric/utils/_sort_edge_index.py +0 -2
- torch_geometric/utils/_spmm.py +16 -14
- torch_geometric/utils/_subgraph.py +4 -0
- torch_geometric/utils/_trim_to_layer.py +2 -2
- torch_geometric/utils/convert.py +17 -10
- torch_geometric/utils/cross_entropy.py +34 -13
- torch_geometric/utils/embedding.py +91 -2
- torch_geometric/utils/geodesic.py +4 -3
- torch_geometric/utils/influence.py +279 -0
- torch_geometric/utils/map.py +13 -9
- torch_geometric/utils/nested.py +1 -1
- torch_geometric/utils/smiles.py +3 -3
- torch_geometric/utils/sparse.py +7 -14
- torch_geometric/visualization/__init__.py +2 -1
- torch_geometric/visualization/graph.py +250 -5
- torch_geometric/warnings.py +11 -2
- torch_geometric/nn/nlp/__init__.py +0 -7
- torch_geometric/nn/nlp/sentence_transformer.py +0 -101
|
@@ -0,0 +1,462 @@
|
|
|
1
|
+
import csv
|
|
2
|
+
import os
|
|
3
|
+
import os.path as osp
|
|
4
|
+
from collections.abc import Sequence
|
|
5
|
+
from typing import Dict, List, Optional, Union
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import torch
|
|
9
|
+
from torch import Tensor
|
|
10
|
+
from tqdm import tqdm
|
|
11
|
+
|
|
12
|
+
from torch_geometric.data import InMemoryDataset, download_google_url
|
|
13
|
+
from torch_geometric.data.data import BaseData
|
|
14
|
+
from torch_geometric.io import fs
|
|
15
|
+
|
|
16
|
+
try:
|
|
17
|
+
from pandas import DataFrame, read_csv
|
|
18
|
+
WITH_PANDAS = True
|
|
19
|
+
except ImportError:
|
|
20
|
+
WITH_PANDAS = False
|
|
21
|
+
|
|
22
|
+
IndexType = Union[slice, Tensor, np.ndarray, Sequence]
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class TAGDataset(InMemoryDataset):
|
|
26
|
+
r"""The Text Attributed Graph datasets from the
|
|
27
|
+
`"Learning on Large-scale Text-attributed Graphs via Variational Inference"
|
|
28
|
+
<https://arxiv.org/abs/2210.14709>`_ paper and `"Harnessing Explanations:
|
|
29
|
+
LLM-to-LM Interpreter for Enhanced Text-Attributed Graph Representation
|
|
30
|
+
Learning" <https://arxiv.org/abs/2305.19523>`_ paper.
|
|
31
|
+
This dataset is aiming on transform `ogbn products`, `ogbn arxiv`
|
|
32
|
+
into Text Attributed Graph that each node in graph is associate with a
|
|
33
|
+
raw text, LLM prediction and explanation, that dataset can be adapt to
|
|
34
|
+
DataLoader (for LM training) and NeighborLoader(for GNN training).
|
|
35
|
+
In addition, this class can be use as a wrapper class by convert a
|
|
36
|
+
InMemoryDataset with Tokenizer and text into Text Attributed Graph.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
root (str): Root directory where the dataset should be saved.
|
|
40
|
+
dataset (InMemoryDataset): The name of the dataset
|
|
41
|
+
(:obj:`"ogbn-products"`, :obj:`"ogbn-arxiv"`).
|
|
42
|
+
tokenizer_name (str): The tokenizer name for language model,
|
|
43
|
+
Be sure to use same tokenizer name as your `model id` of model repo
|
|
44
|
+
on huggingface.co.
|
|
45
|
+
text (List[str]): list of raw text associate with node, the order of
|
|
46
|
+
list should be align with node list
|
|
47
|
+
split_idx (Optional[Dict[str, torch.Tensor]]): Optional dictionary,
|
|
48
|
+
for saving split index, it is required that if your dataset doesn't
|
|
49
|
+
have get_split_idx function
|
|
50
|
+
tokenize_batch_size (int): batch size of tokenizing text, the
|
|
51
|
+
tokenizing process will run on cpu, default: 256
|
|
52
|
+
token_on_disk (bool): save token as .pt file on disk or not,
|
|
53
|
+
default: False
|
|
54
|
+
text_on_disk (bool): save given text(list of str) as dataframe on disk
|
|
55
|
+
or not, default: False
|
|
56
|
+
force_reload (bool): default: False
|
|
57
|
+
.. note::
|
|
58
|
+
See `example/llm/glem.py` for example usage
|
|
59
|
+
"""
|
|
60
|
+
raw_text_id = {
|
|
61
|
+
'ogbn-arxiv': '1g3OOVhRyiyKv13LY6gbp8GLITocOUr_3',
|
|
62
|
+
'ogbn-products': '1I-S176-W4Bm1iPDjQv3hYwQBtxE0v8mt'
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
llm_prediction_url = 'https://github.com/XiaoxinHe/TAPE/raw/main/gpt_preds'
|
|
66
|
+
|
|
67
|
+
llm_explanation_id = {
|
|
68
|
+
'ogbn-arxiv': '1o8n2xRen-N_elF9NQpIca0iCHJgEJbRQ',
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
def __init__(
|
|
72
|
+
self,
|
|
73
|
+
root: str,
|
|
74
|
+
dataset: InMemoryDataset,
|
|
75
|
+
tokenizer_name: str,
|
|
76
|
+
text: Optional[List[str]] = None,
|
|
77
|
+
split_idx: Optional[Dict[str, Tensor]] = None,
|
|
78
|
+
tokenize_batch_size: int = 256,
|
|
79
|
+
token_on_disk: bool = False,
|
|
80
|
+
text_on_disk: bool = False,
|
|
81
|
+
force_reload: bool = False,
|
|
82
|
+
) -> None:
|
|
83
|
+
# list the vars you want to pass in before run download & process
|
|
84
|
+
self.name = dataset.name
|
|
85
|
+
self.text = text
|
|
86
|
+
self.llm_prediction_topk = 5
|
|
87
|
+
self.tokenizer_name = tokenizer_name
|
|
88
|
+
from transformers import AutoTokenizer
|
|
89
|
+
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
|
90
|
+
if self.tokenizer.pad_token_id is None:
|
|
91
|
+
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
|
|
92
|
+
if self.tokenizer.pad_token is None:
|
|
93
|
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
|
94
|
+
|
|
95
|
+
self.dir_name = '_'.join(dataset.name.split('-'))
|
|
96
|
+
self.root = osp.join(root, self.dir_name)
|
|
97
|
+
missing_str_list = []
|
|
98
|
+
if not WITH_PANDAS:
|
|
99
|
+
missing_str_list.append('pandas')
|
|
100
|
+
if len(missing_str_list) > 0:
|
|
101
|
+
missing_str = ' '.join(missing_str_list)
|
|
102
|
+
error_out = f"`pip install {missing_str}` to use this dataset."
|
|
103
|
+
raise ImportError(error_out)
|
|
104
|
+
if hasattr(dataset, 'get_idx_split'):
|
|
105
|
+
self.split_idx = dataset.get_idx_split()
|
|
106
|
+
elif split_idx is not None:
|
|
107
|
+
self.split_idx = split_idx
|
|
108
|
+
else:
|
|
109
|
+
raise ValueError("TAGDataset need split idx for generating "
|
|
110
|
+
"is_gold mask, please pass splited index "
|
|
111
|
+
"in format of dictionaty with 'train', 'valid' "
|
|
112
|
+
"'test' index tensor to 'split_idx'")
|
|
113
|
+
if text_on_disk:
|
|
114
|
+
if text is not None:
|
|
115
|
+
self.save_node_text(text)
|
|
116
|
+
self.text_on_disk = text_on_disk
|
|
117
|
+
# init will call download and process
|
|
118
|
+
super().__init__(self.root, transform=None, pre_transform=None,
|
|
119
|
+
pre_filter=None, force_reload=force_reload)
|
|
120
|
+
# after processing and download
|
|
121
|
+
# Dataset has to have BaseData as _data
|
|
122
|
+
assert dataset._data is not None
|
|
123
|
+
self._data = dataset._data # reassign reference
|
|
124
|
+
assert self._data is not None
|
|
125
|
+
assert dataset._data.y is not None
|
|
126
|
+
assert isinstance(self._data, BaseData)
|
|
127
|
+
assert self._data.num_nodes is not None
|
|
128
|
+
assert isinstance(dataset._data.num_nodes, int)
|
|
129
|
+
assert isinstance(self._data.num_nodes, int)
|
|
130
|
+
self._n_id = torch.arange(self._data.num_nodes)
|
|
131
|
+
is_good_tensor = self.load_gold_mask()
|
|
132
|
+
self._is_gold = is_good_tensor.squeeze()
|
|
133
|
+
self._data['is_gold'] = is_good_tensor
|
|
134
|
+
if self.text is not None and len(self.text) != self._data.num_nodes:
|
|
135
|
+
raise ValueError("The number of text sequence in 'text' should be "
|
|
136
|
+
"equal to number of nodes!")
|
|
137
|
+
self.token_on_disk = token_on_disk
|
|
138
|
+
self.tokenize_batch_size = tokenize_batch_size
|
|
139
|
+
self._token = self.tokenize_graph(self.tokenize_batch_size)
|
|
140
|
+
self._llm_explanation_token: Dict[str, Tensor] = {}
|
|
141
|
+
self._all_token: Dict[str, Tensor] = {}
|
|
142
|
+
if self.name in self.llm_explanation_id:
|
|
143
|
+
self._llm_explanation_token = self.tokenize_graph(
|
|
144
|
+
self.tokenize_batch_size, text_type='llm_explanation')
|
|
145
|
+
self._all_token = self.tokenize_graph(self.tokenize_batch_size,
|
|
146
|
+
text_type='all')
|
|
147
|
+
self.__num_classes__ = dataset.num_classes
|
|
148
|
+
|
|
149
|
+
@property
|
|
150
|
+
def num_classes(self) -> int:
|
|
151
|
+
return self.__num_classes__
|
|
152
|
+
|
|
153
|
+
@property
|
|
154
|
+
def raw_file_names(self) -> List[str]:
|
|
155
|
+
file_names = []
|
|
156
|
+
for _, _, files in os.walk(osp.join(self.root, 'raw')):
|
|
157
|
+
for file in files:
|
|
158
|
+
file_names.append(file)
|
|
159
|
+
return file_names
|
|
160
|
+
|
|
161
|
+
@property
|
|
162
|
+
def processed_file_names(self) -> List[str]:
|
|
163
|
+
return [
|
|
164
|
+
'geometric_data_processed.pt', 'pre_filter.pt',
|
|
165
|
+
'pre_transformed.pt'
|
|
166
|
+
]
|
|
167
|
+
|
|
168
|
+
@property
|
|
169
|
+
def token(self) -> Dict[str, Tensor]:
|
|
170
|
+
if self._token is None: # lazy load
|
|
171
|
+
self._token = self.tokenize_graph()
|
|
172
|
+
return self._token
|
|
173
|
+
|
|
174
|
+
@property
|
|
175
|
+
def llm_explanation_token(self) -> Dict[str, Tensor]:
|
|
176
|
+
if self._llm_explanation_token is None and \
|
|
177
|
+
self.name in self.llm_explanation_id:
|
|
178
|
+
self._llm_explanation_token = self.tokenize_graph(
|
|
179
|
+
text_type='llm_explanation')
|
|
180
|
+
return self._llm_explanation_token
|
|
181
|
+
|
|
182
|
+
@property
|
|
183
|
+
def all_token(self) -> Dict[str, Tensor]:
|
|
184
|
+
if self._all_token is None and \
|
|
185
|
+
self.name in self.llm_explanation_id:
|
|
186
|
+
self._all_token = self.tokenize_graph(text_type='all')
|
|
187
|
+
return self._all_token
|
|
188
|
+
|
|
189
|
+
# load is_gold after init
|
|
190
|
+
@property
|
|
191
|
+
def is_gold(self) -> Tensor:
|
|
192
|
+
if self._is_gold is None:
|
|
193
|
+
print('lazy load is_gold!!')
|
|
194
|
+
self._is_gold = self.load_gold_mask()
|
|
195
|
+
return self._is_gold
|
|
196
|
+
|
|
197
|
+
def get_n_id(self, node_idx: IndexType) -> Tensor:
|
|
198
|
+
if self._n_id is None:
|
|
199
|
+
assert self._data is not None
|
|
200
|
+
assert self._data.num_nodes is not None
|
|
201
|
+
assert isinstance(self._data.num_nodes, int)
|
|
202
|
+
self._n_id = torch.arange(self._data.num_nodes)
|
|
203
|
+
return self._n_id[node_idx]
|
|
204
|
+
|
|
205
|
+
def load_gold_mask(self) -> Tensor:
|
|
206
|
+
r"""Use original train split as gold split, generating is_gold mask
|
|
207
|
+
for picking ground truth labels and pseudo labels.
|
|
208
|
+
"""
|
|
209
|
+
train_split_idx = self.get_idx_split()['train']
|
|
210
|
+
assert self._data is not None
|
|
211
|
+
assert self._data.num_nodes is not None
|
|
212
|
+
assert isinstance(self._data.num_nodes, int)
|
|
213
|
+
is_good_tensor = torch.zeros(self._data.num_nodes,
|
|
214
|
+
dtype=torch.bool).view(-1, 1)
|
|
215
|
+
is_good_tensor[train_split_idx] = True
|
|
216
|
+
return is_good_tensor
|
|
217
|
+
|
|
218
|
+
def get_gold(self, node_idx: IndexType) -> Tensor:
|
|
219
|
+
r"""Get gold mask for given node_idx.
|
|
220
|
+
|
|
221
|
+
Args:
|
|
222
|
+
node_idx (torch.tensor): a tensor contain node idx
|
|
223
|
+
"""
|
|
224
|
+
if self._is_gold is None:
|
|
225
|
+
self._is_gold = self.is_gold
|
|
226
|
+
return self._is_gold[node_idx]
|
|
227
|
+
|
|
228
|
+
def get_idx_split(self) -> Dict[str, Tensor]:
|
|
229
|
+
return self.split_idx
|
|
230
|
+
|
|
231
|
+
def download(self) -> None:
|
|
232
|
+
print('downloading raw text')
|
|
233
|
+
raw_text_path = download_google_url(id=self.raw_text_id[self.name],
|
|
234
|
+
folder=f'{self.root}/raw',
|
|
235
|
+
filename='node-text.csv.gz',
|
|
236
|
+
log=True)
|
|
237
|
+
self.text = list(read_csv(raw_text_path)['text'])
|
|
238
|
+
if self.name in self.llm_explanation_id:
|
|
239
|
+
print('downloading llm explanations')
|
|
240
|
+
llm_explanation_path = download_google_url(
|
|
241
|
+
id=self.llm_explanation_id[self.name],
|
|
242
|
+
folder=f'{self.root}/raw', filename='node-gpt-response.csv.gz',
|
|
243
|
+
log=True)
|
|
244
|
+
self.llm_explanation = list(read_csv(llm_explanation_path)['text'])
|
|
245
|
+
print('downloading llm predictions')
|
|
246
|
+
fs.cp(f'{self.llm_prediction_url}/{self.name}.csv', self.raw_dir)
|
|
247
|
+
|
|
248
|
+
def process(self) -> None:
|
|
249
|
+
# process Title and Abstraction
|
|
250
|
+
if osp.exists(osp.join(self.root, 'raw', 'node-text.csv.gz')):
|
|
251
|
+
text_df = read_csv(osp.join(self.root, 'raw', 'node-text.csv.gz'))
|
|
252
|
+
self.text = list(text_df['text'])
|
|
253
|
+
elif self.name in self.raw_text_id:
|
|
254
|
+
self.download()
|
|
255
|
+
else:
|
|
256
|
+
print('The dataset is not ogbn-products nor ogbn-arxiv,'
|
|
257
|
+
'please pass in your raw text string list to `text`')
|
|
258
|
+
if self.text is None:
|
|
259
|
+
raise ValueError("The TAGDataset only have ogbn-products and "
|
|
260
|
+
"ogbn-arxiv raw text in default "
|
|
261
|
+
"The raw text of each node is not specified"
|
|
262
|
+
"Please pass in 'text' when convert your dataset "
|
|
263
|
+
"to Text Attribute Graph Dataset")
|
|
264
|
+
# process LLM explanation and prediction
|
|
265
|
+
llm_explanation_path = f'{self.raw_dir}/node-gpt-response.csv.gz'
|
|
266
|
+
llm_prediction_path = f'{self.raw_dir}/{self.name}.csv'
|
|
267
|
+
if osp.exists(llm_explanation_path) and osp.exists(
|
|
268
|
+
llm_prediction_path):
|
|
269
|
+
# load LLM explanation
|
|
270
|
+
self.llm_explanation = list(read_csv(llm_explanation_path)['text'])
|
|
271
|
+
# load LLM prediction
|
|
272
|
+
preds = []
|
|
273
|
+
with open(llm_prediction_path) as file:
|
|
274
|
+
reader = csv.reader(file)
|
|
275
|
+
for row in reader:
|
|
276
|
+
inner_list = []
|
|
277
|
+
for value in row:
|
|
278
|
+
inner_list.append(int(value))
|
|
279
|
+
preds.append(inner_list)
|
|
280
|
+
|
|
281
|
+
pl = torch.zeros(len(preds), self.llm_prediction_topk,
|
|
282
|
+
dtype=torch.long)
|
|
283
|
+
for i, pred in enumerate(preds):
|
|
284
|
+
pl[i][:len(pred)] = torch.tensor(
|
|
285
|
+
pred[:self.llm_prediction_topk], dtype=torch.long) + 1
|
|
286
|
+
|
|
287
|
+
if self.llm_explanation is None or pl is None:
|
|
288
|
+
raise ValueError(
|
|
289
|
+
"The TAGDataset only have ogbn-arxiv LLM explanations"
|
|
290
|
+
"and predictions in default. The llm explanation and"
|
|
291
|
+
"prediction of each node is not specified.Please pass in"
|
|
292
|
+
"'llm_explanation' and 'llm_prediction' when"
|
|
293
|
+
"convert your dataset to Text Attribute Graph Dataset")
|
|
294
|
+
elif self.name in self.llm_explanation_id:
|
|
295
|
+
self.download()
|
|
296
|
+
else:
|
|
297
|
+
print(
|
|
298
|
+
'The dataset is not ogbn-arxiv,'
|
|
299
|
+
'please pass in your llm explanation list to `llm_explanation`'
|
|
300
|
+
'and llm prediction list to `llm_prediction`')
|
|
301
|
+
|
|
302
|
+
def save_node_text(self, text: List[str]) -> None:
|
|
303
|
+
node_text_path = osp.join(self.root, 'raw', 'node-text.csv.gz')
|
|
304
|
+
if osp.exists(node_text_path):
|
|
305
|
+
print(f'The raw text is existed at {node_text_path}')
|
|
306
|
+
else:
|
|
307
|
+
print(f'Saving raw text file at {node_text_path}')
|
|
308
|
+
os.makedirs(f'{self.root}/raw', exist_ok=True)
|
|
309
|
+
text_df = DataFrame(text, columns=['text'])
|
|
310
|
+
text_df.to_csv(osp.join(node_text_path), compression='gzip',
|
|
311
|
+
index=False)
|
|
312
|
+
|
|
313
|
+
def tokenize_graph(self, batch_size: int = 256,
|
|
314
|
+
text_type: str = 'raw_text') -> Dict[str, Tensor]:
|
|
315
|
+
r"""Tokenizing the text associate with each node, running in cpu.
|
|
316
|
+
|
|
317
|
+
Args:
|
|
318
|
+
batch_size (Optional[int]): batch size of list of text for
|
|
319
|
+
generating emebdding
|
|
320
|
+
text_type (Optional[str]): type of text
|
|
321
|
+
Returns:
|
|
322
|
+
Dict[str, torch.Tensor]: tokenized graph
|
|
323
|
+
"""
|
|
324
|
+
assert text_type in ['raw_text', 'llm_explanation', 'all']
|
|
325
|
+
if text_type == 'raw_text':
|
|
326
|
+
_text = self.text
|
|
327
|
+
elif text_type == 'llm_explanation':
|
|
328
|
+
_text = self.llm_explanation
|
|
329
|
+
elif text_type == 'all':
|
|
330
|
+
if self.text is None or self.llm_explanation is None:
|
|
331
|
+
raise ValueError("The TAGDataset need text and llm explanation"
|
|
332
|
+
"for tokenizing all text")
|
|
333
|
+
_text = [
|
|
334
|
+
f'{raw_txt} Explanation: {exp_txt}'
|
|
335
|
+
for raw_txt, exp_txt in zip(self.text, self.llm_explanation)
|
|
336
|
+
]
|
|
337
|
+
|
|
338
|
+
data_len = 0
|
|
339
|
+
if _text is not None:
|
|
340
|
+
data_len = len(_text)
|
|
341
|
+
else:
|
|
342
|
+
raise ValueError("The TAGDataset need text for tokenization")
|
|
343
|
+
token_keys = ['input_ids', 'token_type_ids', 'attention_mask']
|
|
344
|
+
path = os.path.join(self.processed_dir, 'token', text_type,
|
|
345
|
+
self.tokenizer_name)
|
|
346
|
+
# Check if the .pt files already exist
|
|
347
|
+
token_files_exist = any(
|
|
348
|
+
os.path.exists(os.path.join(path, f'{k}.pt')) for k in token_keys)
|
|
349
|
+
|
|
350
|
+
if token_files_exist and self.token_on_disk:
|
|
351
|
+
print('Found tokenized file, loading may take several minutes...')
|
|
352
|
+
all_encoded_token = {
|
|
353
|
+
k: torch.load(os.path.join(path, f'{k}.pt'), weights_only=True)
|
|
354
|
+
for k in token_keys
|
|
355
|
+
if os.path.exists(os.path.join(path, f'{k}.pt'))
|
|
356
|
+
}
|
|
357
|
+
return all_encoded_token
|
|
358
|
+
|
|
359
|
+
all_encoded_token = {k: [] for k in token_keys}
|
|
360
|
+
pbar = tqdm(total=data_len)
|
|
361
|
+
|
|
362
|
+
pbar.set_description(f'Tokenizing Text Attributed Graph {text_type}')
|
|
363
|
+
for i in range(0, data_len, batch_size):
|
|
364
|
+
end_index = min(data_len, i + batch_size)
|
|
365
|
+
token = self.tokenizer(_text[i:end_index], padding='max_length',
|
|
366
|
+
truncation=True, max_length=512,
|
|
367
|
+
return_tensors="pt")
|
|
368
|
+
for k in token.keys():
|
|
369
|
+
all_encoded_token[k].append(token[k])
|
|
370
|
+
pbar.update(end_index - i)
|
|
371
|
+
pbar.close()
|
|
372
|
+
|
|
373
|
+
all_encoded_token = {
|
|
374
|
+
k: torch.cat(v)
|
|
375
|
+
for k, v in all_encoded_token.items() if len(v) > 0
|
|
376
|
+
}
|
|
377
|
+
if self.token_on_disk:
|
|
378
|
+
os.makedirs(path, exist_ok=True)
|
|
379
|
+
print('Saving tokens on Disk')
|
|
380
|
+
for k, tensor in all_encoded_token.items():
|
|
381
|
+
torch.save(tensor, os.path.join(path, f'{k}.pt'))
|
|
382
|
+
print('Token saved:', os.path.join(path, f'{k}.pt'))
|
|
383
|
+
os.environ["TOKENIZERS_PARALLELISM"] = 'true' # suppressing warning
|
|
384
|
+
return all_encoded_token
|
|
385
|
+
|
|
386
|
+
def __repr__(self) -> str:
|
|
387
|
+
return f'{self.__class__.__name__}()'
|
|
388
|
+
|
|
389
|
+
class TextDataset(torch.utils.data.Dataset):
|
|
390
|
+
r"""This nested dataset provides textual data for each node in
|
|
391
|
+
the graph. Factory method to create TextDataset from TAGDataset.
|
|
392
|
+
|
|
393
|
+
Args:
|
|
394
|
+
tag_dataset (TAGDataset): the parent dataset
|
|
395
|
+
text_type (str): type of text
|
|
396
|
+
"""
|
|
397
|
+
def __init__(self, tag_dataset: 'TAGDataset',
|
|
398
|
+
text_type: str = 'raw_text') -> None:
|
|
399
|
+
assert text_type in ['raw_text', 'llm_explanation', 'all']
|
|
400
|
+
self.tag_dataset = tag_dataset
|
|
401
|
+
if text_type == 'raw_text':
|
|
402
|
+
self.token = tag_dataset.token
|
|
403
|
+
elif text_type == 'llm_explanation':
|
|
404
|
+
self.token = tag_dataset.llm_explanation_token
|
|
405
|
+
elif text_type == 'all':
|
|
406
|
+
self.token = tag_dataset.all_token
|
|
407
|
+
assert tag_dataset._data is not None
|
|
408
|
+
self._data = tag_dataset._data
|
|
409
|
+
|
|
410
|
+
assert tag_dataset._data.y is not None
|
|
411
|
+
self.labels = tag_dataset._data.y
|
|
412
|
+
|
|
413
|
+
def get_token(self, node_idx: IndexType) -> Dict[str, Tensor]:
|
|
414
|
+
r"""This function will be called in __getitem__().
|
|
415
|
+
|
|
416
|
+
Args:
|
|
417
|
+
node_idx (IndexType): selected node idx in each batch
|
|
418
|
+
Returns:
|
|
419
|
+
items (Dict[str, Tensor]): input for LM
|
|
420
|
+
"""
|
|
421
|
+
items = {k: v[node_idx] for k, v in self.token.items()}
|
|
422
|
+
return items
|
|
423
|
+
|
|
424
|
+
# for LM training
|
|
425
|
+
def __getitem__(
|
|
426
|
+
self,
|
|
427
|
+
node_id: IndexType,
|
|
428
|
+
) -> Dict[str, Union[Tensor, Dict[str, Tensor]]]:
|
|
429
|
+
r"""This function will override the function in
|
|
430
|
+
torch.utils.data.Dataset, and will be called when you
|
|
431
|
+
iterate batch in the dataloader, make sure all following
|
|
432
|
+
key value pairs are present in the return dict.
|
|
433
|
+
|
|
434
|
+
Args:
|
|
435
|
+
node_id (List[int]): list of node idx for selecting tokens,
|
|
436
|
+
labels etc. when iterating data loader for LM
|
|
437
|
+
Returns:
|
|
438
|
+
items (dict): input k,v pairs for Language model training and
|
|
439
|
+
inference
|
|
440
|
+
"""
|
|
441
|
+
item: Dict[str, Union[Tensor, Dict[str, Tensor]]] = {}
|
|
442
|
+
item['input'] = self.get_token(node_id)
|
|
443
|
+
item['labels'] = self.labels[node_id]
|
|
444
|
+
item['is_gold'] = self.tag_dataset.get_gold(node_id)
|
|
445
|
+
item['n_id'] = self.tag_dataset.get_n_id(node_id)
|
|
446
|
+
return item
|
|
447
|
+
|
|
448
|
+
def __len__(self) -> int:
|
|
449
|
+
assert self._data.num_nodes is not None
|
|
450
|
+
return self._data.num_nodes
|
|
451
|
+
|
|
452
|
+
def get(self, idx: int) -> BaseData:
|
|
453
|
+
return self._data
|
|
454
|
+
|
|
455
|
+
def __repr__(self) -> str:
|
|
456
|
+
return f'{self.__class__.__name__}()'
|
|
457
|
+
|
|
458
|
+
def to_text_dataset(self, text_type: str = 'raw_text') -> TextDataset:
|
|
459
|
+
r"""Factory Build text dataset from Text Attributed Graph Dataset
|
|
460
|
+
each data point is node's associated text token.
|
|
461
|
+
"""
|
|
462
|
+
return TAGDataset.TextDataset(self, text_type)
|