pyg-nightly 2.7.0.dev20241009__py3-none-any.whl → 2.8.0.dev20251228__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251228.dist-info}/METADATA +77 -53
- {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251228.dist-info}/RECORD +227 -190
- {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251228.dist-info}/WHEEL +1 -1
- pyg_nightly-2.8.0.dev20251228.dist-info/licenses/LICENSE +19 -0
- torch_geometric/__init__.py +14 -2
- torch_geometric/_compile.py +9 -3
- torch_geometric/_onnx.py +214 -0
- torch_geometric/config_mixin.py +5 -3
- torch_geometric/config_store.py +1 -1
- torch_geometric/contrib/__init__.py +1 -1
- torch_geometric/contrib/explain/pgm_explainer.py +1 -1
- torch_geometric/data/batch.py +2 -2
- torch_geometric/data/collate.py +1 -3
- torch_geometric/data/data.py +109 -5
- torch_geometric/data/database.py +4 -0
- torch_geometric/data/dataset.py +14 -11
- torch_geometric/data/extract.py +1 -1
- torch_geometric/data/feature_store.py +17 -22
- torch_geometric/data/graph_store.py +3 -2
- torch_geometric/data/hetero_data.py +139 -7
- torch_geometric/data/hypergraph_data.py +2 -2
- torch_geometric/data/in_memory_dataset.py +2 -2
- torch_geometric/data/lightning/datamodule.py +42 -28
- torch_geometric/data/storage.py +9 -1
- torch_geometric/datasets/__init__.py +18 -1
- torch_geometric/datasets/actor.py +7 -9
- torch_geometric/datasets/airfrans.py +15 -17
- torch_geometric/datasets/airports.py +8 -10
- torch_geometric/datasets/amazon.py +8 -11
- torch_geometric/datasets/amazon_book.py +8 -9
- torch_geometric/datasets/amazon_products.py +7 -9
- torch_geometric/datasets/aminer.py +8 -9
- torch_geometric/datasets/aqsol.py +10 -13
- torch_geometric/datasets/attributed_graph_dataset.py +8 -10
- torch_geometric/datasets/ba_multi_shapes.py +10 -12
- torch_geometric/datasets/ba_shapes.py +5 -6
- torch_geometric/datasets/city.py +157 -0
- torch_geometric/datasets/dbp15k.py +1 -1
- torch_geometric/datasets/git_mol_dataset.py +263 -0
- torch_geometric/datasets/hgb_dataset.py +2 -2
- torch_geometric/datasets/hm.py +1 -1
- torch_geometric/datasets/instruct_mol_dataset.py +134 -0
- torch_geometric/datasets/md17.py +3 -3
- torch_geometric/datasets/medshapenet.py +145 -0
- torch_geometric/datasets/modelnet.py +1 -1
- torch_geometric/datasets/molecule_gpt_dataset.py +492 -0
- torch_geometric/datasets/molecule_net.py +3 -2
- torch_geometric/datasets/ppi.py +2 -1
- torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
- torch_geometric/datasets/qm7.py +1 -1
- torch_geometric/datasets/qm9.py +1 -1
- torch_geometric/datasets/snap_dataset.py +8 -4
- torch_geometric/datasets/tag_dataset.py +462 -0
- torch_geometric/datasets/teeth3ds.py +269 -0
- torch_geometric/datasets/web_qsp_dataset.py +310 -209
- torch_geometric/datasets/wikics.py +2 -1
- torch_geometric/deprecation.py +1 -1
- torch_geometric/distributed/__init__.py +13 -0
- torch_geometric/distributed/dist_loader.py +2 -2
- torch_geometric/distributed/partition.py +2 -2
- torch_geometric/distributed/rpc.py +3 -3
- torch_geometric/edge_index.py +18 -14
- torch_geometric/explain/algorithm/attention_explainer.py +219 -29
- torch_geometric/explain/algorithm/base.py +2 -2
- torch_geometric/explain/algorithm/captum.py +1 -1
- torch_geometric/explain/algorithm/captum_explainer.py +2 -1
- torch_geometric/explain/algorithm/gnn_explainer.py +406 -69
- torch_geometric/explain/algorithm/graphmask_explainer.py +8 -8
- torch_geometric/explain/algorithm/pg_explainer.py +305 -47
- torch_geometric/explain/explainer.py +2 -2
- torch_geometric/explain/explanation.py +87 -3
- torch_geometric/explain/metric/faithfulness.py +1 -1
- torch_geometric/graphgym/config.py +3 -2
- torch_geometric/graphgym/imports.py +15 -4
- torch_geometric/graphgym/logger.py +1 -1
- torch_geometric/graphgym/loss.py +1 -1
- torch_geometric/graphgym/models/encoder.py +2 -2
- torch_geometric/graphgym/models/layer.py +1 -1
- torch_geometric/graphgym/utils/comp_budget.py +4 -3
- torch_geometric/hash_tensor.py +798 -0
- torch_geometric/index.py +14 -5
- torch_geometric/inspector.py +4 -0
- torch_geometric/io/fs.py +5 -4
- torch_geometric/llm/__init__.py +9 -0
- torch_geometric/llm/large_graph_indexer.py +741 -0
- torch_geometric/llm/models/__init__.py +23 -0
- torch_geometric/{nn → llm}/models/g_retriever.py +77 -45
- torch_geometric/llm/models/git_mol.py +336 -0
- torch_geometric/llm/models/glem.py +397 -0
- torch_geometric/{nn/nlp → llm/models}/llm.py +180 -32
- torch_geometric/llm/models/llm_judge.py +158 -0
- torch_geometric/llm/models/molecule_gpt.py +222 -0
- torch_geometric/llm/models/protein_mpnn.py +333 -0
- torch_geometric/llm/models/sentence_transformer.py +188 -0
- torch_geometric/llm/models/txt2kg.py +353 -0
- torch_geometric/llm/models/vision_transformer.py +38 -0
- torch_geometric/llm/rag_loader.py +154 -0
- torch_geometric/llm/utils/__init__.py +10 -0
- torch_geometric/llm/utils/backend_utils.py +443 -0
- torch_geometric/llm/utils/feature_store.py +169 -0
- torch_geometric/llm/utils/graph_store.py +199 -0
- torch_geometric/llm/utils/vectorrag.py +125 -0
- torch_geometric/loader/cluster.py +4 -4
- torch_geometric/loader/ibmb_loader.py +4 -4
- torch_geometric/loader/link_loader.py +1 -1
- torch_geometric/loader/link_neighbor_loader.py +2 -1
- torch_geometric/loader/mixin.py +6 -5
- torch_geometric/loader/neighbor_loader.py +1 -1
- torch_geometric/loader/neighbor_sampler.py +2 -2
- torch_geometric/loader/prefetch.py +3 -2
- torch_geometric/loader/temporal_dataloader.py +2 -2
- torch_geometric/loader/utils.py +10 -10
- torch_geometric/metrics/__init__.py +14 -0
- torch_geometric/metrics/link_pred.py +745 -92
- torch_geometric/nn/__init__.py +1 -0
- torch_geometric/nn/aggr/base.py +1 -1
- torch_geometric/nn/aggr/equilibrium.py +1 -1
- torch_geometric/nn/aggr/fused.py +1 -1
- torch_geometric/nn/aggr/patch_transformer.py +8 -2
- torch_geometric/nn/aggr/set_transformer.py +1 -1
- torch_geometric/nn/aggr/utils.py +9 -4
- torch_geometric/nn/attention/__init__.py +9 -1
- torch_geometric/nn/attention/polynormer.py +107 -0
- torch_geometric/nn/attention/qformer.py +71 -0
- torch_geometric/nn/attention/sgformer.py +99 -0
- torch_geometric/nn/conv/__init__.py +2 -0
- torch_geometric/nn/conv/appnp.py +1 -1
- torch_geometric/nn/conv/cugraph/gat_conv.py +8 -2
- torch_geometric/nn/conv/cugraph/rgcn_conv.py +3 -0
- torch_geometric/nn/conv/cugraph/sage_conv.py +3 -0
- torch_geometric/nn/conv/dna_conv.py +1 -1
- torch_geometric/nn/conv/eg_conv.py +7 -7
- torch_geometric/nn/conv/gen_conv.py +1 -1
- torch_geometric/nn/conv/gravnet_conv.py +2 -1
- torch_geometric/nn/conv/hetero_conv.py +2 -1
- torch_geometric/nn/conv/meshcnn_conv.py +487 -0
- torch_geometric/nn/conv/message_passing.py +5 -4
- torch_geometric/nn/conv/rgcn_conv.py +2 -1
- torch_geometric/nn/conv/sg_conv.py +1 -1
- torch_geometric/nn/conv/spline_conv.py +2 -1
- torch_geometric/nn/conv/ssg_conv.py +1 -1
- torch_geometric/nn/conv/transformer_conv.py +5 -3
- torch_geometric/nn/data_parallel.py +5 -4
- torch_geometric/nn/dense/linear.py +0 -20
- torch_geometric/nn/encoding.py +17 -3
- torch_geometric/nn/fx.py +14 -12
- torch_geometric/nn/model_hub.py +2 -15
- torch_geometric/nn/models/__init__.py +11 -2
- torch_geometric/nn/models/attentive_fp.py +1 -1
- torch_geometric/nn/models/attract_repel.py +148 -0
- torch_geometric/nn/models/basic_gnn.py +2 -1
- torch_geometric/nn/models/captum.py +1 -1
- torch_geometric/nn/models/deep_graph_infomax.py +1 -1
- torch_geometric/nn/models/dimenet.py +2 -2
- torch_geometric/nn/models/dimenet_utils.py +4 -2
- torch_geometric/nn/models/gpse.py +1083 -0
- torch_geometric/nn/models/graph_unet.py +13 -4
- torch_geometric/nn/models/lpformer.py +783 -0
- torch_geometric/nn/models/metapath2vec.py +1 -1
- torch_geometric/nn/models/mlp.py +4 -2
- torch_geometric/nn/models/node2vec.py +1 -1
- torch_geometric/nn/models/polynormer.py +206 -0
- torch_geometric/nn/models/rev_gnn.py +3 -3
- torch_geometric/nn/models/sgformer.py +219 -0
- torch_geometric/nn/models/signed_gcn.py +1 -1
- torch_geometric/nn/models/visnet.py +2 -2
- torch_geometric/nn/norm/batch_norm.py +17 -7
- torch_geometric/nn/norm/diff_group_norm.py +7 -2
- torch_geometric/nn/norm/graph_norm.py +9 -4
- torch_geometric/nn/norm/instance_norm.py +5 -1
- torch_geometric/nn/norm/layer_norm.py +15 -7
- torch_geometric/nn/norm/msg_norm.py +8 -2
- torch_geometric/nn/pool/__init__.py +8 -4
- torch_geometric/nn/pool/cluster_pool.py +3 -4
- torch_geometric/nn/pool/connect/base.py +1 -3
- torch_geometric/nn/pool/knn.py +13 -10
- torch_geometric/nn/pool/select/base.py +1 -4
- torch_geometric/nn/to_hetero_module.py +4 -3
- torch_geometric/nn/to_hetero_transformer.py +3 -3
- torch_geometric/nn/to_hetero_with_bases_transformer.py +4 -4
- torch_geometric/profile/__init__.py +2 -0
- torch_geometric/profile/nvtx.py +66 -0
- torch_geometric/profile/utils.py +20 -5
- torch_geometric/sampler/__init__.py +2 -1
- torch_geometric/sampler/base.py +336 -7
- torch_geometric/sampler/hgt_sampler.py +11 -1
- torch_geometric/sampler/neighbor_sampler.py +296 -23
- torch_geometric/sampler/utils.py +93 -5
- torch_geometric/testing/__init__.py +4 -0
- torch_geometric/testing/decorators.py +35 -5
- torch_geometric/testing/distributed.py +1 -1
- torch_geometric/transforms/__init__.py +2 -0
- torch_geometric/transforms/add_gpse.py +49 -0
- torch_geometric/transforms/add_metapaths.py +8 -6
- torch_geometric/transforms/add_positional_encoding.py +2 -2
- torch_geometric/transforms/base_transform.py +2 -1
- torch_geometric/transforms/delaunay.py +65 -15
- torch_geometric/transforms/face_to_edge.py +32 -3
- torch_geometric/transforms/gdc.py +7 -8
- torch_geometric/transforms/largest_connected_components.py +1 -1
- torch_geometric/transforms/mask.py +5 -1
- torch_geometric/transforms/normalize_features.py +3 -3
- torch_geometric/transforms/random_link_split.py +1 -1
- torch_geometric/transforms/remove_duplicated_edges.py +4 -2
- torch_geometric/transforms/rooted_subgraph.py +1 -1
- torch_geometric/typing.py +70 -17
- torch_geometric/utils/__init__.py +4 -1
- torch_geometric/utils/_lexsort.py +0 -9
- torch_geometric/utils/_negative_sampling.py +27 -12
- torch_geometric/utils/_scatter.py +132 -195
- torch_geometric/utils/_sort_edge_index.py +0 -2
- torch_geometric/utils/_spmm.py +16 -14
- torch_geometric/utils/_subgraph.py +4 -0
- torch_geometric/utils/_to_dense_batch.py +2 -2
- torch_geometric/utils/_trim_to_layer.py +2 -2
- torch_geometric/utils/convert.py +17 -10
- torch_geometric/utils/cross_entropy.py +34 -13
- torch_geometric/utils/embedding.py +91 -2
- torch_geometric/utils/geodesic.py +4 -3
- torch_geometric/utils/influence.py +279 -0
- torch_geometric/utils/map.py +13 -9
- torch_geometric/utils/nested.py +1 -1
- torch_geometric/utils/smiles.py +3 -3
- torch_geometric/utils/sparse.py +7 -14
- torch_geometric/visualization/__init__.py +2 -1
- torch_geometric/visualization/graph.py +250 -5
- torch_geometric/warnings.py +11 -2
- torch_geometric/nn/nlp/__init__.py +0 -7
- torch_geometric/nn/nlp/sentence_transformer.py +0 -101
|
@@ -0,0 +1,397 @@
|
|
|
1
|
+
from typing import List, Optional, Union
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn as nn
|
|
5
|
+
from tqdm import tqdm
|
|
6
|
+
|
|
7
|
+
from torch_geometric.loader import DataLoader, NeighborLoader
|
|
8
|
+
from torch_geometric.nn.models import GraphSAGE, basic_gnn
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def deal_nan(x):
|
|
12
|
+
if isinstance(x, torch.Tensor):
|
|
13
|
+
x = x.clone()
|
|
14
|
+
x[torch.isnan(x)] = 0.0
|
|
15
|
+
return x
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class GLEM(torch.nn.Module):
|
|
19
|
+
r"""This GNN+LM co-training model is based on GLEM from the `"Learning on
|
|
20
|
+
Large-scale Text-attributed Graphs via Variational Inference"
|
|
21
|
+
<https://arxiv.org/abs/2210.14709>`_ paper.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
lm_to_use (str): A TextEncoder from huggingface model repo
|
|
25
|
+
with a classifier(default: TinyBERT)
|
|
26
|
+
gnn_to_use (torch_geometric.nn.models): (default: GraphSAGE)
|
|
27
|
+
out_channels (int): output channels for LM and GNN, should be same
|
|
28
|
+
num_gnn_heads Optional[int]: Number of heads for attention, if needed
|
|
29
|
+
num_gnn_layers (int): number of gnn layers
|
|
30
|
+
gnn_loss: loss function for gnn, (default: CrossEntropyLoss)
|
|
31
|
+
lm_loss: loss function for Language Model, (default: CrossEntropyLoss)
|
|
32
|
+
alpha (float): pseudo label weight of E-step, LM optimization,
|
|
33
|
+
(default: 0.5)
|
|
34
|
+
beta (float): pseudo label weight of M-step, GNN optimization,
|
|
35
|
+
(default: 0.5)
|
|
36
|
+
lm_dtype (torch.dtype): the data type once you load LM into memory,
|
|
37
|
+
(default: torch.bfloat16)
|
|
38
|
+
lm_use_lora (bool): choose if LM use Lora peft for fine tune,
|
|
39
|
+
(default: True)
|
|
40
|
+
lora_target_modules: The names of the target modules to apply the lora
|
|
41
|
+
adapter to, e.g. ['q_proj', 'v_proj'] for LLM , (default: None)
|
|
42
|
+
|
|
43
|
+
.. note::
|
|
44
|
+
See `examples/llm_plus_gnn/glem.py` for example usage.
|
|
45
|
+
"""
|
|
46
|
+
def __init__(
|
|
47
|
+
self,
|
|
48
|
+
lm_to_use: str = 'prajjwal1/bert-tiny',
|
|
49
|
+
gnn_to_use: basic_gnn = GraphSAGE,
|
|
50
|
+
out_channels: int = 47,
|
|
51
|
+
gnn_loss: Optional[nn.Module] = None,
|
|
52
|
+
lm_loss: Optional[nn.Module] = None,
|
|
53
|
+
alpha: float = 0.5,
|
|
54
|
+
beta: float = 0.5,
|
|
55
|
+
lm_dtype: torch.dtype = torch.bfloat16,
|
|
56
|
+
lm_use_lora: bool = True,
|
|
57
|
+
lora_target_modules: Optional[Union[List[str], str]] = None,
|
|
58
|
+
device: Optional[Union[str, torch.device]] = None,
|
|
59
|
+
):
|
|
60
|
+
super().__init__()
|
|
61
|
+
|
|
62
|
+
if gnn_loss is None:
|
|
63
|
+
gnn_loss = nn.CrossEntropyLoss(reduction='mean')
|
|
64
|
+
if lm_loss is None:
|
|
65
|
+
lm_loss = nn.CrossEntropyLoss(reduction='mean')
|
|
66
|
+
if device is None:
|
|
67
|
+
device = torch.device('cpu')
|
|
68
|
+
|
|
69
|
+
self.device = device
|
|
70
|
+
self.lm_loss = lm_loss
|
|
71
|
+
self.gnn = gnn_to_use
|
|
72
|
+
self.gnn_loss = gnn_loss
|
|
73
|
+
self.alpha = alpha
|
|
74
|
+
self.beta = beta
|
|
75
|
+
self.gnn_loss = gnn_loss
|
|
76
|
+
self.lm = lm_to_use
|
|
77
|
+
from transformers import AutoModelForSequenceClassification
|
|
78
|
+
self.lm = AutoModelForSequenceClassification.from_pretrained(
|
|
79
|
+
lm_to_use, num_labels=out_channels, torch_dtype=lm_dtype,
|
|
80
|
+
offload_folder="offload", trust_remote_code=True)
|
|
81
|
+
if lm_use_lora:
|
|
82
|
+
from peft import (
|
|
83
|
+
LoraConfig,
|
|
84
|
+
TaskType,
|
|
85
|
+
get_peft_model,
|
|
86
|
+
prepare_model_for_kbit_training,
|
|
87
|
+
)
|
|
88
|
+
print("Training LM with LORA!")
|
|
89
|
+
self.lm = prepare_model_for_kbit_training(self.lm)
|
|
90
|
+
config = LoraConfig(task_type=TaskType.SEQ_CLS, r=16,
|
|
91
|
+
lora_alpha=16, lora_dropout=0.05, bias="none",
|
|
92
|
+
target_modules=lora_target_modules)
|
|
93
|
+
self.lm = get_peft_model(self.lm, config)
|
|
94
|
+
self.lm.print_trainable_parameters()
|
|
95
|
+
self.lm.config.pad_token_id = self.lm.config.eos_token_id
|
|
96
|
+
self.lm_device = self.lm.device
|
|
97
|
+
|
|
98
|
+
if self.lm.num_labels != self.gnn.out_channels:
|
|
99
|
+
raise ValueError('''The output channel of language model \
|
|
100
|
+
and gnn should be the same''')
|
|
101
|
+
|
|
102
|
+
def pre_train_gnn(self, train_loader: NeighborLoader,
|
|
103
|
+
optimizer: torch.optim.Optimizer, num_epochs: int,
|
|
104
|
+
patience: int, ext_pseudo_labels: torch.Tensor = None,
|
|
105
|
+
is_augmented: bool = False, verbose: bool = True):
|
|
106
|
+
# Pretrain GNN, optional steps if you do not have pseudo labels.
|
|
107
|
+
best_acc = 0
|
|
108
|
+
early_stopping = 0
|
|
109
|
+
# training only based on gold data
|
|
110
|
+
for epoch in range(0, num_epochs):
|
|
111
|
+
acc, loss = self.train_gnn(train_loader, optimizer, epoch,
|
|
112
|
+
ext_pseudo_labels, is_augmented,
|
|
113
|
+
verbose)
|
|
114
|
+
if acc < best_acc:
|
|
115
|
+
early_stopping += 1
|
|
116
|
+
if early_stopping > patience:
|
|
117
|
+
print(f'Early stopped by Epoch: {epoch}, '
|
|
118
|
+
f'Best acc: {best_acc}')
|
|
119
|
+
break
|
|
120
|
+
best_acc = max(best_acc, acc)
|
|
121
|
+
|
|
122
|
+
def pre_train_lm(self, train_loader: DataLoader,
|
|
123
|
+
optimizer: torch.optim.Optimizer, num_epochs: int,
|
|
124
|
+
patience: int, ext_pseudo_labels: torch.Tensor = None,
|
|
125
|
+
is_augmented: bool = False, verbose: bool = True):
|
|
126
|
+
# Pretrain language model
|
|
127
|
+
best_acc = 0
|
|
128
|
+
early_stopping = 0
|
|
129
|
+
for epoch in range(1, num_epochs + 1):
|
|
130
|
+
acc, loss = self.train_lm(train_loader, optimizer, epoch,
|
|
131
|
+
ext_pseudo_labels, is_augmented, verbose)
|
|
132
|
+
if acc < best_acc:
|
|
133
|
+
early_stopping += 1
|
|
134
|
+
if early_stopping > patience:
|
|
135
|
+
print(f'Early stopped by Epoch: {epoch}, '
|
|
136
|
+
f'Best acc: {best_acc}')
|
|
137
|
+
break
|
|
138
|
+
best_acc = max(best_acc, acc)
|
|
139
|
+
|
|
140
|
+
def train(self, em_phase: str, train_loader: Union[DataLoader,
|
|
141
|
+
NeighborLoader],
|
|
142
|
+
optimizer: torch.optim.Optimizer, pseudo_labels: torch.Tensor,
|
|
143
|
+
epoch: int, is_augmented: bool = False, verbose: bool = False):
|
|
144
|
+
r"""GLEM training step, EM steps.
|
|
145
|
+
|
|
146
|
+
Args:
|
|
147
|
+
em_phase(str): 'gnn' or 'lm' choose which phase you are training on
|
|
148
|
+
train_loader(Union[DataLoader, NeighborLoader]): use DataLoader for
|
|
149
|
+
lm training, include tokenized data, labels is_gold mask.
|
|
150
|
+
use NeighborLoader for gnn training, include x, edge_index.
|
|
151
|
+
optimizer (torch.optim.Optimizer): optimizer for training
|
|
152
|
+
pseudo_labels(torch.Tensor): the predicted labels used as pseudo
|
|
153
|
+
labels
|
|
154
|
+
epoch (int): current epoch
|
|
155
|
+
is_augmented (bool): will use pseudo_labels or not
|
|
156
|
+
verbose (bool): print training progress bar or not
|
|
157
|
+
|
|
158
|
+
Returns:
|
|
159
|
+
acc (float): training accuracy
|
|
160
|
+
loss (float): loss value
|
|
161
|
+
"""
|
|
162
|
+
if pseudo_labels is not None:
|
|
163
|
+
pseudo_labels = pseudo_labels.to(self.device)
|
|
164
|
+
if em_phase == 'gnn':
|
|
165
|
+
acc, loss = self.train_gnn(train_loader, optimizer, epoch,
|
|
166
|
+
pseudo_labels, is_augmented, verbose)
|
|
167
|
+
if em_phase == 'lm':
|
|
168
|
+
acc, loss = self.train_lm(train_loader, optimizer, epoch,
|
|
169
|
+
pseudo_labels, is_augmented, verbose)
|
|
170
|
+
return acc, loss
|
|
171
|
+
|
|
172
|
+
def train_lm(self, train_loader: DataLoader,
|
|
173
|
+
optimizer: torch.optim.Optimizer, epoch: int,
|
|
174
|
+
pseudo_labels: torch.Tensor = None,
|
|
175
|
+
is_augmented: bool = False, verbose: bool = True):
|
|
176
|
+
r"""Language model Training in every epoch.
|
|
177
|
+
|
|
178
|
+
Args:
|
|
179
|
+
train_loader (loader.dataloader.DataLoader): text token dataloader
|
|
180
|
+
optimizer (torch.optim.Optimizer): model optimizer
|
|
181
|
+
epoch (int): current train epoch
|
|
182
|
+
pseudo_labels (torch.Tensor): 1-D tensor, predictions from gnn
|
|
183
|
+
is_augmented (bool): train with pseudo labels or not
|
|
184
|
+
verbose (bool): print training progress bar or not
|
|
185
|
+
|
|
186
|
+
Returns:
|
|
187
|
+
approx_acc (torch.tensor): training accuracy
|
|
188
|
+
loss (torch.float): loss value
|
|
189
|
+
|
|
190
|
+
"""
|
|
191
|
+
all_out = []
|
|
192
|
+
total_loss = total_correct = 0
|
|
193
|
+
num_nodes = train_loader.dataset.indices.size(0)
|
|
194
|
+
self.lm.train()
|
|
195
|
+
if verbose:
|
|
196
|
+
pbar = tqdm(total=num_nodes)
|
|
197
|
+
pbar.set_description(f'Epoch {epoch:02d}')
|
|
198
|
+
for batch in train_loader:
|
|
199
|
+
inputs = {k: v.to(self.device) for k, v in batch['input'].items()}
|
|
200
|
+
out = self.lm(**inputs).logits
|
|
201
|
+
labels = batch['labels'].to(self.device).squeeze()
|
|
202
|
+
# training with pseudo labels or not
|
|
203
|
+
if is_augmented:
|
|
204
|
+
pl_batch = pseudo_labels[batch['n_id']].to(self.device)
|
|
205
|
+
else:
|
|
206
|
+
pl_batch = None
|
|
207
|
+
loss = self.loss(out, labels, self.lm_loss,
|
|
208
|
+
batch['is_gold'].to(self.device), pl_batch,
|
|
209
|
+
self.alpha, is_augmented)
|
|
210
|
+
loss.backward()
|
|
211
|
+
optimizer.step()
|
|
212
|
+
optimizer.zero_grad()
|
|
213
|
+
all_out.append(out)
|
|
214
|
+
total_correct += int(out.argmax(dim=-1).eq(labels).sum())
|
|
215
|
+
total_loss += float(loss.detach())
|
|
216
|
+
if verbose:
|
|
217
|
+
pbar.update(batch['n_id'].size(0))
|
|
218
|
+
|
|
219
|
+
all_out = torch.cat(all_out, dim=0)
|
|
220
|
+
approx_acc = total_correct / num_nodes
|
|
221
|
+
loss = total_loss / len(train_loader)
|
|
222
|
+
if verbose:
|
|
223
|
+
pbar.close()
|
|
224
|
+
print(f'Epoch {epoch:02d} Loss: {loss:.4f} '
|
|
225
|
+
f'Approx. Train: {approx_acc:.4f}')
|
|
226
|
+
return approx_acc, loss
|
|
227
|
+
|
|
228
|
+
def train_gnn(self, train_loader: NeighborLoader,
|
|
229
|
+
optimizer: torch.optim.Optimizer, epoch: int,
|
|
230
|
+
pseudo_labels: torch.Tensor = None,
|
|
231
|
+
is_augmented: bool = False, verbose: bool = True):
|
|
232
|
+
r"""GNN training step in every epoch.
|
|
233
|
+
|
|
234
|
+
Args:
|
|
235
|
+
train_loader (loader.NeighborLoader): gnn Neighbor node loader
|
|
236
|
+
optimizer (torch.optim.Optimizer): model optimizer
|
|
237
|
+
epoch (int): current train epoch
|
|
238
|
+
pseudo_labels(torch.tensor): 1-D tensor, predictions from lm
|
|
239
|
+
is_augmented(bool): use pseudo labeled node or not
|
|
240
|
+
verbose (bool): print training progress or not
|
|
241
|
+
|
|
242
|
+
Returns:
|
|
243
|
+
approx_acc (torch.tensor): training accuracy
|
|
244
|
+
loss (torch.float): loss value
|
|
245
|
+
"""
|
|
246
|
+
self.gnn.train()
|
|
247
|
+
num_nodes = train_loader.input_nodes.size(0)
|
|
248
|
+
if verbose:
|
|
249
|
+
pbar = tqdm(total=num_nodes)
|
|
250
|
+
pbar.set_description(f'Epoch {epoch:02d}')
|
|
251
|
+
total_loss = total_correct = 0
|
|
252
|
+
all_out = []
|
|
253
|
+
for batch in train_loader:
|
|
254
|
+
batch = batch.to(self.device)
|
|
255
|
+
out = self.gnn(batch.x, batch.edge_index)[:batch.batch_size]
|
|
256
|
+
all_out.append(out)
|
|
257
|
+
labels = batch.y[:batch.batch_size].squeeze()
|
|
258
|
+
is_gold_batch = batch.is_gold[:batch.batch_size].squeeze()
|
|
259
|
+
# training with pseudo labels or not
|
|
260
|
+
if is_augmented and pseudo_labels is not None:
|
|
261
|
+
pl_batch = pseudo_labels[batch.n_id[:batch.batch_size]]
|
|
262
|
+
else:
|
|
263
|
+
pl_batch = None
|
|
264
|
+
loss = self.loss(out, labels, self.gnn_loss, is_gold_batch,
|
|
265
|
+
pl_batch, self.beta, is_augmented)
|
|
266
|
+
loss.backward()
|
|
267
|
+
optimizer.step()
|
|
268
|
+
optimizer.zero_grad()
|
|
269
|
+
total_loss += float(loss.detach())
|
|
270
|
+
total_correct += int(out.argmax(dim=-1).eq(labels).sum())
|
|
271
|
+
if verbose:
|
|
272
|
+
pbar.update(batch.batch_size)
|
|
273
|
+
|
|
274
|
+
all_out = torch.cat(all_out, dim=0)
|
|
275
|
+
loss = total_loss / len(train_loader)
|
|
276
|
+
approx_acc = total_correct / num_nodes
|
|
277
|
+
if verbose:
|
|
278
|
+
pbar.close()
|
|
279
|
+
print(f'Epoch: {epoch:02d} Loss: {loss:.4f} '
|
|
280
|
+
f'Approx. Train: {approx_acc:.4f}')
|
|
281
|
+
return approx_acc, loss
|
|
282
|
+
|
|
283
|
+
@torch.no_grad()
|
|
284
|
+
def inference(self, em_phase: str, data_loader: Union[NeighborLoader,
|
|
285
|
+
DataLoader],
|
|
286
|
+
verbose: bool = False):
|
|
287
|
+
r"""GLEM inference step.
|
|
288
|
+
|
|
289
|
+
Args:
|
|
290
|
+
em_phase(str): 'gnn' or 'lm'
|
|
291
|
+
data_loader(dataloader or Neighborloader):
|
|
292
|
+
dataloader: for lm training, include tokenized data
|
|
293
|
+
nodeloader: for gnn training, include x, edge_index
|
|
294
|
+
verbose(bool): print inference progress or not
|
|
295
|
+
|
|
296
|
+
Returns:
|
|
297
|
+
out (torch.Tensor): n * m tensor, m is number of classes,
|
|
298
|
+
n is number of nodes
|
|
299
|
+
"""
|
|
300
|
+
out = None
|
|
301
|
+
if em_phase == 'gnn':
|
|
302
|
+
self.gnn.eval()
|
|
303
|
+
out = self.inference_gnn(data_loader, verbose)
|
|
304
|
+
elif em_phase == 'lm':
|
|
305
|
+
self.lm.eval()
|
|
306
|
+
out = self.inference_lm(data_loader, verbose)
|
|
307
|
+
return out
|
|
308
|
+
|
|
309
|
+
@torch.no_grad()
|
|
310
|
+
def inference_lm(self, data_loader: DataLoader, verbose: bool = True):
|
|
311
|
+
r"""LM inference step.
|
|
312
|
+
|
|
313
|
+
Args:
|
|
314
|
+
data_loader (Dataloader): include token, labels, and gold mask
|
|
315
|
+
verbose (bool): print progress bar or not
|
|
316
|
+
|
|
317
|
+
Returns:
|
|
318
|
+
preds (tensor): prediction from GNN, convert to pseudo labels
|
|
319
|
+
by preds.argmax(dim=-1).unsqueeze(1)
|
|
320
|
+
"""
|
|
321
|
+
if verbose:
|
|
322
|
+
pbar = tqdm(total=data_loader.dataset._data.num_nodes)
|
|
323
|
+
pbar.set_description('LM inference stage')
|
|
324
|
+
self.lm.eval()
|
|
325
|
+
preds = []
|
|
326
|
+
for batch in data_loader:
|
|
327
|
+
inputs = {k: v.to(self.device) for k, v in batch['input'].items()}
|
|
328
|
+
logits = self.lm(**inputs).logits
|
|
329
|
+
preds.append(logits)
|
|
330
|
+
if verbose:
|
|
331
|
+
pbar.update(batch['n_id'].size(0))
|
|
332
|
+
if verbose:
|
|
333
|
+
pbar.close()
|
|
334
|
+
preds = torch.cat(preds)
|
|
335
|
+
return preds
|
|
336
|
+
|
|
337
|
+
@torch.no_grad()
|
|
338
|
+
def inference_gnn(self, data_loader: NeighborLoader, verbose: bool = True):
|
|
339
|
+
r"""GNN inference step.
|
|
340
|
+
|
|
341
|
+
Args:
|
|
342
|
+
data_loader(NeighborLoader): include x, edge_index,
|
|
343
|
+
verbose (bool): print progress bar or not
|
|
344
|
+
|
|
345
|
+
Returns:
|
|
346
|
+
preds (tensor): prediction from GNN,
|
|
347
|
+
convert to pseudo labels by preds.argmax(dim=-1).unsqueeze(1)
|
|
348
|
+
"""
|
|
349
|
+
if verbose:
|
|
350
|
+
pbar = tqdm(total=data_loader.data.num_nodes)
|
|
351
|
+
pbar.set_description('GNN inference stage')
|
|
352
|
+
preds = []
|
|
353
|
+
self.gnn.eval()
|
|
354
|
+
for batch in data_loader:
|
|
355
|
+
batch = batch.to(self.device)
|
|
356
|
+
out = self.gnn(batch.x, batch.edge_index)[:batch.batch_size]
|
|
357
|
+
preds.append(out)
|
|
358
|
+
if verbose:
|
|
359
|
+
pbar.update(batch.batch_size)
|
|
360
|
+
if verbose:
|
|
361
|
+
pbar.close()
|
|
362
|
+
preds = torch.cat(preds, dim=0)
|
|
363
|
+
return preds
|
|
364
|
+
|
|
365
|
+
def loss(self, logits: torch.Tensor, labels: torch.Tensor,
|
|
366
|
+
loss_func: torch.nn.functional, is_gold: torch.Tensor,
|
|
367
|
+
pseudo_labels: torch.Tensor = None, pl_weight: float = 0.5,
|
|
368
|
+
is_augmented: bool = True):
|
|
369
|
+
r"""Core function of variational EM inference, this function is aming
|
|
370
|
+
on combining loss value on gold(original train) and loss value on
|
|
371
|
+
pseudo labels.
|
|
372
|
+
|
|
373
|
+
Reference:
|
|
374
|
+
<https://github.com/AndyJZhao/GLEM/blob/main/src/models/GLEM/GLEM_utils.py> # noqa
|
|
375
|
+
|
|
376
|
+
Args:
|
|
377
|
+
logits(torch.tensor): predict results from LM or GNN
|
|
378
|
+
labels(torch.tensor): combined node labels from ground truth and
|
|
379
|
+
pseudo labels(if provided)
|
|
380
|
+
loss_func(torch.nn.modules.loss): loss function for classification
|
|
381
|
+
is_gold(tensor): a tensor with bool value that mask ground truth
|
|
382
|
+
label and during training, thus ~is_gold mask pseudo labels
|
|
383
|
+
pseudo_labels(torch.tensor): predictions from other model
|
|
384
|
+
pl_weight: the pseudo labels used in E-step and M-step optimization
|
|
385
|
+
alpha in E-step, beta in M-step respectively
|
|
386
|
+
is_augmented: use EM or just train GNN and LM with gold data
|
|
387
|
+
|
|
388
|
+
"""
|
|
389
|
+
if is_augmented and (sum(~is_gold) > 0):
|
|
390
|
+
mle_loss = deal_nan(loss_func(logits[is_gold], labels[is_gold]))
|
|
391
|
+
# all other labels beside from ground truth(gold labels)
|
|
392
|
+
pseudo_label_loss = deal_nan(
|
|
393
|
+
loss_func(logits[~is_gold], pseudo_labels[~is_gold]))
|
|
394
|
+
loss = pl_weight * pseudo_label_loss + (1 - pl_weight) * mle_loss
|
|
395
|
+
else:
|
|
396
|
+
loss = loss_func(logits, labels)
|
|
397
|
+
return loss
|
|
@@ -10,15 +10,17 @@ try:
|
|
|
10
10
|
except ImportError:
|
|
11
11
|
BatchEncoding = Dict
|
|
12
12
|
|
|
13
|
-
BOS = '<s>[INST]'
|
|
14
|
-
EOS_USER = '[/INST]'
|
|
15
|
-
EOS = '[/s]'
|
|
16
13
|
IGNORE_INDEX = -100
|
|
17
14
|
MAX_TXT_LEN = 512
|
|
18
|
-
MAX_NEW_TOKENS =
|
|
15
|
+
MAX_NEW_TOKENS = 128
|
|
19
16
|
PAD_TOKEN_ID = 0
|
|
20
17
|
PADDING_SIDE = 'left'
|
|
21
18
|
|
|
19
|
+
# legacy constants - used for Llama 2 style prompting
|
|
20
|
+
BOS = '<s>[INST]'
|
|
21
|
+
EOS_USER = '[/INST]'
|
|
22
|
+
EOS = '[/s]'
|
|
23
|
+
|
|
22
24
|
|
|
23
25
|
def get_llm_kwargs(required_memory: int, dtype=torch.dtype) -> Dict[str, Any]:
|
|
24
26
|
torch.cuda.empty_cache()
|
|
@@ -41,7 +43,7 @@ def get_llm_kwargs(required_memory: int, dtype=torch.dtype) -> Dict[str, Any]:
|
|
|
41
43
|
}
|
|
42
44
|
kwargs['low_cpu_mem_usage'] = True
|
|
43
45
|
kwargs['device_map'] = 'auto'
|
|
44
|
-
kwargs['
|
|
46
|
+
kwargs['dtype'] = dtype
|
|
45
47
|
|
|
46
48
|
return kwargs
|
|
47
49
|
|
|
@@ -49,50 +51,108 @@ def get_llm_kwargs(required_memory: int, dtype=torch.dtype) -> Dict[str, Any]:
|
|
|
49
51
|
class LLM(torch.nn.Module):
|
|
50
52
|
r"""A wrapper around a Large Language Model (LLM) from HuggingFace.
|
|
51
53
|
|
|
52
|
-
|
|
53
|
-
:
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
54
|
+
Args:
|
|
55
|
+
model_name (str): The HuggingFace model name
|
|
56
|
+
num_params (float, optional): An integer representing how many params
|
|
57
|
+
the HuggingFace model has, in billions. This is used to
|
|
58
|
+
automatically allocate the correct number of GPUs needed (using a
|
|
59
|
+
rough heuristic), given the available GPU memory of your GPUs. If
|
|
60
|
+
not specified, the number of parameters is determined using the
|
|
61
|
+
`huggingface_hub` module.
|
|
62
|
+
n_gpus (int, optional): Number of GPUs to use. Designed for advanced
|
|
63
|
+
users to select how many GPU's they want to set this manually and
|
|
64
|
+
override the automatic set up mechanism.
|
|
65
|
+
dtype (torch.dtype, optional): The data type to use for the LLM.
|
|
66
|
+
(default :obj: `torch.bfloat16`)
|
|
67
|
+
sys_prompt (str, optional): A system prompt to use for the LLM.
|
|
68
|
+
(default: :obj: `None`)
|
|
60
69
|
"""
|
|
61
70
|
def __init__(
|
|
62
71
|
self,
|
|
63
72
|
model_name: str,
|
|
64
|
-
num_params:
|
|
65
|
-
|
|
73
|
+
num_params: Optional[float] = None,
|
|
74
|
+
n_gpus: Optional[int] = None,
|
|
75
|
+
dtype: Optional[torch.dtype] = torch.bfloat16,
|
|
76
|
+
sys_prompt: Optional[str] = None,
|
|
66
77
|
) -> None:
|
|
67
78
|
super().__init__()
|
|
68
79
|
|
|
69
80
|
self.model_name = model_name
|
|
70
81
|
|
|
71
82
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
83
|
+
if n_gpus is None:
|
|
84
|
+
if num_params is None:
|
|
85
|
+
from huggingface_hub import get_safetensors_metadata
|
|
86
|
+
safetensors_metadata = get_safetensors_metadata(model_name)
|
|
87
|
+
param_count = safetensors_metadata.parameter_count
|
|
88
|
+
num_params = float(list(param_count.values())[0] // 10**9)
|
|
89
|
+
|
|
90
|
+
# A rough heuristic on GPU memory requirements, e.g., we found that
|
|
91
|
+
# LLAMA3 (8B parameters) fits on a 96GB GPU.
|
|
92
|
+
required_memory = 96.0 * num_params / 8.0
|
|
93
|
+
kwargs = get_llm_kwargs(required_memory, dtype)
|
|
94
|
+
else:
|
|
95
|
+
gpu_memory: List[int] = []
|
|
96
|
+
for i in range(n_gpus):
|
|
97
|
+
gpu_memory.append(torch.cuda.mem_get_info(i)[0] // 1024**3)
|
|
98
|
+
kwargs = dict(revision='main')
|
|
99
|
+
kwargs['max_memory'] = {
|
|
100
|
+
i: f'{memory}GiB'
|
|
101
|
+
for i, memory in enumerate(gpu_memory)
|
|
102
|
+
}
|
|
103
|
+
kwargs['low_cpu_mem_usage'] = True
|
|
104
|
+
kwargs['device_map'] = 'auto'
|
|
105
|
+
kwargs['torch_dtype'] = dtype
|
|
77
106
|
|
|
78
107
|
print(f"Setting up '{model_name}' with configuration: {kwargs}")
|
|
79
108
|
self.tokenizer = AutoTokenizer.from_pretrained(
|
|
80
109
|
model_name,
|
|
81
110
|
use_fast=False,
|
|
82
111
|
)
|
|
83
|
-
self.tokenizer.
|
|
84
|
-
|
|
112
|
+
if self.tokenizer.chat_template and self.tokenizer.bos_token is None:
|
|
113
|
+
dummy_convo = [
|
|
114
|
+
{
|
|
115
|
+
"role": "system",
|
|
116
|
+
"content": "dummy"
|
|
117
|
+
},
|
|
118
|
+
{
|
|
119
|
+
"role": "user",
|
|
120
|
+
"content": "convo"
|
|
121
|
+
},
|
|
122
|
+
]
|
|
123
|
+
text = self.tokenizer.apply_chat_template(
|
|
124
|
+
dummy_convo,
|
|
125
|
+
tokenize=True,
|
|
126
|
+
)
|
|
127
|
+
self.tokenizer.bos_token = self.tokenizer.decode(text[0])
|
|
128
|
+
if self.tokenizer.pad_token_id is None:
|
|
129
|
+
self.tokenizer.pad_token_id = PAD_TOKEN_ID
|
|
130
|
+
if self.tokenizer.padding_side is None:
|
|
131
|
+
self.tokenizer.padding_side = PADDING_SIDE
|
|
85
132
|
self.llm = AutoModelForCausalLM.from_pretrained(model_name, **kwargs)
|
|
86
133
|
self.word_embedding = self.llm.model.get_input_embeddings()
|
|
87
|
-
|
|
134
|
+
if sys_prompt is not None:
|
|
135
|
+
self.sys_prompt = sys_prompt
|
|
136
|
+
else:
|
|
137
|
+
self.sys_prompt = ""
|
|
88
138
|
if 'max_memory' not in kwargs: # Pure CPU:
|
|
89
|
-
warnings.warn(
|
|
139
|
+
warnings.warn(
|
|
140
|
+
"LLM is being used on CPU, which may be slow. This decision "
|
|
141
|
+
"was made by a rough hueristic that assumes your GPU set up "
|
|
142
|
+
"does not have enough GPU RAM. This is done to avoid GPU OOM "
|
|
143
|
+
"errors. If you think this is a mistake, please initialize "
|
|
144
|
+
"your LLM with the n_gpus param to dictate how many gpus to "
|
|
145
|
+
"use for the LLM.", stacklevel=2)
|
|
90
146
|
self.device = torch.device('cpu')
|
|
91
147
|
self.autocast_context = nullcontext()
|
|
92
148
|
else:
|
|
93
149
|
self.device = self.llm.device
|
|
94
|
-
|
|
150
|
+
if dtype == torch.float32:
|
|
151
|
+
self.autocast_context = nullcontext()
|
|
152
|
+
else:
|
|
153
|
+
self.autocast_context = torch.amp.autocast('cuda', dtype=dtype)
|
|
95
154
|
|
|
155
|
+
# legacy function - used for Llama 2 style prompting
|
|
96
156
|
def _encode_inputs(
|
|
97
157
|
self,
|
|
98
158
|
question: List[str],
|
|
@@ -126,6 +186,7 @@ class LLM(torch.nn.Module):
|
|
|
126
186
|
label_input_ids = label_input_ids + eos_tokens.input_ids
|
|
127
187
|
return label_input_ids
|
|
128
188
|
|
|
189
|
+
# legacy function - used for Llama 2 style prompting
|
|
129
190
|
def _input_ids(
|
|
130
191
|
self,
|
|
131
192
|
i: int,
|
|
@@ -140,6 +201,7 @@ class LLM(torch.nn.Module):
|
|
|
140
201
|
input_ids += eos_user_tokens.input_ids
|
|
141
202
|
return input_ids
|
|
142
203
|
|
|
204
|
+
# legacy function - used for Llama 2 style prompting
|
|
143
205
|
def _inputs_embeds(
|
|
144
206
|
self,
|
|
145
207
|
i: int,
|
|
@@ -199,7 +261,8 @@ class LLM(torch.nn.Module):
|
|
|
199
261
|
device=self.device)
|
|
200
262
|
return inputs_embeds, attention_mask, label_input_ids
|
|
201
263
|
|
|
202
|
-
|
|
264
|
+
# legacy function - used for Llama 2 style prompting
|
|
265
|
+
def _get_embeds_old(
|
|
203
266
|
self,
|
|
204
267
|
question: List[str],
|
|
205
268
|
context: Optional[List[str]] = None,
|
|
@@ -246,6 +309,95 @@ class LLM(torch.nn.Module):
|
|
|
246
309
|
|
|
247
310
|
return inputs_embeds, attention_mask, label_input_ids
|
|
248
311
|
|
|
312
|
+
def _get_embeds(
|
|
313
|
+
self,
|
|
314
|
+
question: List[str],
|
|
315
|
+
context: Optional[List[str]] = None,
|
|
316
|
+
embedding: Optional[List[Tensor]] = None,
|
|
317
|
+
answer: Optional[List[str]] = None,
|
|
318
|
+
) -> tuple:
|
|
319
|
+
if not self.tokenizer.chat_template or not self.sys_prompt:
|
|
320
|
+
warnings.warn(
|
|
321
|
+
f"HuggingFace model {self.model_name} is not using a "
|
|
322
|
+
"chat template, using Llama 2 style prompting. Please "
|
|
323
|
+
"consider using a more recent model and initialize the "
|
|
324
|
+
"LLM with `sys_prompt`.", stacklevel=2)
|
|
325
|
+
return self._get_embeds_old(question, context, embedding, answer)
|
|
326
|
+
batch_label_input_ids = None
|
|
327
|
+
if answer is not None:
|
|
328
|
+
label = self.tokenizer(answer, add_special_tokens=False)
|
|
329
|
+
eos_tokens = self.tokenizer(self.tokenizer.eos_token,
|
|
330
|
+
add_special_tokens=False)
|
|
331
|
+
batch_label_input_ids = []
|
|
332
|
+
|
|
333
|
+
batch_inputs_embeds = []
|
|
334
|
+
batch_attention_mask = []
|
|
335
|
+
for i in range(len(question)):
|
|
336
|
+
ctx = f"{context[i]} - " if context else ""
|
|
337
|
+
messages = [
|
|
338
|
+
{
|
|
339
|
+
"role": "system",
|
|
340
|
+
"content": self.sys_prompt
|
|
341
|
+
},
|
|
342
|
+
{
|
|
343
|
+
"role": "user",
|
|
344
|
+
"content": f"{ctx} - {question[i]}"
|
|
345
|
+
},
|
|
346
|
+
]
|
|
347
|
+
text = self.tokenizer.apply_chat_template(
|
|
348
|
+
messages,
|
|
349
|
+
tokenize=False,
|
|
350
|
+
add_generation_prompt=True,
|
|
351
|
+
enable_thinking=True,
|
|
352
|
+
)
|
|
353
|
+
text = text[len(self.tokenizer.bos_token):]
|
|
354
|
+
input_ids = self.tokenizer(text,
|
|
355
|
+
add_special_tokens=False).input_ids
|
|
356
|
+
if answer is not None:
|
|
357
|
+
label_input_ids = self._label_input_ids(i, label, eos_tokens)
|
|
358
|
+
input_ids += label_input_ids
|
|
359
|
+
else:
|
|
360
|
+
label_input_ids = None
|
|
361
|
+
|
|
362
|
+
bos_token = self.tokenizer(
|
|
363
|
+
self.tokenizer.bos_token,
|
|
364
|
+
add_special_tokens=False,
|
|
365
|
+
return_tensors='pt',
|
|
366
|
+
).input_ids[0].to(self.device)
|
|
367
|
+
|
|
368
|
+
bos_embeds = self.word_embedding(bos_token)
|
|
369
|
+
|
|
370
|
+
inputs_embeds = self.word_embedding(
|
|
371
|
+
torch.tensor(input_ids, device=self.device))
|
|
372
|
+
|
|
373
|
+
to_cat = [bos_embeds]
|
|
374
|
+
if embedding is not None and embedding[i] is not None:
|
|
375
|
+
to_cat.append(embedding[i])
|
|
376
|
+
to_cat.append(inputs_embeds)
|
|
377
|
+
inputs_embeds = torch.cat(to_cat, dim=0).to(self.device)
|
|
378
|
+
|
|
379
|
+
(
|
|
380
|
+
batch_inputs_embeds,
|
|
381
|
+
batch_attention_mask,
|
|
382
|
+
batch_label_input_ids,
|
|
383
|
+
) = self._append_embeds(
|
|
384
|
+
inputs_embeds,
|
|
385
|
+
batch_inputs_embeds,
|
|
386
|
+
batch_attention_mask,
|
|
387
|
+
label_input_ids,
|
|
388
|
+
batch_label_input_ids,
|
|
389
|
+
)
|
|
390
|
+
|
|
391
|
+
pad_token = torch.tensor(self.tokenizer.pad_token_id,
|
|
392
|
+
device=self.device)
|
|
393
|
+
pad_embeds = self.word_embedding(pad_token).unsqueeze(0)
|
|
394
|
+
|
|
395
|
+
inputs_embeds, attention_mask, label_input_ids = self._pad_embeds(
|
|
396
|
+
pad_embeds, batch_inputs_embeds, batch_attention_mask,
|
|
397
|
+
batch_label_input_ids)
|
|
398
|
+
|
|
399
|
+
return inputs_embeds, attention_mask, label_input_ids
|
|
400
|
+
|
|
249
401
|
def forward(
|
|
250
402
|
self,
|
|
251
403
|
question: List[str],
|
|
@@ -302,17 +454,13 @@ class LLM(torch.nn.Module):
|
|
|
302
454
|
inputs_embeds, attention_mask, _ = self._get_embeds(
|
|
303
455
|
question, context, embedding)
|
|
304
456
|
|
|
305
|
-
bos_token = self.tokenizer(
|
|
306
|
-
BOS,
|
|
307
|
-
add_special_tokens=False,
|
|
308
|
-
).input_ids[0]
|
|
309
|
-
|
|
310
457
|
with self.autocast_context:
|
|
311
458
|
outputs = self.llm.generate(
|
|
312
459
|
inputs_embeds=inputs_embeds,
|
|
313
|
-
bos_token_id=
|
|
460
|
+
bos_token_id=self.tokenizer.bos_token_id,
|
|
314
461
|
max_new_tokens=max_tokens,
|
|
315
462
|
attention_mask=attention_mask,
|
|
463
|
+
pad_token_id=self.tokenizer.eos_token_id,
|
|
316
464
|
use_cache=True,
|
|
317
465
|
)
|
|
318
466
|
|