pyg-nightly 2.7.0.dev20241118__py3-none-any.whl → 2.7.0.dev20241120__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pyg_nightly-2.7.0.dev20241118.dist-info → pyg_nightly-2.7.0.dev20241120.dist-info}/METADATA +1 -1
- {pyg_nightly-2.7.0.dev20241118.dist-info → pyg_nightly-2.7.0.dev20241120.dist-info}/RECORD +15 -10
- torch_geometric/__init__.py +1 -1
- torch_geometric/datasets/__init__.py +4 -0
- torch_geometric/datasets/molecule_gpt_dataset.py +480 -0
- torch_geometric/datasets/tag_dataset.py +350 -0
- torch_geometric/loader/neighbor_loader.py +1 -1
- torch_geometric/nn/attention/__init__.py +5 -1
- torch_geometric/nn/attention/qformer.py +71 -0
- torch_geometric/nn/models/__init__.py +4 -1
- torch_geometric/nn/models/glem.py +384 -0
- torch_geometric/nn/models/molecule_gpt.py +222 -0
- torch_geometric/nn/nlp/llm.py +1 -1
- torch_geometric/nn/nlp/sentence_transformer.py +3 -0
- {pyg_nightly-2.7.0.dev20241118.dist-info → pyg_nightly-2.7.0.dev20241120.dist-info}/WHEEL +0 -0
@@ -0,0 +1,384 @@
|
|
1
|
+
from typing import List, Optional, Union
|
2
|
+
|
3
|
+
import torch
|
4
|
+
import torch.nn as nn
|
5
|
+
from tqdm import tqdm
|
6
|
+
|
7
|
+
from torch_geometric.loader import DataLoader, NeighborLoader
|
8
|
+
from torch_geometric.nn.models import GraphSAGE, basic_gnn
|
9
|
+
|
10
|
+
|
11
|
+
class GLEM(torch.nn.Module):
|
12
|
+
r"""This GNN+LM co-training model is based on GLEM from the `"Learning on
|
13
|
+
Large-scale Text-attributed Graphs via Variational Inference"
|
14
|
+
<https://arxiv.org/abs/2210.14709>`_ paper.
|
15
|
+
|
16
|
+
Args:
|
17
|
+
lm_to_use (str): A TextEncoder from huggingface model repo
|
18
|
+
with a classifier(default: TinyBERT)
|
19
|
+
gnn_to_use (torch_geometric.nn.models): (default: GraphSAGE)
|
20
|
+
out_channels (int): output channels for LM and GNN, should be same
|
21
|
+
num_gnn_heads Optional[int]: Number of heads for attention, if needed
|
22
|
+
num_gnn_layers (int): number of gnn layers
|
23
|
+
gnn_loss: loss function for gnn, (default: CrossEntropyLoss)
|
24
|
+
lm_loss: loss function for Language Model, (default: CrossEntropyLoss)
|
25
|
+
alpha (float): pseudo label weight of E-step, LM optimization,
|
26
|
+
(default: 0.5)
|
27
|
+
beta (float): pseudo label weight of M-step, GNN optimization,
|
28
|
+
(default: 0.5)
|
29
|
+
lm_dtype (torch.dtype): the data type once you load LM into memory,
|
30
|
+
(default: torch.bfloat16)
|
31
|
+
lm_use_lora (bool): choose if LM use Lora peft for fine tune,
|
32
|
+
(default: True)
|
33
|
+
lora_target_modules: The names of the target modules to apply the lora
|
34
|
+
adapter to, e.g. ['q_proj', 'v_proj'] for LLM , (default: None)
|
35
|
+
|
36
|
+
.. note::
|
37
|
+
See `examples/llm_plus_gnn/glem.py` for example usage.
|
38
|
+
"""
|
39
|
+
def __init__(
|
40
|
+
self,
|
41
|
+
lm_to_use: str = 'prajjwal1/bert-tiny',
|
42
|
+
gnn_to_use: basic_gnn = GraphSAGE,
|
43
|
+
out_channels: int = 47,
|
44
|
+
gnn_loss=nn.CrossEntropyLoss(reduction='mean'),
|
45
|
+
lm_loss=nn.CrossEntropyLoss(reduction='mean'),
|
46
|
+
alpha: float = 0.5,
|
47
|
+
beta: float = 0.5,
|
48
|
+
lm_dtype: torch.dtype = torch.bfloat16,
|
49
|
+
lm_use_lora: bool = True,
|
50
|
+
lora_target_modules: Optional[Union[List[str], str]] = None,
|
51
|
+
device: Union[str, torch.device] = torch.device('cpu'),
|
52
|
+
):
|
53
|
+
super().__init__()
|
54
|
+
self.device = device
|
55
|
+
self.lm_loss = lm_loss
|
56
|
+
self.gnn = gnn_to_use
|
57
|
+
self.gnn_loss = gnn_loss
|
58
|
+
self.alpha = alpha
|
59
|
+
self.beta = beta
|
60
|
+
self.gnn_loss = gnn_loss
|
61
|
+
self.lm = lm_to_use
|
62
|
+
from transformers import AutoModelForSequenceClassification
|
63
|
+
self.lm = AutoModelForSequenceClassification.from_pretrained(
|
64
|
+
lm_to_use, num_labels=out_channels, torch_dtype=lm_dtype,
|
65
|
+
offload_folder="offload", trust_remote_code=True)
|
66
|
+
if lm_use_lora:
|
67
|
+
from peft import (
|
68
|
+
LoraConfig,
|
69
|
+
TaskType,
|
70
|
+
get_peft_model,
|
71
|
+
prepare_model_for_kbit_training,
|
72
|
+
)
|
73
|
+
print("Training LM with LORA!")
|
74
|
+
self.lm = prepare_model_for_kbit_training(self.lm)
|
75
|
+
config = LoraConfig(task_type=TaskType.SEQ_CLS, r=16,
|
76
|
+
lora_alpha=16, lora_dropout=0.05, bias="none",
|
77
|
+
target_modules=lora_target_modules)
|
78
|
+
self.lm = get_peft_model(self.lm, config)
|
79
|
+
self.lm.print_trainable_parameters()
|
80
|
+
self.lm.config.pad_token_id = self.lm.config.eos_token_id
|
81
|
+
self.lm_device = self.lm.device
|
82
|
+
|
83
|
+
if self.lm.num_labels != self.gnn.out_channels:
|
84
|
+
raise ValueError('''The output channel of language model \
|
85
|
+
and gnn should be the same''')
|
86
|
+
|
87
|
+
def pre_train_gnn(self, train_loader: NeighborLoader,
|
88
|
+
optimizer: torch.optim.Optimizer, num_epochs: int,
|
89
|
+
patience: int, ext_pseudo_labels: torch.Tensor = None,
|
90
|
+
is_augmented: bool = False, verbose: bool = True):
|
91
|
+
# Pretrain GNN, optional steps if you do not have pseudo labels.
|
92
|
+
best_acc = 0
|
93
|
+
early_stopping = 0
|
94
|
+
# training only based on gold data
|
95
|
+
for epoch in range(0, num_epochs):
|
96
|
+
acc, loss = self.train_gnn(train_loader, optimizer, epoch,
|
97
|
+
ext_pseudo_labels, is_augmented,
|
98
|
+
verbose)
|
99
|
+
if acc < best_acc:
|
100
|
+
early_stopping += 1
|
101
|
+
if early_stopping > patience:
|
102
|
+
print(f'Early stopped by Epoch: {epoch}, '
|
103
|
+
f'Best acc: {best_acc}')
|
104
|
+
break
|
105
|
+
best_acc = max(best_acc, acc)
|
106
|
+
|
107
|
+
def pre_train_lm(self, train_loader: DataLoader,
|
108
|
+
optimizer: torch.optim.Optimizer, num_epochs: int,
|
109
|
+
patience: int, ext_pseudo_labels: torch.Tensor = None,
|
110
|
+
is_augmented: bool = False, verbose: bool = True):
|
111
|
+
# Pretrain language model
|
112
|
+
best_acc = 0
|
113
|
+
early_stopping = 0
|
114
|
+
for epoch in range(1, num_epochs + 1):
|
115
|
+
acc, loss = self.train_lm(train_loader, optimizer, epoch,
|
116
|
+
ext_pseudo_labels, is_augmented, verbose)
|
117
|
+
if acc < best_acc:
|
118
|
+
early_stopping += 1
|
119
|
+
if early_stopping > patience:
|
120
|
+
print(f'Early stopped by Epoch: {epoch}, '
|
121
|
+
f'Best acc: {best_acc}')
|
122
|
+
break
|
123
|
+
best_acc = max(best_acc, acc)
|
124
|
+
|
125
|
+
def train(self, em_phase: str, train_loader: Union[DataLoader,
|
126
|
+
NeighborLoader],
|
127
|
+
optimizer: torch.optim.Optimizer, pseudo_labels: torch.Tensor,
|
128
|
+
epoch: int, is_augmented: bool = False, verbose: bool = False):
|
129
|
+
r"""GLEM training step, EM steps.
|
130
|
+
|
131
|
+
Args:
|
132
|
+
em_phase(str): 'gnn' or 'lm' choose which phase you are training on
|
133
|
+
train_loader(Union[DataLoader, NeighborLoader]): use DataLoader for
|
134
|
+
lm training, include tokenized data, labels is_gold mask.
|
135
|
+
use NeighborLoader for gnn training, include x, edge_index.
|
136
|
+
optimizer (torch.optim.Optimizer): optimizer for training
|
137
|
+
pseudo_labels(torch.Tensor): the predicted labels used as pseudo
|
138
|
+
labels
|
139
|
+
epoch (int): current epoch
|
140
|
+
is_augmented (bool): will use pseudo_labels or not
|
141
|
+
verbose (bool): print training progress bar or not
|
142
|
+
|
143
|
+
Returns:
|
144
|
+
acc (float): training accuracy
|
145
|
+
loss (float): loss value
|
146
|
+
"""
|
147
|
+
pseudo_labels = pseudo_labels.to(self.device)
|
148
|
+
if em_phase == 'gnn':
|
149
|
+
acc, loss = self.train_gnn(train_loader, optimizer, epoch,
|
150
|
+
pseudo_labels, is_augmented, verbose)
|
151
|
+
if em_phase == 'lm':
|
152
|
+
acc, loss = self.train_lm(train_loader, optimizer, epoch,
|
153
|
+
pseudo_labels, is_augmented, verbose)
|
154
|
+
return acc, loss
|
155
|
+
|
156
|
+
def train_lm(self, train_loader: DataLoader,
|
157
|
+
optimizer: torch.optim.Optimizer, epoch: int,
|
158
|
+
pseudo_labels: torch.Tensor = None,
|
159
|
+
is_augmented: bool = False, verbose: bool = True):
|
160
|
+
r"""Language model Training in every epoch.
|
161
|
+
|
162
|
+
Args:
|
163
|
+
train_loader (loader.dataloader.DataLoader): text token dataloader
|
164
|
+
optimizer (torch.optim.Optimizer): model optimizer
|
165
|
+
epoch (int): current train epoch
|
166
|
+
pseudo_labels (torch.Tensor): 1-D tensor, predictions from gnn
|
167
|
+
is_augmented (bool): train with pseudo labels or not
|
168
|
+
verbose (bool): print training progress bar or not
|
169
|
+
|
170
|
+
Returns:
|
171
|
+
approx_acc (torch.tensor): training accuracy
|
172
|
+
loss (torch.float): loss value
|
173
|
+
|
174
|
+
"""
|
175
|
+
all_out = []
|
176
|
+
total_loss = total_correct = 0
|
177
|
+
num_nodes = train_loader.dataset.indices.size(0)
|
178
|
+
self.lm.train()
|
179
|
+
if verbose:
|
180
|
+
pbar = tqdm(total=num_nodes)
|
181
|
+
pbar.set_description(f'Epoch {epoch:02d}')
|
182
|
+
for batch in train_loader:
|
183
|
+
inputs = {k: v.to(self.device) for k, v in batch['input'].items()}
|
184
|
+
out = self.lm(**inputs).logits
|
185
|
+
labels = batch['labels'].to(self.device).squeeze()
|
186
|
+
# training with pseudo labels or not
|
187
|
+
if is_augmented:
|
188
|
+
pl_batch = pseudo_labels[batch['n_id']].to(self.device)
|
189
|
+
else:
|
190
|
+
pl_batch = None
|
191
|
+
loss = self.loss(out, labels, self.lm_loss,
|
192
|
+
batch['is_gold'].to(self.device), pl_batch,
|
193
|
+
self.alpha, is_augmented)
|
194
|
+
loss.backward()
|
195
|
+
optimizer.step()
|
196
|
+
optimizer.zero_grad()
|
197
|
+
all_out.append(out)
|
198
|
+
total_correct += int(out.argmax(dim=-1).eq(labels).sum())
|
199
|
+
total_loss += float(loss)
|
200
|
+
if verbose:
|
201
|
+
pbar.update(batch['n_id'].size(0))
|
202
|
+
|
203
|
+
all_out = torch.cat(all_out, dim=0)
|
204
|
+
approx_acc = total_correct / num_nodes
|
205
|
+
loss = total_loss / len(train_loader)
|
206
|
+
if verbose:
|
207
|
+
pbar.close()
|
208
|
+
print(f'Epoch {epoch:02d} Loss: {loss:.4f} '
|
209
|
+
f'Approx. Train: {approx_acc:.4f}')
|
210
|
+
return approx_acc, loss
|
211
|
+
|
212
|
+
def train_gnn(self, train_loader: NeighborLoader,
|
213
|
+
optimizer: torch.optim.Optimizer, epoch: int,
|
214
|
+
pseudo_labels: torch.Tensor = None,
|
215
|
+
is_augmented: bool = False, verbose: bool = True):
|
216
|
+
r"""GNN training step in every epoch.
|
217
|
+
|
218
|
+
Args:
|
219
|
+
train_loader (loader.NeighborLoader): gnn Neighbor node loader
|
220
|
+
optimizer (torch.optim.Optimizer): model optimizer
|
221
|
+
epoch (int): current train epoch
|
222
|
+
pseudo_labels(torch.tensor): 1-D tensor, predictions from lm
|
223
|
+
is_augmented(bool): use pseudo labeled node or not
|
224
|
+
verbose (bool): print training progress or not
|
225
|
+
|
226
|
+
Returns:
|
227
|
+
approx_acc (torch.tensor): training accuracy
|
228
|
+
loss (torch.float): loss value
|
229
|
+
"""
|
230
|
+
self.gnn.train()
|
231
|
+
num_nodes = train_loader.input_nodes.size(0)
|
232
|
+
if verbose:
|
233
|
+
pbar = tqdm(total=num_nodes)
|
234
|
+
pbar.set_description(f'Epoch {epoch:02d}')
|
235
|
+
total_loss = total_correct = 0
|
236
|
+
all_out = []
|
237
|
+
for batch in train_loader:
|
238
|
+
batch = batch.to(self.device)
|
239
|
+
out = self.gnn(batch.x, batch.edge_index)[:batch.batch_size]
|
240
|
+
all_out.append(out)
|
241
|
+
labels = batch.y[:batch.batch_size].squeeze()
|
242
|
+
is_gold_batch = batch.is_gold[:batch.batch_size].squeeze()
|
243
|
+
# training with pseudo labels or not
|
244
|
+
if is_augmented and pseudo_labels is not None:
|
245
|
+
pl_batch = pseudo_labels[batch.n_id[:batch.batch_size]]
|
246
|
+
else:
|
247
|
+
pl_batch = None
|
248
|
+
loss = self.loss(out, labels, self.gnn_loss, is_gold_batch,
|
249
|
+
pl_batch, self.beta, is_augmented)
|
250
|
+
loss.backward()
|
251
|
+
optimizer.step()
|
252
|
+
optimizer.zero_grad()
|
253
|
+
total_loss += float(loss)
|
254
|
+
total_correct += int(out.argmax(dim=-1).eq(labels).sum())
|
255
|
+
if verbose:
|
256
|
+
pbar.update(batch.batch_size)
|
257
|
+
|
258
|
+
all_out = torch.cat(all_out, dim=0)
|
259
|
+
loss = total_loss / len(train_loader)
|
260
|
+
approx_acc = total_correct / num_nodes
|
261
|
+
if verbose:
|
262
|
+
pbar.close()
|
263
|
+
print(f'Epoch: {epoch:02d} Loss: {loss:.4f} '
|
264
|
+
f'Approx. Train: {approx_acc:.4f}')
|
265
|
+
return approx_acc, loss
|
266
|
+
|
267
|
+
@torch.no_grad()
|
268
|
+
def inference(self, em_phase: str, data_loader: Union[NeighborLoader,
|
269
|
+
DataLoader],
|
270
|
+
verbose: bool = False):
|
271
|
+
r"""GLEM inference step.
|
272
|
+
|
273
|
+
Args:
|
274
|
+
em_phase(str): 'gnn' or 'lm'
|
275
|
+
data_loader(dataloader or Neighborloader):
|
276
|
+
dataloader: for lm training, include tokenized data
|
277
|
+
nodeloader: for gnn training, include x, edge_index
|
278
|
+
verbose(bool): print inference progress or not
|
279
|
+
|
280
|
+
Returns:
|
281
|
+
out (torch.Tensor): n * m tensor, m is number of classes,
|
282
|
+
n is number of nodes
|
283
|
+
"""
|
284
|
+
out = None
|
285
|
+
if em_phase == 'gnn':
|
286
|
+
self.gnn.eval()
|
287
|
+
out = self.inference_gnn(data_loader, verbose)
|
288
|
+
elif em_phase == 'lm':
|
289
|
+
self.lm.eval()
|
290
|
+
out = self.inference_lm(data_loader, verbose)
|
291
|
+
return out
|
292
|
+
|
293
|
+
@torch.no_grad()
|
294
|
+
def inference_lm(self, data_loader: DataLoader, verbose: bool = True):
|
295
|
+
r"""LM inference step.
|
296
|
+
|
297
|
+
Args:
|
298
|
+
data_loader (Dataloader): include token, labels, and gold mask
|
299
|
+
verbose (bool): print progress bar or not
|
300
|
+
|
301
|
+
Returns:
|
302
|
+
preds (tensor): prediction from GNN, convert to pseudo labels
|
303
|
+
by preds.argmax(dim=-1).unsqueeze(1)
|
304
|
+
"""
|
305
|
+
if verbose:
|
306
|
+
pbar = tqdm(total=data_loader.dataset._data.num_nodes)
|
307
|
+
pbar.set_description('LM inference stage')
|
308
|
+
self.lm.eval()
|
309
|
+
preds = []
|
310
|
+
for batch in data_loader:
|
311
|
+
inputs = {k: v.to(self.device) for k, v in batch['input'].items()}
|
312
|
+
logits = self.lm(**inputs).logits
|
313
|
+
preds.append(logits)
|
314
|
+
if verbose:
|
315
|
+
pbar.update(batch['n_id'].size(0))
|
316
|
+
if verbose:
|
317
|
+
pbar.close()
|
318
|
+
preds = torch.cat(preds)
|
319
|
+
return preds
|
320
|
+
|
321
|
+
@torch.no_grad()
|
322
|
+
def inference_gnn(self, data_loader: NeighborLoader, verbose: bool = True):
|
323
|
+
r"""GNN inference step.
|
324
|
+
|
325
|
+
Args:
|
326
|
+
data_loader(NeighborLoader): include x, edge_index,
|
327
|
+
verbose (bool): print progress bar or not
|
328
|
+
|
329
|
+
Returns:
|
330
|
+
preds (tensor): prediction from GNN,
|
331
|
+
convert to pseudo labels by preds.argmax(dim=-1).unsqueeze(1)
|
332
|
+
"""
|
333
|
+
if verbose:
|
334
|
+
pbar = tqdm(total=data_loader.data.num_nodes)
|
335
|
+
pbar.set_description('GNN inference stage')
|
336
|
+
preds = []
|
337
|
+
self.gnn.eval()
|
338
|
+
for batch in data_loader:
|
339
|
+
batch = batch.to(self.device)
|
340
|
+
out = self.gnn(batch.x, batch.edge_index)[:batch.batch_size]
|
341
|
+
preds.append(out)
|
342
|
+
if verbose:
|
343
|
+
pbar.update(batch.batch_size)
|
344
|
+
if verbose:
|
345
|
+
pbar.close()
|
346
|
+
preds = torch.cat(preds, dim=0)
|
347
|
+
return preds
|
348
|
+
|
349
|
+
def loss(self, logits: torch.Tensor, labels: torch.Tensor,
|
350
|
+
loss_func: torch.nn.functional, is_gold: torch.Tensor,
|
351
|
+
pseudo_labels: torch.Tensor = None, pl_weight: float = 0.5,
|
352
|
+
is_augmented: bool = True):
|
353
|
+
r"""Core function of variational EM inference, this function is aming
|
354
|
+
on combining loss value on gold(original train) and loss value on
|
355
|
+
pseudo labels.
|
356
|
+
|
357
|
+
Reference:
|
358
|
+
<https://github.com/AndyJZhao/GLEM/blob/main/src/models/GLEM/GLEM_utils.py> # noqa
|
359
|
+
|
360
|
+
Args:
|
361
|
+
logits(torch.tensor): predict results from LM or GNN
|
362
|
+
labels(torch.tensor): combined node labels from ground truth and
|
363
|
+
pseudo labels(if provided)
|
364
|
+
loss_func(torch.nn.modules.loss): loss function for classification
|
365
|
+
is_gold(tensor): a tensor with bool value that mask ground truth
|
366
|
+
label and during training, thus ~is_gold mask pseudo labels
|
367
|
+
pseudo_labels(torch.tensor): predictions from other model
|
368
|
+
pl_weight: the pseudo labels used in E-step and M-step optimization
|
369
|
+
alpha in E-step, beta in M-step respectively
|
370
|
+
is_augmented: use EM or just train GNN and LM with gold data
|
371
|
+
|
372
|
+
"""
|
373
|
+
def deal_nan(x):
|
374
|
+
return 0 if torch.isnan(x) else x
|
375
|
+
|
376
|
+
if is_augmented and (sum(~is_gold) > 0):
|
377
|
+
mle_loss = deal_nan(loss_func(logits[is_gold], labels[is_gold]))
|
378
|
+
# all other labels beside from ground truth(gold labels)
|
379
|
+
pseudo_label_loss = deal_nan(
|
380
|
+
loss_func(logits[~is_gold], pseudo_labels[~is_gold]))
|
381
|
+
loss = pl_weight * pseudo_label_loss + (1 - pl_weight) * mle_loss
|
382
|
+
else:
|
383
|
+
loss = loss_func(logits, labels)
|
384
|
+
return loss
|
@@ -0,0 +1,222 @@
|
|
1
|
+
from typing import List, Optional
|
2
|
+
|
3
|
+
import torch
|
4
|
+
from torch import Tensor
|
5
|
+
|
6
|
+
from torch_geometric.nn.attention import QFormer
|
7
|
+
from torch_geometric.nn.nlp.llm import BOS, LLM, MAX_NEW_TOKENS
|
8
|
+
from torch_geometric.utils import to_dense_batch
|
9
|
+
|
10
|
+
|
11
|
+
def pad_or_truncate(embeddings: Tensor, max_seq_len: int,
|
12
|
+
padding_value: int = 0) -> Tensor:
|
13
|
+
batch_size, current_seq_len, d = embeddings.size()
|
14
|
+
|
15
|
+
if current_seq_len > max_seq_len:
|
16
|
+
return embeddings[:, :max_seq_len, :]
|
17
|
+
elif current_seq_len < max_seq_len:
|
18
|
+
pad_tensor = torch.full((batch_size, max_seq_len - current_seq_len, d),
|
19
|
+
padding_value, dtype=embeddings.dtype,
|
20
|
+
device=embeddings.device)
|
21
|
+
return torch.cat([embeddings, pad_tensor], dim=1)
|
22
|
+
else:
|
23
|
+
return embeddings
|
24
|
+
|
25
|
+
|
26
|
+
class MoleculeGPT(torch.nn.Module):
|
27
|
+
r"""The MoleculeGPT model from the `"MoleculeGPT: Instruction
|
28
|
+
Following Large Language Models for Molecular Property Prediction"
|
29
|
+
<https://ai4d3.github.io/papers/34.pdf>`_ paper.
|
30
|
+
|
31
|
+
Args:
|
32
|
+
llm (LLM): The LLM to use.
|
33
|
+
graph_encoder (torch.nn.Module): Encode 2D molecule graph.
|
34
|
+
smiles_encoder (torch.nn.Module): Encode 1D SMILES.
|
35
|
+
mlp_out_channels (int, optional): The size of each embedding
|
36
|
+
after qformer encoding. (default: :obj:`32`)
|
37
|
+
max_tokens (int, optional): Max output tokens of 1D/2D encoder.
|
38
|
+
(default: :obj:`20`)
|
39
|
+
|
40
|
+
.. warning::
|
41
|
+
This module has been tested with the following HuggingFace models
|
42
|
+
|
43
|
+
* :obj:`llm_to_use="lmsys/vicuna-7b-v1.5"`
|
44
|
+
|
45
|
+
and may not work with other models. See other models at `HuggingFace
|
46
|
+
Models <https://huggingface.co/models>`_ and let us know if you
|
47
|
+
encounter any issues.
|
48
|
+
|
49
|
+
.. note::
|
50
|
+
For an example of using :class:`MoleculeGPT`, see
|
51
|
+
`examples/llm/molecule_gpt.py <https://github.com/pyg-team/
|
52
|
+
pytorch_geometric/blob/master/examples/llm/molecule_gpt.py>`_.
|
53
|
+
"""
|
54
|
+
def __init__(
|
55
|
+
self,
|
56
|
+
llm: LLM,
|
57
|
+
graph_encoder: torch.nn.Module,
|
58
|
+
smiles_encoder: torch.nn.Module,
|
59
|
+
mlp_out_channels: int = 32,
|
60
|
+
max_tokens: Optional[int] = 20,
|
61
|
+
) -> None:
|
62
|
+
super().__init__()
|
63
|
+
self.llm = llm
|
64
|
+
self.graph_encoder = graph_encoder.to(self.llm.device)
|
65
|
+
self.smiles_encoder = smiles_encoder.to(self.llm.device)
|
66
|
+
|
67
|
+
self.graph_qformer = QFormer(
|
68
|
+
input_dim=self.graph_encoder.nn[-1].out_features,
|
69
|
+
hidden_dim=mlp_out_channels,
|
70
|
+
output_dim=mlp_out_channels,
|
71
|
+
num_heads=4,
|
72
|
+
num_layers=2,
|
73
|
+
).to(self.llm.device)
|
74
|
+
|
75
|
+
self.smiles_qformer = QFormer(
|
76
|
+
input_dim=self.smiles_encoder.model.pooler.dense.out_features,
|
77
|
+
hidden_dim=mlp_out_channels,
|
78
|
+
output_dim=mlp_out_channels,
|
79
|
+
num_heads=4,
|
80
|
+
num_layers=2,
|
81
|
+
).to(self.llm.device)
|
82
|
+
|
83
|
+
self.max_tokens = max_tokens
|
84
|
+
|
85
|
+
self.word_embedding = self.llm.word_embedding
|
86
|
+
self.llm_generator = self.llm.llm
|
87
|
+
|
88
|
+
# LLMs
|
89
|
+
in_dim = 2 * mlp_out_channels * max_tokens
|
90
|
+
out_dim = self.llm.llm.model.embed_tokens.embedding_dim
|
91
|
+
self.projector = torch.nn.Sequential(
|
92
|
+
torch.nn.Linear(in_dim, in_dim),
|
93
|
+
torch.nn.Sigmoid(),
|
94
|
+
torch.nn.Linear(in_dim, out_dim),
|
95
|
+
).to(self.llm.device)
|
96
|
+
|
97
|
+
def encode(
|
98
|
+
self,
|
99
|
+
x: Tensor,
|
100
|
+
edge_index: Tensor,
|
101
|
+
batch: Tensor,
|
102
|
+
edge_attr: Optional[Tensor],
|
103
|
+
smiles: List[str],
|
104
|
+
) -> Tensor:
|
105
|
+
batch_size = len(smiles)
|
106
|
+
# 2D Graph Branch: [bs, node_len, d]
|
107
|
+
x = x.to(self.llm.device)
|
108
|
+
edge_index = edge_index.to(self.llm.device)
|
109
|
+
if edge_attr is not None:
|
110
|
+
edge_attr = edge_attr.to(self.llm.device)
|
111
|
+
batch = batch.to(self.llm.device)
|
112
|
+
|
113
|
+
x_graph = self.graph_encoder(x, edge_index, edge_attr=edge_attr)
|
114
|
+
x_graph = to_dense_batch(x_graph, batch)[0]
|
115
|
+
out_graph = self.graph_qformer(x_graph)
|
116
|
+
out_graph = pad_or_truncate(out_graph, max_seq_len=self.max_tokens,
|
117
|
+
padding_value=0)
|
118
|
+
out_graph = out_graph.view(batch_size, -1)
|
119
|
+
|
120
|
+
# 1D SMILES Branch: [bs, seq_len, d]
|
121
|
+
x_smiles = self.smiles_encoder.encode(smiles,
|
122
|
+
output_device=self.llm.device)
|
123
|
+
out_smiles = self.smiles_qformer(x_smiles)
|
124
|
+
out_smiles = pad_or_truncate(out_smiles, max_seq_len=self.max_tokens,
|
125
|
+
padding_value=0)
|
126
|
+
out_smiles = out_smiles.view(batch_size, -1)
|
127
|
+
|
128
|
+
# Merge into LLMs
|
129
|
+
x_cat = torch.cat([out_graph, out_smiles], dim=1)
|
130
|
+
return x_cat
|
131
|
+
|
132
|
+
def forward(
|
133
|
+
self,
|
134
|
+
x: Tensor,
|
135
|
+
edge_index: Tensor,
|
136
|
+
batch: Tensor,
|
137
|
+
edge_attr: Optional[Tensor],
|
138
|
+
smiles: List[str],
|
139
|
+
instructions: List[str],
|
140
|
+
label: List[str],
|
141
|
+
additional_text_context: Optional[List[str]] = None,
|
142
|
+
):
|
143
|
+
x = self.encode(x, edge_index, batch, edge_attr, smiles)
|
144
|
+
x = self.projector(x)
|
145
|
+
xs = x.split(1, dim=0)
|
146
|
+
|
147
|
+
batch_unique = batch.unique()
|
148
|
+
batch_size = len(instructions)
|
149
|
+
if len(batch_unique) < batch_size:
|
150
|
+
xs = [
|
151
|
+
xs[i] if i in batch_unique else None for i in range(batch_size)
|
152
|
+
]
|
153
|
+
|
154
|
+
(
|
155
|
+
inputs_embeds,
|
156
|
+
attention_mask,
|
157
|
+
label_input_ids,
|
158
|
+
) = self.llm._get_embeds(instructions, additional_text_context, xs,
|
159
|
+
label)
|
160
|
+
|
161
|
+
with self.llm.autocast_context:
|
162
|
+
outputs = self.llm_generator(
|
163
|
+
inputs_embeds=inputs_embeds,
|
164
|
+
attention_mask=attention_mask,
|
165
|
+
return_dict=True,
|
166
|
+
labels=label_input_ids,
|
167
|
+
)
|
168
|
+
|
169
|
+
return outputs.loss
|
170
|
+
|
171
|
+
@torch.no_grad()
|
172
|
+
def inference(
|
173
|
+
self,
|
174
|
+
x: Tensor,
|
175
|
+
edge_index: Tensor,
|
176
|
+
batch: Tensor,
|
177
|
+
edge_attr: Optional[Tensor],
|
178
|
+
smiles: List[str],
|
179
|
+
instructions: List[str],
|
180
|
+
additional_text_context: Optional[List[str]] = None,
|
181
|
+
max_out_tokens: Optional[int] = MAX_NEW_TOKENS,
|
182
|
+
):
|
183
|
+
x = self.encode(x, edge_index, batch, edge_attr, smiles)
|
184
|
+
x = self.projector(x)
|
185
|
+
xs = x.split(1, dim=0)
|
186
|
+
|
187
|
+
# Handle questions without node features:
|
188
|
+
batch_unique = batch.unique()
|
189
|
+
batch_size = len(instructions)
|
190
|
+
if len(batch_unique) < batch_size:
|
191
|
+
xs = [
|
192
|
+
xs[i] if i in batch_unique else None for i in range(batch_size)
|
193
|
+
]
|
194
|
+
|
195
|
+
inputs_embeds, attention_mask, _ = self.llm._get_embeds(
|
196
|
+
instructions, additional_text_context, xs)
|
197
|
+
|
198
|
+
bos_token = self.llm.tokenizer(
|
199
|
+
BOS,
|
200
|
+
add_special_tokens=False,
|
201
|
+
).input_ids[0]
|
202
|
+
|
203
|
+
with self.llm.autocast_context:
|
204
|
+
outputs = self.llm_generator.generate(
|
205
|
+
inputs_embeds=inputs_embeds,
|
206
|
+
max_new_tokens=max_out_tokens,
|
207
|
+
attention_mask=attention_mask,
|
208
|
+
bos_token_id=bos_token,
|
209
|
+
use_cache=True # Important to set!
|
210
|
+
)
|
211
|
+
|
212
|
+
return self.llm.tokenizer.batch_decode(
|
213
|
+
outputs,
|
214
|
+
skip_special_tokens=True,
|
215
|
+
)
|
216
|
+
|
217
|
+
def __repr__(self) -> str:
|
218
|
+
return (f'{self.__class__.__name__}(\n'
|
219
|
+
f' llm={self.llm},\n'
|
220
|
+
f' graph={self.graph_encoder.__class__.__name__},\n'
|
221
|
+
f' smiles={self.smiles_encoder},\n'
|
222
|
+
f')')
|
torch_geometric/nn/nlp/llm.py
CHANGED
@@ -56,7 +56,7 @@ class LLM(torch.nn.Module):
|
|
56
56
|
allocate the correct number of GPUs needed, given the available GPU
|
57
57
|
memory of your GPUs.
|
58
58
|
dtype (torch.dtype, optional): The data type to use for the LLM.
|
59
|
-
(default :obj: `torch.
|
59
|
+
(default :obj: `torch.bfloat16`)
|
60
60
|
"""
|
61
61
|
def __init__(
|
62
62
|
self,
|
@@ -10,6 +10,7 @@ class PoolingStrategy(Enum):
|
|
10
10
|
MEAN = 'mean'
|
11
11
|
LAST = 'last'
|
12
12
|
CLS = 'cls'
|
13
|
+
LAST_HIDDEN_STATE = 'last_hidden_state'
|
13
14
|
|
14
15
|
|
15
16
|
class SentenceTransformer(torch.nn.Module):
|
@@ -38,6 +39,8 @@ class SentenceTransformer(torch.nn.Module):
|
|
38
39
|
emb = mean_pooling(emb, attention_mask)
|
39
40
|
elif self.pooling_strategy == PoolingStrategy.LAST:
|
40
41
|
emb = last_pooling(emb, attention_mask)
|
42
|
+
elif self.pooling_strategy == PoolingStrategy.LAST_HIDDEN_STATE:
|
43
|
+
emb = out.last_hidden_state
|
41
44
|
else:
|
42
45
|
assert self.pooling_strategy == PoolingStrategy.CLS
|
43
46
|
emb = emb[:, 0, :]
|
File without changes
|