pyg-nightly 2.7.0.dev20241119__py3-none-any.whl → 2.7.0.dev20241120__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pyg_nightly-2.7.0.dev20241119.dist-info → pyg_nightly-2.7.0.dev20241120.dist-info}/METADATA +1 -1
- {pyg_nightly-2.7.0.dev20241119.dist-info → pyg_nightly-2.7.0.dev20241120.dist-info}/RECORD +14 -9
- torch_geometric/__init__.py +1 -1
- torch_geometric/datasets/__init__.py +4 -0
- torch_geometric/datasets/molecule_gpt_dataset.py +480 -0
- torch_geometric/datasets/tag_dataset.py +350 -0
- torch_geometric/nn/attention/__init__.py +5 -1
- torch_geometric/nn/attention/qformer.py +71 -0
- torch_geometric/nn/models/__init__.py +4 -1
- torch_geometric/nn/models/glem.py +384 -0
- torch_geometric/nn/models/molecule_gpt.py +222 -0
- torch_geometric/nn/nlp/llm.py +1 -1
- torch_geometric/nn/nlp/sentence_transformer.py +3 -0
- {pyg_nightly-2.7.0.dev20241119.dist-info → pyg_nightly-2.7.0.dev20241120.dist-info}/WHEEL +0 -0
@@ -0,0 +1,350 @@
|
|
1
|
+
import os
|
2
|
+
import os.path as osp
|
3
|
+
from collections.abc import Sequence
|
4
|
+
from typing import Dict, List, Optional, Union
|
5
|
+
|
6
|
+
import numpy as np
|
7
|
+
import torch
|
8
|
+
from torch import Tensor
|
9
|
+
from tqdm import tqdm
|
10
|
+
|
11
|
+
from torch_geometric.data import InMemoryDataset, download_google_url
|
12
|
+
from torch_geometric.data.data import BaseData
|
13
|
+
|
14
|
+
try:
|
15
|
+
from pandas import DataFrame, read_csv
|
16
|
+
WITH_PANDAS = True
|
17
|
+
except ImportError:
|
18
|
+
WITH_PANDAS = False
|
19
|
+
|
20
|
+
IndexType = Union[slice, Tensor, np.ndarray, Sequence]
|
21
|
+
|
22
|
+
|
23
|
+
class TAGDataset(InMemoryDataset):
|
24
|
+
r"""The Text Attributed Graph datasets from the
|
25
|
+
`"Learning on Large-scale Text-attributed Graphs via Variational Inference
|
26
|
+
" <https://arxiv.org/abs/2210.14709>`_ paper.
|
27
|
+
This dataset is aiming on transform `ogbn products`, `ogbn arxiv`
|
28
|
+
into Text Attributed Graph that each node in graph is associate with a
|
29
|
+
raw text, that dataset can be adapt to DataLoader (for LM training) and
|
30
|
+
NeighborLoader(for GNN training). In addition, this class can be use as a
|
31
|
+
wrapper class by convert a InMemoryDataset with Tokenizer and text into
|
32
|
+
Text Attributed Graph.
|
33
|
+
|
34
|
+
Args:
|
35
|
+
root (str): Root directory where the dataset should be saved.
|
36
|
+
dataset (InMemoryDataset): The name of the dataset
|
37
|
+
(:obj:`"ogbn-products"`, :obj:`"ogbn-arxiv"`).
|
38
|
+
tokenizer_name (str): The tokenizer name for language model,
|
39
|
+
Be sure to use same tokenizer name as your `model id` of model repo
|
40
|
+
on huggingface.co.
|
41
|
+
text (List[str]): list of raw text associate with node, the order of
|
42
|
+
list should be align with node list
|
43
|
+
split_idx (Optional[Dict[str, torch.Tensor]]): Optional dictionary,
|
44
|
+
for saving split index, it is required that if your dataset doesn't
|
45
|
+
have get_split_idx function
|
46
|
+
tokenize_batch_size (int): batch size of tokenizing text, the
|
47
|
+
tokenizing process will run on cpu, default: 256
|
48
|
+
token_on_disk (bool): save token as .pt file on disk or not,
|
49
|
+
default: False
|
50
|
+
text_on_disk (bool): save given text(list of str) as dataframe on disk
|
51
|
+
or not, default: False
|
52
|
+
force_reload (bool): default: False
|
53
|
+
.. note::
|
54
|
+
See `example/llm_plus_gnn/glem.py` for example usage
|
55
|
+
"""
|
56
|
+
raw_text_id = {
|
57
|
+
'ogbn-arxiv': '1g3OOVhRyiyKv13LY6gbp8GLITocOUr_3',
|
58
|
+
'ogbn-products': '1I-S176-W4Bm1iPDjQv3hYwQBtxE0v8mt'
|
59
|
+
}
|
60
|
+
|
61
|
+
def __init__(self, root: str, dataset: InMemoryDataset,
|
62
|
+
tokenizer_name: str, text: Optional[List[str]] = None,
|
63
|
+
split_idx: Optional[Dict[str, Tensor]] = None,
|
64
|
+
tokenize_batch_size: int = 256, token_on_disk: bool = False,
|
65
|
+
text_on_disk: bool = False,
|
66
|
+
force_reload: bool = False) -> None:
|
67
|
+
# list the vars you want to pass in before run download & process
|
68
|
+
self.name = dataset.name
|
69
|
+
self.text = text
|
70
|
+
self.tokenizer_name = tokenizer_name
|
71
|
+
from transformers import AutoTokenizer
|
72
|
+
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
73
|
+
if self.tokenizer.pad_token_id is None:
|
74
|
+
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
|
75
|
+
if self.tokenizer.pad_token is None:
|
76
|
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
77
|
+
|
78
|
+
self.dir_name = '_'.join(dataset.name.split('-'))
|
79
|
+
self.root = osp.join(root, self.dir_name)
|
80
|
+
missing_str_list = []
|
81
|
+
if not WITH_PANDAS:
|
82
|
+
missing_str_list.append('pandas')
|
83
|
+
if len(missing_str_list) > 0:
|
84
|
+
missing_str = ' '.join(missing_str_list)
|
85
|
+
error_out = f"`pip install {missing_str}` to use this dataset."
|
86
|
+
raise ImportError(error_out)
|
87
|
+
if hasattr(dataset, 'get_idx_split'):
|
88
|
+
self.split_idx = dataset.get_idx_split()
|
89
|
+
elif split_idx is not None:
|
90
|
+
self.split_idx = split_idx
|
91
|
+
else:
|
92
|
+
raise ValueError("TAGDataset need split idx for generating "
|
93
|
+
"is_gold mask, please pass splited index "
|
94
|
+
"in format of dictionaty with 'train', 'valid' "
|
95
|
+
"'test' index tensor to 'split_idx'")
|
96
|
+
if text is not None and text_on_disk:
|
97
|
+
self.save_node_text(text)
|
98
|
+
self.text_on_disk = text_on_disk
|
99
|
+
# init will call download and process
|
100
|
+
super().__init__(self.root, transform=None, pre_transform=None,
|
101
|
+
pre_filter=None, force_reload=force_reload)
|
102
|
+
# after processing and download
|
103
|
+
# Dataset has to have BaseData as _data
|
104
|
+
assert dataset._data is not None
|
105
|
+
self._data = dataset._data # reassign reference
|
106
|
+
assert self._data is not None
|
107
|
+
assert dataset._data.y is not None
|
108
|
+
assert isinstance(self._data, BaseData)
|
109
|
+
assert self._data.num_nodes is not None
|
110
|
+
assert isinstance(dataset._data.num_nodes, int)
|
111
|
+
assert isinstance(self._data.num_nodes, int)
|
112
|
+
self._n_id = torch.arange(self._data.num_nodes)
|
113
|
+
is_good_tensor = self.load_gold_mask()
|
114
|
+
self._is_gold = is_good_tensor.squeeze()
|
115
|
+
self._data['is_gold'] = is_good_tensor
|
116
|
+
if self.text is not None and len(self.text) != self._data.num_nodes:
|
117
|
+
raise ValueError("The number of text sequence in 'text' should be "
|
118
|
+
"equal to number of nodes!")
|
119
|
+
self.token_on_disk = token_on_disk
|
120
|
+
self.tokenize_batch_size = tokenize_batch_size
|
121
|
+
self._token = self.tokenize_graph(self.tokenize_batch_size)
|
122
|
+
self.__num_classes__ = dataset.num_classes
|
123
|
+
|
124
|
+
@property
|
125
|
+
def num_classes(self) -> int:
|
126
|
+
return self.__num_classes__
|
127
|
+
|
128
|
+
@property
|
129
|
+
def raw_file_names(self) -> List[str]:
|
130
|
+
file_names = []
|
131
|
+
for root, _, files in os.walk(osp.join(self.root, 'raw')):
|
132
|
+
for file in files:
|
133
|
+
file_names.append(file)
|
134
|
+
return file_names
|
135
|
+
|
136
|
+
@property
|
137
|
+
def processed_file_names(self) -> List[str]:
|
138
|
+
return [
|
139
|
+
'geometric_data_processed.pt', 'pre_filter.pt',
|
140
|
+
'pre_transformed.pt'
|
141
|
+
]
|
142
|
+
|
143
|
+
@property
|
144
|
+
def token(self) -> Dict[str, Tensor]:
|
145
|
+
if self._token is None: # lazy load
|
146
|
+
self._token = self.tokenize_graph()
|
147
|
+
return self._token
|
148
|
+
|
149
|
+
# load is_gold after init
|
150
|
+
@property
|
151
|
+
def is_gold(self) -> Tensor:
|
152
|
+
if self._is_gold is None:
|
153
|
+
print('lazy load is_gold!!')
|
154
|
+
self._is_gold = self.load_gold_mask()
|
155
|
+
return self._is_gold
|
156
|
+
|
157
|
+
def get_n_id(self, node_idx: IndexType) -> Tensor:
|
158
|
+
if self._n_id is None:
|
159
|
+
assert self._data is not None
|
160
|
+
assert self._data.num_nodes is not None
|
161
|
+
assert isinstance(self._data.num_nodes, int)
|
162
|
+
self._n_id = torch.arange(self._data.num_nodes)
|
163
|
+
return self._n_id[node_idx]
|
164
|
+
|
165
|
+
def load_gold_mask(self) -> Tensor:
|
166
|
+
r"""Use original train split as gold split, generating is_gold mask
|
167
|
+
for picking ground truth labels and pseudo labels.
|
168
|
+
"""
|
169
|
+
train_split_idx = self.get_idx_split()['train']
|
170
|
+
assert self._data is not None
|
171
|
+
assert self._data.num_nodes is not None
|
172
|
+
assert isinstance(self._data.num_nodes, int)
|
173
|
+
is_good_tensor = torch.zeros(self._data.num_nodes,
|
174
|
+
dtype=torch.bool).view(-1, 1)
|
175
|
+
is_good_tensor[train_split_idx] = True
|
176
|
+
return is_good_tensor
|
177
|
+
|
178
|
+
def get_gold(self, node_idx: IndexType) -> Tensor:
|
179
|
+
r"""Get gold mask for given node_idx.
|
180
|
+
|
181
|
+
Args:
|
182
|
+
node_idx (torch.tensor): a tensor contain node idx
|
183
|
+
"""
|
184
|
+
if self._is_gold is None:
|
185
|
+
self._is_gold = self.is_gold
|
186
|
+
return self._is_gold[node_idx]
|
187
|
+
|
188
|
+
def get_idx_split(self) -> Dict[str, Tensor]:
|
189
|
+
return self.split_idx
|
190
|
+
|
191
|
+
def download(self) -> None:
|
192
|
+
print('downloading raw text')
|
193
|
+
raw_text_path = download_google_url(id=self.raw_text_id[self.name],
|
194
|
+
folder=f'{self.root}/raw',
|
195
|
+
filename='node-text.csv.gz',
|
196
|
+
log=True)
|
197
|
+
text_df = read_csv(raw_text_path)
|
198
|
+
self.text = list(text_df['text'])
|
199
|
+
|
200
|
+
def process(self) -> None:
|
201
|
+
if osp.exists(osp.join(self.root, 'raw', 'node-text.csv.gz')):
|
202
|
+
text_df = read_csv(osp.join(self.root, 'raw', 'node-text.csv.gz'))
|
203
|
+
self.text = list(text_df['text'])
|
204
|
+
elif self.name in self.raw_text_id:
|
205
|
+
self.download()
|
206
|
+
else:
|
207
|
+
print('The dataset is not ogbn-products nor ogbn-arxiv,'
|
208
|
+
'please pass in your raw text string list to `text`')
|
209
|
+
if self.text is None:
|
210
|
+
raise ValueError("The TAGDataset only have ogbn-products and "
|
211
|
+
"ogbn-arxiv raw text in default "
|
212
|
+
"The raw text of each node is not specified"
|
213
|
+
"Please pass in 'text' when convert your dataset "
|
214
|
+
"to Text Attribute Graph Dataset")
|
215
|
+
|
216
|
+
def save_node_text(self, text: List[str]) -> None:
|
217
|
+
node_text_path = osp.join(self.root, 'raw', 'node-text.csv.gz')
|
218
|
+
if osp.exists(node_text_path):
|
219
|
+
print(f'The raw text is existed at {node_text_path}')
|
220
|
+
else:
|
221
|
+
print(f'Saving raw text file at {node_text_path}')
|
222
|
+
os.makedirs(f'{self.root}/raw', exist_ok=True)
|
223
|
+
text_df = DataFrame(text, columns=['text'])
|
224
|
+
text_df.to_csv(osp.join(node_text_path), compression='gzip',
|
225
|
+
index=False)
|
226
|
+
|
227
|
+
def tokenize_graph(self, batch_size: int = 256) -> Dict[str, Tensor]:
|
228
|
+
r"""Tokenizing the text associate with each node, running in cpu.
|
229
|
+
|
230
|
+
Args:
|
231
|
+
batch_size (Optional[int]): batch size of list of text for
|
232
|
+
generating emebdding
|
233
|
+
Returns:
|
234
|
+
Dict[str, torch.Tensor]: tokenized graph
|
235
|
+
"""
|
236
|
+
data_len = 0
|
237
|
+
if self.text is not None:
|
238
|
+
data_len = len(self.text)
|
239
|
+
else:
|
240
|
+
raise ValueError("The TAGDataset need text for tokenization")
|
241
|
+
token_keys = ['input_ids', 'token_type_ids', 'attention_mask']
|
242
|
+
path = os.path.join(self.processed_dir, 'token', self.tokenizer_name)
|
243
|
+
# Check if the .pt files already exist
|
244
|
+
token_files_exist = any(
|
245
|
+
os.path.exists(os.path.join(path, f'{k}.pt')) for k in token_keys)
|
246
|
+
|
247
|
+
if token_files_exist and self.token_on_disk:
|
248
|
+
print('Found tokenized file, loading may take several minutes...')
|
249
|
+
all_encoded_token = {
|
250
|
+
k: torch.load(os.path.join(path, f'{k}.pt'), weights_only=True)
|
251
|
+
for k in token_keys
|
252
|
+
if os.path.exists(os.path.join(path, f'{k}.pt'))
|
253
|
+
}
|
254
|
+
return all_encoded_token
|
255
|
+
|
256
|
+
all_encoded_token = {k: [] for k in token_keys}
|
257
|
+
pbar = tqdm(total=data_len)
|
258
|
+
|
259
|
+
pbar.set_description('Tokenizing Text Attributed Graph')
|
260
|
+
for i in range(0, data_len, batch_size):
|
261
|
+
end_index = min(data_len, i + batch_size)
|
262
|
+
token = self.tokenizer(self.text[i:min(i + batch_size, data_len)],
|
263
|
+
padding='max_length', truncation=True,
|
264
|
+
max_length=512, return_tensors="pt")
|
265
|
+
for k in token.keys():
|
266
|
+
all_encoded_token[k].append(token[k])
|
267
|
+
pbar.update(end_index - i)
|
268
|
+
pbar.close()
|
269
|
+
|
270
|
+
all_encoded_token = {
|
271
|
+
k: torch.cat(v)
|
272
|
+
for k, v in all_encoded_token.items() if len(v) > 0
|
273
|
+
}
|
274
|
+
if self.token_on_disk:
|
275
|
+
os.makedirs(path, exist_ok=True)
|
276
|
+
print('Saving tokens on Disk')
|
277
|
+
for k, tensor in all_encoded_token.items():
|
278
|
+
torch.save(tensor, os.path.join(path, f'{k}.pt'))
|
279
|
+
print('Token saved:', os.path.join(path, f'{k}.pt'))
|
280
|
+
os.environ["TOKENIZERS_PARALLELISM"] = 'true' # supressing warning
|
281
|
+
return all_encoded_token
|
282
|
+
|
283
|
+
def __repr__(self) -> str:
|
284
|
+
return f'{self.__class__.__name__}()'
|
285
|
+
|
286
|
+
class TextDataset(torch.utils.data.Dataset):
|
287
|
+
r"""This nested dataset provides textual data for each node in
|
288
|
+
the graph. Factory method to create TextDataset from TAGDataset.
|
289
|
+
|
290
|
+
Args:
|
291
|
+
tag_dataset (TAGDataset): the parent dataset
|
292
|
+
"""
|
293
|
+
def __init__(self, tag_dataset: 'TAGDataset') -> None:
|
294
|
+
self.tag_dataset = tag_dataset
|
295
|
+
self.token = tag_dataset.token
|
296
|
+
assert tag_dataset._data is not None
|
297
|
+
self._data = tag_dataset._data
|
298
|
+
|
299
|
+
assert tag_dataset._data.y is not None
|
300
|
+
self.labels = tag_dataset._data.y
|
301
|
+
|
302
|
+
def get_token(self, node_idx: IndexType) -> Dict[str, Tensor]:
|
303
|
+
r"""This function will be called in __getitem__().
|
304
|
+
|
305
|
+
Args:
|
306
|
+
node_idx (IndexType): selected node idx in each batch
|
307
|
+
Returns:
|
308
|
+
items (Dict[str, Tensor]): input for LM
|
309
|
+
"""
|
310
|
+
items = {k: v[node_idx] for k, v in self.token.items()}
|
311
|
+
return items
|
312
|
+
|
313
|
+
# for LM training
|
314
|
+
def __getitem__(
|
315
|
+
self, node_id: IndexType
|
316
|
+
) -> Dict[str, Union[Tensor, Dict[str, Tensor]]]:
|
317
|
+
r"""This function will override the function in
|
318
|
+
torch.utils.data.Dataset, and will be called when you
|
319
|
+
iterate batch in the dataloader, make sure all following
|
320
|
+
key value pairs are present in the return dict.
|
321
|
+
|
322
|
+
Args:
|
323
|
+
node_id (List[int]): list of node idx for selecting tokens,
|
324
|
+
labels etc. when iterating data loader for LM
|
325
|
+
Returns:
|
326
|
+
items (dict): input k,v pairs for Language model training and
|
327
|
+
inference
|
328
|
+
"""
|
329
|
+
item: Dict[str, Union[Tensor, Dict[str, Tensor]]] = {}
|
330
|
+
item['input'] = self.get_token(node_id)
|
331
|
+
item['labels'] = self.labels[node_id]
|
332
|
+
item['is_gold'] = self.tag_dataset.get_gold(node_id)
|
333
|
+
item['n_id'] = self.tag_dataset.get_n_id(node_id)
|
334
|
+
return item
|
335
|
+
|
336
|
+
def __len__(self) -> int:
|
337
|
+
assert self._data.num_nodes is not None
|
338
|
+
return self._data.num_nodes
|
339
|
+
|
340
|
+
def get(self, idx: int) -> BaseData:
|
341
|
+
return self._data
|
342
|
+
|
343
|
+
def __repr__(self) -> str:
|
344
|
+
return f'{self.__class__.__name__}()'
|
345
|
+
|
346
|
+
def to_text_dataset(self) -> TextDataset:
|
347
|
+
r"""Factory Build text dataset from Text Attributed Graph Dataset
|
348
|
+
each data point is node's associated text token.
|
349
|
+
"""
|
350
|
+
return TAGDataset.TextDataset(self)
|
@@ -0,0 +1,71 @@
|
|
1
|
+
from typing import Callable
|
2
|
+
|
3
|
+
import torch
|
4
|
+
|
5
|
+
|
6
|
+
class QFormer(torch.nn.Module):
|
7
|
+
r"""The Querying Transformer (Q-Former) from
|
8
|
+
`"BLIP-2: Bootstrapping Language-Image Pre-training
|
9
|
+
with Frozen Image Encoders and Large Language Models"
|
10
|
+
<https://arxiv.org/pdf/2301.12597>`_ paper.
|
11
|
+
|
12
|
+
Args:
|
13
|
+
input_dim (int): The number of features in the input.
|
14
|
+
hidden_dim (int): The dimension of the fnn in the encoder layer.
|
15
|
+
output_dim (int): The final output dimension.
|
16
|
+
num_heads (int): The number of multi-attention-heads.
|
17
|
+
num_layers (int): The number of sub-encoder-layers in the encoder.
|
18
|
+
dropout (int): The dropout value in each encoder layer.
|
19
|
+
|
20
|
+
|
21
|
+
.. note::
|
22
|
+
This is a simplified version of the original Q-Former implementation.
|
23
|
+
"""
|
24
|
+
def __init__(
|
25
|
+
self,
|
26
|
+
input_dim: int,
|
27
|
+
hidden_dim: int,
|
28
|
+
output_dim: int,
|
29
|
+
num_heads: int,
|
30
|
+
num_layers: int,
|
31
|
+
dropout: float = 0.0,
|
32
|
+
activation: Callable = torch.nn.ReLU(),
|
33
|
+
) -> None:
|
34
|
+
|
35
|
+
super().__init__()
|
36
|
+
self.num_layers = num_layers
|
37
|
+
self.num_heads = num_heads
|
38
|
+
|
39
|
+
self.layer_norm = torch.nn.LayerNorm(input_dim)
|
40
|
+
self.encoder_layer = torch.nn.TransformerEncoderLayer(
|
41
|
+
d_model=input_dim,
|
42
|
+
nhead=num_heads,
|
43
|
+
dim_feedforward=hidden_dim,
|
44
|
+
dropout=dropout,
|
45
|
+
activation=activation,
|
46
|
+
batch_first=True,
|
47
|
+
)
|
48
|
+
self.encoder = torch.nn.TransformerEncoder(
|
49
|
+
self.encoder_layer,
|
50
|
+
num_layers=num_layers,
|
51
|
+
)
|
52
|
+
self.project = torch.nn.Linear(input_dim, output_dim)
|
53
|
+
|
54
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
55
|
+
r"""Forward pass.
|
56
|
+
|
57
|
+
Args:
|
58
|
+
x (torch.Tensor): Input sequence to the encoder layer.
|
59
|
+
:math:`\mathbf{X} \in \mathbb{R}^{B \times N \times F}`, with
|
60
|
+
batch-size :math:`B`, sequence length :math:`N`,
|
61
|
+
and feature dimension :math:`F`.
|
62
|
+
"""
|
63
|
+
x = self.layer_norm(x)
|
64
|
+
x = self.encoder(x)
|
65
|
+
out = self.project(x)
|
66
|
+
return out
|
67
|
+
|
68
|
+
def __repr__(self) -> str:
|
69
|
+
return (f'{self.__class__.__name__}('
|
70
|
+
f'num_heads={self.num_heads}, '
|
71
|
+
f'num_layers={self.num_layers})')
|
@@ -29,7 +29,8 @@ from .pmlp import PMLP
|
|
29
29
|
from .neural_fingerprint import NeuralFingerprint
|
30
30
|
from .visnet import ViSNet
|
31
31
|
from .g_retriever import GRetriever
|
32
|
-
|
32
|
+
from .molecule_gpt import MoleculeGPT
|
33
|
+
from .glem import GLEM
|
33
34
|
# Deprecated:
|
34
35
|
from torch_geometric.explain.algorithm.captum import (to_captum_input,
|
35
36
|
captum_output_to_dicts)
|
@@ -77,4 +78,6 @@ __all__ = classes = [
|
|
77
78
|
'NeuralFingerprint',
|
78
79
|
'ViSNet',
|
79
80
|
'GRetriever',
|
81
|
+
'MoleculeGPT',
|
82
|
+
'GLEM',
|
80
83
|
]
|