pyg-nightly 2.7.0.dev20241119__py3-none-any.whl → 2.7.0.dev20241120__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,350 @@
1
+ import os
2
+ import os.path as osp
3
+ from collections.abc import Sequence
4
+ from typing import Dict, List, Optional, Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ from torch import Tensor
9
+ from tqdm import tqdm
10
+
11
+ from torch_geometric.data import InMemoryDataset, download_google_url
12
+ from torch_geometric.data.data import BaseData
13
+
14
+ try:
15
+ from pandas import DataFrame, read_csv
16
+ WITH_PANDAS = True
17
+ except ImportError:
18
+ WITH_PANDAS = False
19
+
20
+ IndexType = Union[slice, Tensor, np.ndarray, Sequence]
21
+
22
+
23
+ class TAGDataset(InMemoryDataset):
24
+ r"""The Text Attributed Graph datasets from the
25
+ `"Learning on Large-scale Text-attributed Graphs via Variational Inference
26
+ " <https://arxiv.org/abs/2210.14709>`_ paper.
27
+ This dataset is aiming on transform `ogbn products`, `ogbn arxiv`
28
+ into Text Attributed Graph that each node in graph is associate with a
29
+ raw text, that dataset can be adapt to DataLoader (for LM training) and
30
+ NeighborLoader(for GNN training). In addition, this class can be use as a
31
+ wrapper class by convert a InMemoryDataset with Tokenizer and text into
32
+ Text Attributed Graph.
33
+
34
+ Args:
35
+ root (str): Root directory where the dataset should be saved.
36
+ dataset (InMemoryDataset): The name of the dataset
37
+ (:obj:`"ogbn-products"`, :obj:`"ogbn-arxiv"`).
38
+ tokenizer_name (str): The tokenizer name for language model,
39
+ Be sure to use same tokenizer name as your `model id` of model repo
40
+ on huggingface.co.
41
+ text (List[str]): list of raw text associate with node, the order of
42
+ list should be align with node list
43
+ split_idx (Optional[Dict[str, torch.Tensor]]): Optional dictionary,
44
+ for saving split index, it is required that if your dataset doesn't
45
+ have get_split_idx function
46
+ tokenize_batch_size (int): batch size of tokenizing text, the
47
+ tokenizing process will run on cpu, default: 256
48
+ token_on_disk (bool): save token as .pt file on disk or not,
49
+ default: False
50
+ text_on_disk (bool): save given text(list of str) as dataframe on disk
51
+ or not, default: False
52
+ force_reload (bool): default: False
53
+ .. note::
54
+ See `example/llm_plus_gnn/glem.py` for example usage
55
+ """
56
+ raw_text_id = {
57
+ 'ogbn-arxiv': '1g3OOVhRyiyKv13LY6gbp8GLITocOUr_3',
58
+ 'ogbn-products': '1I-S176-W4Bm1iPDjQv3hYwQBtxE0v8mt'
59
+ }
60
+
61
+ def __init__(self, root: str, dataset: InMemoryDataset,
62
+ tokenizer_name: str, text: Optional[List[str]] = None,
63
+ split_idx: Optional[Dict[str, Tensor]] = None,
64
+ tokenize_batch_size: int = 256, token_on_disk: bool = False,
65
+ text_on_disk: bool = False,
66
+ force_reload: bool = False) -> None:
67
+ # list the vars you want to pass in before run download & process
68
+ self.name = dataset.name
69
+ self.text = text
70
+ self.tokenizer_name = tokenizer_name
71
+ from transformers import AutoTokenizer
72
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
73
+ if self.tokenizer.pad_token_id is None:
74
+ self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
75
+ if self.tokenizer.pad_token is None:
76
+ self.tokenizer.pad_token = self.tokenizer.eos_token
77
+
78
+ self.dir_name = '_'.join(dataset.name.split('-'))
79
+ self.root = osp.join(root, self.dir_name)
80
+ missing_str_list = []
81
+ if not WITH_PANDAS:
82
+ missing_str_list.append('pandas')
83
+ if len(missing_str_list) > 0:
84
+ missing_str = ' '.join(missing_str_list)
85
+ error_out = f"`pip install {missing_str}` to use this dataset."
86
+ raise ImportError(error_out)
87
+ if hasattr(dataset, 'get_idx_split'):
88
+ self.split_idx = dataset.get_idx_split()
89
+ elif split_idx is not None:
90
+ self.split_idx = split_idx
91
+ else:
92
+ raise ValueError("TAGDataset need split idx for generating "
93
+ "is_gold mask, please pass splited index "
94
+ "in format of dictionaty with 'train', 'valid' "
95
+ "'test' index tensor to 'split_idx'")
96
+ if text is not None and text_on_disk:
97
+ self.save_node_text(text)
98
+ self.text_on_disk = text_on_disk
99
+ # init will call download and process
100
+ super().__init__(self.root, transform=None, pre_transform=None,
101
+ pre_filter=None, force_reload=force_reload)
102
+ # after processing and download
103
+ # Dataset has to have BaseData as _data
104
+ assert dataset._data is not None
105
+ self._data = dataset._data # reassign reference
106
+ assert self._data is not None
107
+ assert dataset._data.y is not None
108
+ assert isinstance(self._data, BaseData)
109
+ assert self._data.num_nodes is not None
110
+ assert isinstance(dataset._data.num_nodes, int)
111
+ assert isinstance(self._data.num_nodes, int)
112
+ self._n_id = torch.arange(self._data.num_nodes)
113
+ is_good_tensor = self.load_gold_mask()
114
+ self._is_gold = is_good_tensor.squeeze()
115
+ self._data['is_gold'] = is_good_tensor
116
+ if self.text is not None and len(self.text) != self._data.num_nodes:
117
+ raise ValueError("The number of text sequence in 'text' should be "
118
+ "equal to number of nodes!")
119
+ self.token_on_disk = token_on_disk
120
+ self.tokenize_batch_size = tokenize_batch_size
121
+ self._token = self.tokenize_graph(self.tokenize_batch_size)
122
+ self.__num_classes__ = dataset.num_classes
123
+
124
+ @property
125
+ def num_classes(self) -> int:
126
+ return self.__num_classes__
127
+
128
+ @property
129
+ def raw_file_names(self) -> List[str]:
130
+ file_names = []
131
+ for root, _, files in os.walk(osp.join(self.root, 'raw')):
132
+ for file in files:
133
+ file_names.append(file)
134
+ return file_names
135
+
136
+ @property
137
+ def processed_file_names(self) -> List[str]:
138
+ return [
139
+ 'geometric_data_processed.pt', 'pre_filter.pt',
140
+ 'pre_transformed.pt'
141
+ ]
142
+
143
+ @property
144
+ def token(self) -> Dict[str, Tensor]:
145
+ if self._token is None: # lazy load
146
+ self._token = self.tokenize_graph()
147
+ return self._token
148
+
149
+ # load is_gold after init
150
+ @property
151
+ def is_gold(self) -> Tensor:
152
+ if self._is_gold is None:
153
+ print('lazy load is_gold!!')
154
+ self._is_gold = self.load_gold_mask()
155
+ return self._is_gold
156
+
157
+ def get_n_id(self, node_idx: IndexType) -> Tensor:
158
+ if self._n_id is None:
159
+ assert self._data is not None
160
+ assert self._data.num_nodes is not None
161
+ assert isinstance(self._data.num_nodes, int)
162
+ self._n_id = torch.arange(self._data.num_nodes)
163
+ return self._n_id[node_idx]
164
+
165
+ def load_gold_mask(self) -> Tensor:
166
+ r"""Use original train split as gold split, generating is_gold mask
167
+ for picking ground truth labels and pseudo labels.
168
+ """
169
+ train_split_idx = self.get_idx_split()['train']
170
+ assert self._data is not None
171
+ assert self._data.num_nodes is not None
172
+ assert isinstance(self._data.num_nodes, int)
173
+ is_good_tensor = torch.zeros(self._data.num_nodes,
174
+ dtype=torch.bool).view(-1, 1)
175
+ is_good_tensor[train_split_idx] = True
176
+ return is_good_tensor
177
+
178
+ def get_gold(self, node_idx: IndexType) -> Tensor:
179
+ r"""Get gold mask for given node_idx.
180
+
181
+ Args:
182
+ node_idx (torch.tensor): a tensor contain node idx
183
+ """
184
+ if self._is_gold is None:
185
+ self._is_gold = self.is_gold
186
+ return self._is_gold[node_idx]
187
+
188
+ def get_idx_split(self) -> Dict[str, Tensor]:
189
+ return self.split_idx
190
+
191
+ def download(self) -> None:
192
+ print('downloading raw text')
193
+ raw_text_path = download_google_url(id=self.raw_text_id[self.name],
194
+ folder=f'{self.root}/raw',
195
+ filename='node-text.csv.gz',
196
+ log=True)
197
+ text_df = read_csv(raw_text_path)
198
+ self.text = list(text_df['text'])
199
+
200
+ def process(self) -> None:
201
+ if osp.exists(osp.join(self.root, 'raw', 'node-text.csv.gz')):
202
+ text_df = read_csv(osp.join(self.root, 'raw', 'node-text.csv.gz'))
203
+ self.text = list(text_df['text'])
204
+ elif self.name in self.raw_text_id:
205
+ self.download()
206
+ else:
207
+ print('The dataset is not ogbn-products nor ogbn-arxiv,'
208
+ 'please pass in your raw text string list to `text`')
209
+ if self.text is None:
210
+ raise ValueError("The TAGDataset only have ogbn-products and "
211
+ "ogbn-arxiv raw text in default "
212
+ "The raw text of each node is not specified"
213
+ "Please pass in 'text' when convert your dataset "
214
+ "to Text Attribute Graph Dataset")
215
+
216
+ def save_node_text(self, text: List[str]) -> None:
217
+ node_text_path = osp.join(self.root, 'raw', 'node-text.csv.gz')
218
+ if osp.exists(node_text_path):
219
+ print(f'The raw text is existed at {node_text_path}')
220
+ else:
221
+ print(f'Saving raw text file at {node_text_path}')
222
+ os.makedirs(f'{self.root}/raw', exist_ok=True)
223
+ text_df = DataFrame(text, columns=['text'])
224
+ text_df.to_csv(osp.join(node_text_path), compression='gzip',
225
+ index=False)
226
+
227
+ def tokenize_graph(self, batch_size: int = 256) -> Dict[str, Tensor]:
228
+ r"""Tokenizing the text associate with each node, running in cpu.
229
+
230
+ Args:
231
+ batch_size (Optional[int]): batch size of list of text for
232
+ generating emebdding
233
+ Returns:
234
+ Dict[str, torch.Tensor]: tokenized graph
235
+ """
236
+ data_len = 0
237
+ if self.text is not None:
238
+ data_len = len(self.text)
239
+ else:
240
+ raise ValueError("The TAGDataset need text for tokenization")
241
+ token_keys = ['input_ids', 'token_type_ids', 'attention_mask']
242
+ path = os.path.join(self.processed_dir, 'token', self.tokenizer_name)
243
+ # Check if the .pt files already exist
244
+ token_files_exist = any(
245
+ os.path.exists(os.path.join(path, f'{k}.pt')) for k in token_keys)
246
+
247
+ if token_files_exist and self.token_on_disk:
248
+ print('Found tokenized file, loading may take several minutes...')
249
+ all_encoded_token = {
250
+ k: torch.load(os.path.join(path, f'{k}.pt'), weights_only=True)
251
+ for k in token_keys
252
+ if os.path.exists(os.path.join(path, f'{k}.pt'))
253
+ }
254
+ return all_encoded_token
255
+
256
+ all_encoded_token = {k: [] for k in token_keys}
257
+ pbar = tqdm(total=data_len)
258
+
259
+ pbar.set_description('Tokenizing Text Attributed Graph')
260
+ for i in range(0, data_len, batch_size):
261
+ end_index = min(data_len, i + batch_size)
262
+ token = self.tokenizer(self.text[i:min(i + batch_size, data_len)],
263
+ padding='max_length', truncation=True,
264
+ max_length=512, return_tensors="pt")
265
+ for k in token.keys():
266
+ all_encoded_token[k].append(token[k])
267
+ pbar.update(end_index - i)
268
+ pbar.close()
269
+
270
+ all_encoded_token = {
271
+ k: torch.cat(v)
272
+ for k, v in all_encoded_token.items() if len(v) > 0
273
+ }
274
+ if self.token_on_disk:
275
+ os.makedirs(path, exist_ok=True)
276
+ print('Saving tokens on Disk')
277
+ for k, tensor in all_encoded_token.items():
278
+ torch.save(tensor, os.path.join(path, f'{k}.pt'))
279
+ print('Token saved:', os.path.join(path, f'{k}.pt'))
280
+ os.environ["TOKENIZERS_PARALLELISM"] = 'true' # supressing warning
281
+ return all_encoded_token
282
+
283
+ def __repr__(self) -> str:
284
+ return f'{self.__class__.__name__}()'
285
+
286
+ class TextDataset(torch.utils.data.Dataset):
287
+ r"""This nested dataset provides textual data for each node in
288
+ the graph. Factory method to create TextDataset from TAGDataset.
289
+
290
+ Args:
291
+ tag_dataset (TAGDataset): the parent dataset
292
+ """
293
+ def __init__(self, tag_dataset: 'TAGDataset') -> None:
294
+ self.tag_dataset = tag_dataset
295
+ self.token = tag_dataset.token
296
+ assert tag_dataset._data is not None
297
+ self._data = tag_dataset._data
298
+
299
+ assert tag_dataset._data.y is not None
300
+ self.labels = tag_dataset._data.y
301
+
302
+ def get_token(self, node_idx: IndexType) -> Dict[str, Tensor]:
303
+ r"""This function will be called in __getitem__().
304
+
305
+ Args:
306
+ node_idx (IndexType): selected node idx in each batch
307
+ Returns:
308
+ items (Dict[str, Tensor]): input for LM
309
+ """
310
+ items = {k: v[node_idx] for k, v in self.token.items()}
311
+ return items
312
+
313
+ # for LM training
314
+ def __getitem__(
315
+ self, node_id: IndexType
316
+ ) -> Dict[str, Union[Tensor, Dict[str, Tensor]]]:
317
+ r"""This function will override the function in
318
+ torch.utils.data.Dataset, and will be called when you
319
+ iterate batch in the dataloader, make sure all following
320
+ key value pairs are present in the return dict.
321
+
322
+ Args:
323
+ node_id (List[int]): list of node idx for selecting tokens,
324
+ labels etc. when iterating data loader for LM
325
+ Returns:
326
+ items (dict): input k,v pairs for Language model training and
327
+ inference
328
+ """
329
+ item: Dict[str, Union[Tensor, Dict[str, Tensor]]] = {}
330
+ item['input'] = self.get_token(node_id)
331
+ item['labels'] = self.labels[node_id]
332
+ item['is_gold'] = self.tag_dataset.get_gold(node_id)
333
+ item['n_id'] = self.tag_dataset.get_n_id(node_id)
334
+ return item
335
+
336
+ def __len__(self) -> int:
337
+ assert self._data.num_nodes is not None
338
+ return self._data.num_nodes
339
+
340
+ def get(self, idx: int) -> BaseData:
341
+ return self._data
342
+
343
+ def __repr__(self) -> str:
344
+ return f'{self.__class__.__name__}()'
345
+
346
+ def to_text_dataset(self) -> TextDataset:
347
+ r"""Factory Build text dataset from Text Attributed Graph Dataset
348
+ each data point is node's associated text token.
349
+ """
350
+ return TAGDataset.TextDataset(self)
@@ -1,3 +1,7 @@
1
1
  from .performer import PerformerAttention
