pyg-nightly 2.7.0.dev20241119__py3-none-any.whl → 2.7.0.dev20241121__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,384 @@
1
+ from typing import List, Optional, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from tqdm import tqdm
6
+
7
+ from torch_geometric.loader import DataLoader, NeighborLoader
8
+ from torch_geometric.nn.models import GraphSAGE, basic_gnn
9
+
10
+
11
+ class GLEM(torch.nn.Module):
12
+ r"""This GNN+LM co-training model is based on GLEM from the `"Learning on
13
+ Large-scale Text-attributed Graphs via Variational Inference"
14
+ <https://arxiv.org/abs/2210.14709>`_ paper.
15
+
16
+ Args:
17
+ lm_to_use (str): A TextEncoder from huggingface model repo
18
+ with a classifier(default: TinyBERT)
19
+ gnn_to_use (torch_geometric.nn.models): (default: GraphSAGE)
20
+ out_channels (int): output channels for LM and GNN, should be same
21
+ num_gnn_heads Optional[int]: Number of heads for attention, if needed
22
+ num_gnn_layers (int): number of gnn layers
23
+ gnn_loss: loss function for gnn, (default: CrossEntropyLoss)
24
+ lm_loss: loss function for Language Model, (default: CrossEntropyLoss)
25
+ alpha (float): pseudo label weight of E-step, LM optimization,
26
+ (default: 0.5)
27
+ beta (float): pseudo label weight of M-step, GNN optimization,
28
+ (default: 0.5)
29
+ lm_dtype (torch.dtype): the data type once you load LM into memory,
30
+ (default: torch.bfloat16)
31
+ lm_use_lora (bool): choose if LM use Lora peft for fine tune,
32
+ (default: True)
33
+ lora_target_modules: The names of the target modules to apply the lora
34
+ adapter to, e.g. ['q_proj', 'v_proj'] for LLM , (default: None)
35
+
36
+ .. note::
37
+ See `examples/llm_plus_gnn/glem.py` for example usage.
38
+ """
39
+ def __init__(
40
+ self,
41
+ lm_to_use: str = 'prajjwal1/bert-tiny',
42
+ gnn_to_use: basic_gnn = GraphSAGE,
43
+ out_channels: int = 47,
44
+ gnn_loss=nn.CrossEntropyLoss(reduction='mean'),
45
+ lm_loss=nn.CrossEntropyLoss(reduction='mean'),
46
+ alpha: float = 0.5,
47
+ beta: float = 0.5,
48
+ lm_dtype: torch.dtype = torch.bfloat16,
49
+ lm_use_lora: bool = True,
50
+ lora_target_modules: Optional[Union[List[str], str]] = None,
51
+ device: Union[str, torch.device] = torch.device('cpu'),
52
+ ):
53
+ super().__init__()
54
+ self.device = device
55
+ self.lm_loss = lm_loss
56
+ self.gnn = gnn_to_use
57
+ self.gnn_loss = gnn_loss
58
+ self.alpha = alpha
59
+ self.beta = beta
60
+ self.gnn_loss = gnn_loss
61
+ self.lm = lm_to_use
62
+ from transformers import AutoModelForSequenceClassification
63
+ self.lm = AutoModelForSequenceClassification.from_pretrained(
64
+ lm_to_use, num_labels=out_channels, torch_dtype=lm_dtype,
65
+ offload_folder="offload", trust_remote_code=True)
66
+ if lm_use_lora:
67
+ from peft import (
68
+ LoraConfig,
69
+ TaskType,
70
+ get_peft_model,
71
+ prepare_model_for_kbit_training,
72
+ )
73
+ print("Training LM with LORA!")
74
+ self.lm = prepare_model_for_kbit_training(self.lm)
75
+ config = LoraConfig(task_type=TaskType.SEQ_CLS, r=16,
76
+ lora_alpha=16, lora_dropout=0.05, bias="none",
77
+ target_modules=lora_target_modules)
78
+ self.lm = get_peft_model(self.lm, config)
79
+ self.lm.print_trainable_parameters()
80
+ self.lm.config.pad_token_id = self.lm.config.eos_token_id
81
+ self.lm_device = self.lm.device
82
+
83
+ if self.lm.num_labels != self.gnn.out_channels:
84
+ raise ValueError('''The output channel of language model \
85
+ and gnn should be the same''')
86
+
87
+ def pre_train_gnn(self, train_loader: NeighborLoader,
88
+ optimizer: torch.optim.Optimizer, num_epochs: int,
89
+ patience: int, ext_pseudo_labels: torch.Tensor = None,
90
+ is_augmented: bool = False, verbose: bool = True):
91
+ # Pretrain GNN, optional steps if you do not have pseudo labels.
92
+ best_acc = 0
93
+ early_stopping = 0
94
+ # training only based on gold data
95
+ for epoch in range(0, num_epochs):
96
+ acc, loss = self.train_gnn(train_loader, optimizer, epoch,
97
+ ext_pseudo_labels, is_augmented,
98
+ verbose)
99
+ if acc < best_acc:
100
+ early_stopping += 1
101
+ if early_stopping > patience:
102
+ print(f'Early stopped by Epoch: {epoch}, '
103
+ f'Best acc: {best_acc}')
104
+ break
105
+ best_acc = max(best_acc, acc)
106
+
107
+ def pre_train_lm(self, train_loader: DataLoader,
108
+ optimizer: torch.optim.Optimizer, num_epochs: int,
109
+ patience: int, ext_pseudo_labels: torch.Tensor = None,
110
+ is_augmented: bool = False, verbose: bool = True):
111
+ # Pretrain language model
112
+ best_acc = 0
113
+ early_stopping = 0
114
+ for epoch in range(1, num_epochs + 1):
115
+ acc, loss = self.train_lm(train_loader, optimizer, epoch,
116
+ ext_pseudo_labels, is_augmented, verbose)
117
+ if acc < best_acc:
118
+ early_stopping += 1
119
+ if early_stopping > patience:
120
+ print(f'Early stopped by Epoch: {epoch}, '
121
+ f'Best acc: {best_acc}')
122
+ break
123
+ best_acc = max(best_acc, acc)
124
+
125
+ def train(self, em_phase: str, train_loader: Union[DataLoader,
126
+ NeighborLoader],
127
+ optimizer: torch.optim.Optimizer, pseudo_labels: torch.Tensor,
128
+ epoch: int, is_augmented: bool = False, verbose: bool = False):
129
+ r"""GLEM training step, EM steps.
130
+
131
+ Args:
132
+ em_phase(str): 'gnn' or 'lm' choose which phase you are training on
133
+ train_loader(Union[DataLoader, NeighborLoader]): use DataLoader for
134
+ lm training, include tokenized data, labels is_gold mask.
135
+ use NeighborLoader for gnn training, include x, edge_index.
136
+ optimizer (torch.optim.Optimizer): optimizer for training
137
+ pseudo_labels(torch.Tensor): the predicted labels used as pseudo
138
+ labels
139
+ epoch (int): current epoch
140
+ is_augmented (bool): will use pseudo_labels or not
141
+ verbose (bool): print training progress bar or not
142
+
143
+ Returns:
144
+ acc (float): training accuracy
145
+ loss (float): loss value
146
+ """
147
+ pseudo_labels = pseudo_labels.to(self.device)
148
+ if em_phase == 'gnn':
149
+ acc, loss = self.train_gnn(train_loader, optimizer, epoch,
150
+ pseudo_labels, is_augmented, verbose)
151
+ if em_phase == 'lm':
152
+ acc, loss = self.train_lm(train_loader, optimizer, epoch,
153
+ pseudo_labels, is_augmented, verbose)
154
+ return acc, loss
155
+
156
+ def train_lm(self, train_loader: DataLoader,
157
+ optimizer: torch.optim.Optimizer, epoch: int,
158
+ pseudo_labels: torch.Tensor = None,
159
+ is_augmented: bool = False, verbose: bool = True):
160
+ r"""Language model Training in every epoch.
161
+
162
+ Args:
163
+ train_loader (loader.dataloader.DataLoader): text token dataloader
164
+ optimizer (torch.optim.Optimizer): model optimizer
165
+ epoch (int): current train epoch
166
+ pseudo_labels (torch.Tensor): 1-D tensor, predictions from gnn
167
+ is_augmented (bool): train with pseudo labels or not
168
+ verbose (bool): print training progress bar or not
169
+
170
+ Returns:
171
+ approx_acc (torch.tensor): training accuracy
172
+ loss (torch.float): loss value
173
+
174
+ """
175
+ all_out = []
176
+ total_loss = total_correct = 0
177
+ num_nodes = train_loader.dataset.indices.size(0)
178
+ self.lm.train()
179
+ if verbose:
180
+ pbar = tqdm(total=num_nodes)
181
+ pbar.set_description(f'Epoch {epoch:02d}')
182
+ for batch in train_loader:
183
+ inputs = {k: v.to(self.device) for k, v in batch['input'].items()}
184
+ out = self.lm(**inputs).logits
185
+ labels = batch['labels'].to(self.device).squeeze()
186
+ # training with pseudo labels or not
187
+ if is_augmented:
188
+ pl_batch = pseudo_labels[batch['n_id']].to(self.device)
189
+ else:
190
+ pl_batch = None
191
+ loss = self.loss(out, labels, self.lm_loss,
192
+ batch['is_gold'].to(self.device), pl_batch,
193
+ self.alpha, is_augmented)
194
+ loss.backward()
195
+ optimizer.step()
196
+ optimizer.zero_grad()
197
+ all_out.append(out)
198
+ total_correct += int(out.argmax(dim=-1).eq(labels).sum())
199
+ total_loss += float(loss)
200
+ if verbose:
201
+ pbar.update(batch['n_id'].size(0))
202
+
203
+ all_out = torch.cat(all_out, dim=0)
204
+ approx_acc = total_correct / num_nodes
205
+ loss = total_loss / len(train_loader)
206
+ if verbose:
207
+ pbar.close()
208
+ print(f'Epoch {epoch:02d} Loss: {loss:.4f} '
209
+ f'Approx. Train: {approx_acc:.4f}')
210
+ return approx_acc, loss
211
+
212
+ def train_gnn(self, train_loader: NeighborLoader,
213
+ optimizer: torch.optim.Optimizer, epoch: int,
214
+ pseudo_labels: torch.Tensor = None,
215
+ is_augmented: bool = False, verbose: bool = True):
216
+ r"""GNN training step in every epoch.
217
+
218
+ Args:
219
+ train_loader (loader.NeighborLoader): gnn Neighbor node loader
220
+ optimizer (torch.optim.Optimizer): model optimizer
221
+ epoch (int): current train epoch
222
+ pseudo_labels(torch.tensor): 1-D tensor, predictions from lm
223
+ is_augmented(bool): use pseudo labeled node or not
224
+ verbose (bool): print training progress or not
225
+
226
+ Returns:
227
+ approx_acc (torch.tensor): training accuracy
228
+ loss (torch.float): loss value
229
+ """
230
+ self.gnn.train()
231
+ num_nodes = train_loader.input_nodes.size(0)
232
+ if verbose:
233
+ pbar = tqdm(total=num_nodes)
234
+ pbar.set_description(f'Epoch {epoch:02d}')
235
+ total_loss = total_correct = 0
236
+ all_out = []
237
+ for batch in train_loader:
238
+ batch = batch.to(self.device)
239
+ out = self.gnn(batch.x, batch.edge_index)[:batch.batch_size]
240
+ all_out.append(out)
241
+ labels = batch.y[:batch.batch_size].squeeze()
242
+ is_gold_batch = batch.is_gold[:batch.batch_size].squeeze()
243
+ # training with pseudo labels or not
244
+ if is_augmented and pseudo_labels is not None:
245
+ pl_batch = pseudo_labels[batch.n_id[:batch.batch_size]]
246
+ else:
247
+ pl_batch = None
248
+ loss = self.loss(out, labels, self.gnn_loss, is_gold_batch,
249
+ pl_batch, self.beta, is_augmented)
250
+ loss.backward()
251
+ optimizer.step()
252
+ optimizer.zero_grad()
253
+ total_loss += float(loss)
254
+ total_correct += int(out.argmax(dim=-1).eq(labels).sum())
255
+ if verbose:
256
+ pbar.update(batch.batch_size)
257
+
258
+ all_out = torch.cat(all_out, dim=0)
259
+ loss = total_loss / len(train_loader)
260
+ approx_acc = total_correct / num_nodes
261
+ if verbose:
262
+ pbar.close()
263
+ print(f'Epoch: {epoch:02d} Loss: {loss:.4f} '
264
+ f'Approx. Train: {approx_acc:.4f}')
265
+ return approx_acc, loss
266
+
267
+ @torch.no_grad()
268
+ def inference(self, em_phase: str, data_loader: Union[NeighborLoader,
269
+ DataLoader],
270
+ verbose: bool = False):
271
+ r"""GLEM inference step.
272
+
273
+ Args:
274
+ em_phase(str): 'gnn' or 'lm'
275
+ data_loader(dataloader or Neighborloader):
276
+ dataloader: for lm training, include tokenized data
277
+ nodeloader: for gnn training, include x, edge_index
278
+ verbose(bool): print inference progress or not
279
+
280
+ Returns:
281
+ out (torch.Tensor): n * m tensor, m is number of classes,
282
+ n is number of nodes
283
+ """
284
+ out = None
285
+ if em_phase == 'gnn':
286
+ self.gnn.eval()
287
+ out = self.inference_gnn(data_loader, verbose)
288
+ elif em_phase == 'lm':
289
+ self.lm.eval()
290
+ out = self.inference_lm(data_loader, verbose)
291
+ return out
292
+
293
+ @torch.no_grad()
294
+ def inference_lm(self, data_loader: DataLoader, verbose: bool = True):
295
+ r"""LM inference step.
296
+
297
+ Args:
298
+ data_loader (Dataloader): include token, labels, and gold mask
299
+ verbose (bool): print progress bar or not
300
+
301
+ Returns:
302
+ preds (tensor): prediction from GNN, convert to pseudo labels
303
+ by preds.argmax(dim=-1).unsqueeze(1)
304
+ """
305
+ if verbose:
306
+ pbar = tqdm(total=data_loader.dataset._data.num_nodes)
307
+ pbar.set_description('LM inference stage')
308
+ self.lm.eval()
309
+ preds = []
310
+ for batch in data_loader:
311
+ inputs = {k: v.to(self.device) for k, v in batch['input'].items()}
312
+ logits = self.lm(**inputs).logits
313
+ preds.append(logits)
314
+ if verbose:
315
+ pbar.update(batch['n_id'].size(0))
316
+ if verbose:
317
+ pbar.close()
318
+ preds = torch.cat(preds)
319
+ return preds
320
+
321
+ @torch.no_grad()
322
+ def inference_gnn(self, data_loader: NeighborLoader, verbose: bool = True):
323
+ r"""GNN inference step.
324
+
325
+ Args:
326
+ data_loader(NeighborLoader): include x, edge_index,
327
+ verbose (bool): print progress bar or not
328
+
329
+ Returns:
330
+ preds (tensor): prediction from GNN,
331
+ convert to pseudo labels by preds.argmax(dim=-1).unsqueeze(1)
332
+ """
333
+ if verbose:
334
+ pbar = tqdm(total=data_loader.data.num_nodes)
335
+ pbar.set_description('GNN inference stage')
336
+ preds = []
337
+ self.gnn.eval()
338
+ for batch in data_loader:
339
+ batch = batch.to(self.device)
340
+ out = self.gnn(batch.x, batch.edge_index)[:batch.batch_size]
341
+ preds.append(out)
342
+ if verbose:
343
+ pbar.update(batch.batch_size)
344
+ if verbose:
345
+ pbar.close()
346
+ preds = torch.cat(preds, dim=0)
347
+ return preds
348
+
349
+ def loss(self, logits: torch.Tensor, labels: torch.Tensor,
350
+ loss_func: torch.nn.functional, is_gold: torch.Tensor,
351
+ pseudo_labels: torch.Tensor = None, pl_weight: float = 0.5,
352
+ is_augmented: bool = True):
353
+ r"""Core function of variational EM inference, this function is aming
354
+ on combining loss value on gold(original train) and loss value on
355
+ pseudo labels.
356
+
357
+ Reference:
358
+ <https://github.com/AndyJZhao/GLEM/blob/main/src/models/GLEM/GLEM_utils.py> # noqa
359
+
360
+ Args:
361
+ logits(torch.tensor): predict results from LM or GNN
362
+ labels(torch.tensor): combined node labels from ground truth and
363
+ pseudo labels(if provided)
364
+ loss_func(torch.nn.modules.loss): loss function for classification
365
+ is_gold(tensor): a tensor with bool value that mask ground truth
366
+ label and during training, thus ~is_gold mask pseudo labels
367
+ pseudo_labels(torch.tensor): predictions from other model
368
+ pl_weight: the pseudo labels used in E-step and M-step optimization
369
+ alpha in E-step, beta in M-step respectively
370
+ is_augmented: use EM or just train GNN and LM with gold data
371
+
372
+ """
373
+ def deal_nan(x):
374
+ return 0 if torch.isnan(x) else x
375
+
376
+ if is_augmented and (sum(~is_gold) > 0):
377
+ mle_loss = deal_nan(loss_func(logits[is_gold], labels[is_gold]))
378
+ # all other labels beside from ground truth(gold labels)
379
+ pseudo_label_loss = deal_nan(
380
+ loss_func(logits[~is_gold], pseudo_labels[~is_gold]))
381
+ loss = pl_weight * pseudo_label_loss + (1 - pl_weight) * mle_loss
382
+ else:
383
+ loss = loss_func(logits, labels)
384
+ return loss
@@ -0,0 +1,222 @@
1
+ from typing import List, Optional
2
+
3
+ import torch
4
+ from torch import Tensor
5
+
6
+ from torch_geometric.nn.attention import QFormer
7
+ from torch_geometric.nn.nlp.llm import BOS, LLM, MAX_NEW_TOKENS
8
+ from torch_geometric.utils import to_dense_batch
9
+
10
+
11
+ def pad_or_truncate(embeddings: Tensor, max_seq_len: int,
12
+ padding_value: int = 0) -> Tensor:
13
+ batch_size, current_seq_len, d = embeddings.size()
14
+
15
+ if current_seq_len > max_seq_len:
16
+ return embeddings[:, :max_seq_len, :]
17
+ elif current_seq_len < max_seq_len:
18
+ pad_tensor = torch.full((batch_size, max_seq_len - current_seq_len, d),
19
+ padding_value, dtype=embeddings.dtype,
20
+ device=embeddings.device)
21
+ return torch.cat([embeddings, pad_tensor], dim=1)
22
+ else:
23
+ return embeddings
24
+
25
+
26
+ class MoleculeGPT(torch.nn.Module):
27
+ r"""The MoleculeGPT model from the `"MoleculeGPT: Instruction
28
+ Following Large Language Models for Molecular Property Prediction"
29
+ <https://ai4d3.github.io/papers/34.pdf>`_ paper.
30
+
31
+ Args:
32
+ llm (LLM): The LLM to use.
33
+ graph_encoder (torch.nn.Module): Encode 2D molecule graph.
34
+ smiles_encoder (torch.nn.Module): Encode 1D SMILES.
35
+ mlp_out_channels (int, optional): The size of each embedding
36
+ after qformer encoding. (default: :obj:`32`)
37
+ max_tokens (int, optional): Max output tokens of 1D/2D encoder.
38
+ (default: :obj:`20`)
39
+
40
+ .. warning::
41
+ This module has been tested with the following HuggingFace models
42
+
43
+ * :obj:`llm_to_use="lmsys/vicuna-7b-v1.5"`
44
+
45
+ and may not work with other models. See other models at `HuggingFace
46
+ Models <https://huggingface.co/models>`_ and let us know if you
47
+ encounter any issues.
48
+
49
+ .. note::
50
+ For an example of using :class:`MoleculeGPT`, see
51
+ `examples/llm/molecule_gpt.py <https://github.com/pyg-team/
52
+ pytorch_geometric/blob/master/examples/llm/molecule_gpt.py>`_.
53
+ """
54
+ def __init__(
55
+ self,
56
+ llm: LLM,
57
+ graph_encoder: torch.nn.Module,
58
+ smiles_encoder: torch.nn.Module,
59
+ mlp_out_channels: int = 32,
60
+ max_tokens: Optional[int] = 20,
61
+ ) -> None:
62
+ super().__init__()
63
+ self.llm = llm
64
+ self.graph_encoder = graph_encoder.to(self.llm.device)
65
+ self.smiles_encoder = smiles_encoder.to(self.llm.device)
66
+
67
+ self.graph_qformer = QFormer(
68
+ input_dim=self.graph_encoder.nn[-1].out_features,
69
+ hidden_dim=mlp_out_channels,
70
+ output_dim=mlp_out_channels,
71
+ num_heads=4,
72
+ num_layers=2,
73
+ ).to(self.llm.device)
74
+
75
+ self.smiles_qformer = QFormer(
76
+ input_dim=self.smiles_encoder.model.pooler.dense.out_features,
77
+ hidden_dim=mlp_out_channels,
78
+ output_dim=mlp_out_channels,
79
+ num_heads=4,
80
+ num_layers=2,
81
+ ).to(self.llm.device)
82
+
83
+ self.max_tokens = max_tokens
84
+
85
+ self.word_embedding = self.llm.word_embedding
86
+ self.llm_generator = self.llm.llm
87
+
88
+ # LLMs
89
+ in_dim = 2 * mlp_out_channels * max_tokens
90
+ out_dim = self.llm.llm.model.embed_tokens.embedding_dim
91
+ self.projector = torch.nn.Sequential(
92
+ torch.nn.Linear(in_dim, in_dim),
93
+ torch.nn.Sigmoid(),
94
+ torch.nn.Linear(in_dim, out_dim),
95
+ ).to(self.llm.device)
96
+
97
+ def encode(
98
+ self,
99
+ x: Tensor,
100
+ edge_index: Tensor,
101
+ batch: Tensor,
102
+ edge_attr: Optional[Tensor],
103
+ smiles: List[str],
104
+ ) -> Tensor:
105
+ batch_size = len(smiles)
106
+ # 2D Graph Branch: [bs, node_len, d]
107
+ x = x.to(self.llm.device)
108
+ edge_index = edge_index.to(self.llm.device)
109
+ if edge_attr is not None:
110
+ edge_attr = edge_attr.to(self.llm.device)
111
+ batch = batch.to(self.llm.device)
112
+
113
+ x_graph = self.graph_encoder(x, edge_index, edge_attr=edge_attr)
114
+ x_graph = to_dense_batch(x_graph, batch)[0]
115
+ out_graph = self.graph_qformer(x_graph)
116
+ out_graph = pad_or_truncate(out_graph, max_seq_len=self.max_tokens,
117
+ padding_value=0)
118
+ out_graph = out_graph.view(batch_size, -1)
119
+
120
+ # 1D SMILES Branch: [bs, seq_len, d]
121
+ x_smiles = self.smiles_encoder.encode(smiles,
122
+ output_device=self.llm.device)
123
+ out_smiles = self.smiles_qformer(x_smiles)
124
+ out_smiles = pad_or_truncate(out_smiles, max_seq_len=self.max_tokens,
125
+ padding_value=0)
126
+ out_smiles = out_smiles.view(batch_size, -1)
127
+
128
+ # Merge into LLMs
129
+ x_cat = torch.cat([out_graph, out_smiles], dim=1)
130
+ return x_cat
131
+
132
+ def forward(
133
+ self,
134
+ x: Tensor,
135
+ edge_index: Tensor,
136
+ batch: Tensor,
137
+ edge_attr: Optional[Tensor],
138
+ smiles: List[str],
139
+ instructions: List[str],
140
+ label: List[str],
141
+ additional_text_context: Optional[List[str]] = None,
142
+ ):
143
+ x = self.encode(x, edge_index, batch, edge_attr, smiles)
144
+ x = self.projector(x)
145
+ xs = x.split(1, dim=0)
146
+
147
+ batch_unique = batch.unique()
148
+ batch_size = len(instructions)
149
+ if len(batch_unique) < batch_size:
150
+ xs = [
151
+ xs[i] if i in batch_unique else None for i in range(batch_size)
152
+ ]
153
+
154
+ (
155
+ inputs_embeds,
156
+ attention_mask,
157
+ label_input_ids,
158
+ ) = self.llm._get_embeds(instructions, additional_text_context, xs,
159
+ label)
160
+
161
+ with self.llm.autocast_context:
162
+ outputs = self.llm_generator(
163
+ inputs_embeds=inputs_embeds,
164
+ attention_mask=attention_mask,
165
+ return_dict=True,
166
+ labels=label_input_ids,
167
+ )
168
+
169
+ return outputs.loss
170
+
171
+ @torch.no_grad()
172
+ def inference(
173
+ self,
174
+ x: Tensor,
175
+ edge_index: Tensor,
176
+ batch: Tensor,
177
+ edge_attr: Optional[Tensor],
178
+ smiles: List[str],
179
+ instructions: List[str],
180
+ additional_text_context: Optional[List[str]] = None,
181
+ max_out_tokens: Optional[int] = MAX_NEW_TOKENS,
182
+ ):
183
+ x = self.encode(x, edge_index, batch, edge_attr, smiles)
184
+ x = self.projector(x)
185
+ xs = x.split(1, dim=0)
186
+
187
+ # Handle questions without node features:
188
+ batch_unique = batch.unique()
189
+ batch_size = len(instructions)
190
+ if len(batch_unique) < batch_size:
191
+ xs = [
192
+ xs[i] if i in batch_unique else None for i in range(batch_size)
193
+ ]
194
+
195
+ inputs_embeds, attention_mask, _ = self.llm._get_embeds(
196
+ instructions, additional_text_context, xs)
197
+
198
+ bos_token = self.llm.tokenizer(
199
+ BOS,
200
+ add_special_tokens=False,
201
+ ).input_ids[0]
202
+
203
+ with self.llm.autocast_context:
204
+ outputs = self.llm_generator.generate(
205
+ inputs_embeds=inputs_embeds,
206
+ max_new_tokens=max_out_tokens,
207
+ attention_mask=attention_mask,
208
+ bos_token_id=bos_token,
209
+ use_cache=True # Important to set!
210
+ )
211
+
212
+ return self.llm.tokenizer.batch_decode(
213
+ outputs,
214
+ skip_special_tokens=True,
215
+ )
216
+
217
+ def __repr__(self) -> str:
218
+ return (f'{self.__class__.__name__}(\n'
219
+ f' llm={self.llm},\n'
220
+ f' graph={self.graph_encoder.__class__.__name__},\n'
221
+ f' smiles={self.smiles_encoder},\n'
222
+ f')')
@@ -56,7 +56,7 @@ class LLM(torch.nn.Module):
56
56
  allocate the correct number of GPUs needed, given the available GPU
57
57
  memory of your GPUs.
58
58
  dtype (torch.dtype, optional): The data type to use for the LLM.
59
- (default :obj: `torch.bloat16`)
59
+ (default :obj: `torch.bfloat16`)
60
60
  """
61
61
  def __init__(
62
62
  self,
@@ -10,6 +10,7 @@ class PoolingStrategy(Enum):
10
10
  MEAN = 'mean'
11
11
  LAST = 'last'
12
12
  CLS = 'cls'
13
+ LAST_HIDDEN_STATE = 'last_hidden_state'
13
14
 
14
15
 
15
16
  class SentenceTransformer(torch.nn.Module):
@@ -38,6 +39,8 @@ class SentenceTransformer(torch.nn.Module):
38
39
  emb = mean_pooling(emb, attention_mask)
39
40
  elif self.pooling_strategy == PoolingStrategy.LAST:
40
41
  emb = last_pooling(emb, attention_mask)
42
+ elif self.pooling_strategy == PoolingStrategy.LAST_HIDDEN_STATE:
43
+ emb = out.last_hidden_state
41
44
  else:
42
45
  assert self.pooling_strategy == PoolingStrategy.CLS
43
46
  emb = emb[:, 0, :]