2
+ from .qformer import QFormer
2
3
 
3
- __all__ = ['PerformerAttention']
4
+ __all__ = [
5
+ 'PerformerAttention',
6
+ 'QFormer',
7
+ ]
@@ -0,0 +1,71 @@
1
+ from typing import Callable
2
+
3
+ import torch
4
+
5
+
6
+ class QFormer(torch.nn.Module):
7
+ r"""The Querying Transformer (Q-Former) from
8
+ `"BLIP-2: Bootstrapping Language-Image Pre-training
9
+ with Frozen Image Encoders and Large Language Models"
10
+ <https://arxiv.org/pdf/2301.12597>`_ paper.
11
+
12
+ Args:
13
+ input_dim (int): The number of features in the input.
14
+ hidden_dim (int): The dimension of the fnn in the encoder layer.
15
+ output_dim (int): The final output dimension.
16
+ num_heads (int): The number of multi-attention-heads.
17
+ num_layers (int): The number of sub-encoder-layers in the encoder.
18
+ dropout (int): The dropout value in each encoder layer.
19
+
20
+
21
+ .. note::
22
+ This is a simplified version of the original Q-Former implementation.
23
+ """
24
+ def __init__(
25
+ self,
26
+ input_dim: int,
27
+ hidden_dim: int,
28
+ output_dim: int,
29
+ num_heads: int,
30
+ num_layers: int,
31
+ dropout: float = 0.0,
32
+ activation: Callable = torch.nn.ReLU(),
33
+ ) -> None:
34
+
35
+ super().__init__()
36
+ self.num_layers = num_layers
37
+ self.num_heads = num_heads
38
+
39
+ self.layer_norm = torch.nn.LayerNorm(input_dim)
40
+ self.encoder_layer = torch.nn.TransformerEncoderLayer(
41
+ d_model=input_dim,
42
+ nhead=num_heads,
43
+ dim_feedforward=hidden_dim,
44
+ dropout=dropout,
45
+ activation=activation,
46
+ batch_first=True,
47
+ )
48
+ self.encoder = torch.nn.TransformerEncoder(
49
+ self.encoder_layer,
50
+ num_layers=num_layers,
51
+ )
52
+ self.project = torch.nn.Linear(input_dim, output_dim)
53
+
54
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
55
+ r"""Forward pass.
56
+
57
+ Args:
58
+ x (torch.Tensor): Input sequence to the encoder layer.
59
+ :math:`\mathbf{X} \in \mathbb{R}^{B \times N \times F}`, with
60
+ batch-size :math:`B`, sequence length :math:`N`,
61
+ and feature dimension :math:`F`.
62
+ """
63
+ x = self.layer_norm(x)
64
+ x = self.encoder(x)
65
+ out = self.project(x)
66
+ return out
67
+
68
+ def __repr__(self) -> str:
69
+ return (f'{self.__class__.__name__}('
70
+ f'num_heads={self.num_heads}, '
71
+ f'num_layers={self.num_layers})')
@@ -29,7 +29,8 @@ from .pmlp import PMLP
29
29
  from .neural_fingerprint import NeuralFingerprint
30
30
  from .visnet import ViSNet
31
31
  from .g_retriever import GRetriever
32
-
32
+ from .molecule_gpt import MoleculeGPT
33
+ from .glem import GLEM
33
34
  # Deprecated:
34
35
  from torch_geometric.explain.algorithm.captum import (to_captum_input,
35
36
  captum_output_to_dicts)
@@ -77,4 +78,6 @@ __all__ = classes = [
77
78
  'NeuralFingerprint',
78
79
  'ViSNet',
79
80
  'GRetriever',
81
+ 'MoleculeGPT',
82
+ 'GLEM',
80
83
  ]