pyg-nightly 2.7.0.dev20241009__py3-none-any.whl → 2.8.0.dev20251207__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pyg-nightly might be problematic. Click here for more details.

Files changed (228) hide show
  1. {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/METADATA +77 -53
  2. {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/RECORD +226 -189
  3. {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/WHEEL +1 -1
  4. pyg_nightly-2.8.0.dev20251207.dist-info/licenses/LICENSE +19 -0
  5. torch_geometric/__init__.py +14 -2
  6. torch_geometric/_compile.py +9 -3
  7. torch_geometric/_onnx.py +214 -0
  8. torch_geometric/config_mixin.py +5 -3
  9. torch_geometric/config_store.py +1 -1
  10. torch_geometric/contrib/__init__.py +1 -1
  11. torch_geometric/contrib/explain/pgm_explainer.py +1 -1
  12. torch_geometric/data/batch.py +2 -2
  13. torch_geometric/data/collate.py +1 -3
  14. torch_geometric/data/data.py +109 -5
  15. torch_geometric/data/database.py +4 -0
  16. torch_geometric/data/dataset.py +14 -11
  17. torch_geometric/data/extract.py +1 -1
  18. torch_geometric/data/feature_store.py +17 -22
  19. torch_geometric/data/graph_store.py +3 -2
  20. torch_geometric/data/hetero_data.py +139 -7
  21. torch_geometric/data/hypergraph_data.py +2 -2
  22. torch_geometric/data/in_memory_dataset.py +2 -2
  23. torch_geometric/data/lightning/datamodule.py +42 -28
  24. torch_geometric/data/storage.py +9 -1
  25. torch_geometric/datasets/__init__.py +18 -1
  26. torch_geometric/datasets/actor.py +7 -9
  27. torch_geometric/datasets/airfrans.py +15 -17
  28. torch_geometric/datasets/airports.py +8 -10
  29. torch_geometric/datasets/amazon.py +8 -11
  30. torch_geometric/datasets/amazon_book.py +8 -9
  31. torch_geometric/datasets/amazon_products.py +7 -9
  32. torch_geometric/datasets/aminer.py +8 -9
  33. torch_geometric/datasets/aqsol.py +10 -13
  34. torch_geometric/datasets/attributed_graph_dataset.py +8 -10
  35. torch_geometric/datasets/ba_multi_shapes.py +10 -12
  36. torch_geometric/datasets/ba_shapes.py +5 -6
  37. torch_geometric/datasets/city.py +157 -0
  38. torch_geometric/datasets/dbp15k.py +1 -1
  39. torch_geometric/datasets/git_mol_dataset.py +263 -0
  40. torch_geometric/datasets/hgb_dataset.py +2 -2
  41. torch_geometric/datasets/hm.py +1 -1
  42. torch_geometric/datasets/instruct_mol_dataset.py +134 -0
  43. torch_geometric/datasets/md17.py +3 -3
  44. torch_geometric/datasets/medshapenet.py +145 -0
  45. torch_geometric/datasets/modelnet.py +1 -1
  46. torch_geometric/datasets/molecule_gpt_dataset.py +492 -0
  47. torch_geometric/datasets/molecule_net.py +3 -2
  48. torch_geometric/datasets/ppi.py +2 -1
  49. torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
  50. torch_geometric/datasets/qm7.py +1 -1
  51. torch_geometric/datasets/qm9.py +1 -1
  52. torch_geometric/datasets/snap_dataset.py +8 -4
  53. torch_geometric/datasets/tag_dataset.py +462 -0
  54. torch_geometric/datasets/teeth3ds.py +269 -0
  55. torch_geometric/datasets/web_qsp_dataset.py +310 -209
  56. torch_geometric/datasets/wikics.py +2 -1
  57. torch_geometric/deprecation.py +1 -1
  58. torch_geometric/distributed/__init__.py +13 -0
  59. torch_geometric/distributed/dist_loader.py +2 -2
  60. torch_geometric/distributed/partition.py +2 -2
  61. torch_geometric/distributed/rpc.py +3 -3
  62. torch_geometric/edge_index.py +18 -14
  63. torch_geometric/explain/algorithm/attention_explainer.py +219 -29
  64. torch_geometric/explain/algorithm/base.py +2 -2
  65. torch_geometric/explain/algorithm/captum.py +1 -1
  66. torch_geometric/explain/algorithm/captum_explainer.py +2 -1
  67. torch_geometric/explain/algorithm/gnn_explainer.py +406 -69
  68. torch_geometric/explain/algorithm/graphmask_explainer.py +8 -8
  69. torch_geometric/explain/algorithm/pg_explainer.py +305 -47
  70. torch_geometric/explain/explainer.py +2 -2
  71. torch_geometric/explain/explanation.py +87 -3
  72. torch_geometric/explain/metric/faithfulness.py +1 -1
  73. torch_geometric/graphgym/config.py +3 -2
  74. torch_geometric/graphgym/imports.py +15 -4
  75. torch_geometric/graphgym/logger.py +1 -1
  76. torch_geometric/graphgym/loss.py +1 -1
  77. torch_geometric/graphgym/models/encoder.py +2 -2
  78. torch_geometric/graphgym/models/layer.py +1 -1
  79. torch_geometric/graphgym/utils/comp_budget.py +4 -3
  80. torch_geometric/hash_tensor.py +798 -0
  81. torch_geometric/index.py +14 -5
  82. torch_geometric/inspector.py +4 -0
  83. torch_geometric/io/fs.py +5 -4
  84. torch_geometric/llm/__init__.py +9 -0
  85. torch_geometric/llm/large_graph_indexer.py +741 -0
  86. torch_geometric/llm/models/__init__.py +23 -0
  87. torch_geometric/{nn → llm}/models/g_retriever.py +77 -45
  88. torch_geometric/llm/models/git_mol.py +336 -0
  89. torch_geometric/llm/models/glem.py +397 -0
  90. torch_geometric/{nn/nlp → llm/models}/llm.py +179 -31
  91. torch_geometric/llm/models/llm_judge.py +158 -0
  92. torch_geometric/llm/models/molecule_gpt.py +222 -0
  93. torch_geometric/llm/models/protein_mpnn.py +333 -0
  94. torch_geometric/llm/models/sentence_transformer.py +188 -0
  95. torch_geometric/llm/models/txt2kg.py +353 -0
  96. torch_geometric/llm/models/vision_transformer.py +38 -0
  97. torch_geometric/llm/rag_loader.py +154 -0
  98. torch_geometric/llm/utils/__init__.py +10 -0
  99. torch_geometric/llm/utils/backend_utils.py +443 -0
  100. torch_geometric/llm/utils/feature_store.py +169 -0
  101. torch_geometric/llm/utils/graph_store.py +199 -0
  102. torch_geometric/llm/utils/vectorrag.py +125 -0
  103. torch_geometric/loader/cluster.py +4 -4
  104. torch_geometric/loader/ibmb_loader.py +4 -4
  105. torch_geometric/loader/link_loader.py +1 -1
  106. torch_geometric/loader/link_neighbor_loader.py +2 -1
  107. torch_geometric/loader/mixin.py +6 -5
  108. torch_geometric/loader/neighbor_loader.py +1 -1
  109. torch_geometric/loader/neighbor_sampler.py +2 -2
  110. torch_geometric/loader/prefetch.py +3 -2
  111. torch_geometric/loader/temporal_dataloader.py +2 -2
  112. torch_geometric/loader/utils.py +10 -10
  113. torch_geometric/metrics/__init__.py +14 -0
  114. torch_geometric/metrics/link_pred.py +745 -92
  115. torch_geometric/nn/__init__.py +1 -0
  116. torch_geometric/nn/aggr/base.py +1 -1
  117. torch_geometric/nn/aggr/equilibrium.py +1 -1
  118. torch_geometric/nn/aggr/fused.py +1 -1
  119. torch_geometric/nn/aggr/patch_transformer.py +8 -2
  120. torch_geometric/nn/aggr/set_transformer.py +1 -1
  121. torch_geometric/nn/aggr/utils.py +9 -4
  122. torch_geometric/nn/attention/__init__.py +9 -1
  123. torch_geometric/nn/attention/polynormer.py +107 -0
  124. torch_geometric/nn/attention/qformer.py +71 -0
  125. torch_geometric/nn/attention/sgformer.py +99 -0
  126. torch_geometric/nn/conv/__init__.py +2 -0
  127. torch_geometric/nn/conv/appnp.py +1 -1
  128. torch_geometric/nn/conv/cugraph/gat_conv.py +8 -2
  129. torch_geometric/nn/conv/cugraph/rgcn_conv.py +3 -0
  130. torch_geometric/nn/conv/cugraph/sage_conv.py +3 -0
  131. torch_geometric/nn/conv/dna_conv.py +1 -1
  132. torch_geometric/nn/conv/eg_conv.py +7 -7
  133. torch_geometric/nn/conv/gen_conv.py +1 -1
  134. torch_geometric/nn/conv/gravnet_conv.py +2 -1
  135. torch_geometric/nn/conv/hetero_conv.py +2 -1
  136. torch_geometric/nn/conv/meshcnn_conv.py +487 -0
  137. torch_geometric/nn/conv/message_passing.py +5 -4
  138. torch_geometric/nn/conv/rgcn_conv.py +2 -1
  139. torch_geometric/nn/conv/sg_conv.py +1 -1
  140. torch_geometric/nn/conv/spline_conv.py +2 -1
  141. torch_geometric/nn/conv/ssg_conv.py +1 -1
  142. torch_geometric/nn/conv/transformer_conv.py +5 -3
  143. torch_geometric/nn/data_parallel.py +5 -4
  144. torch_geometric/nn/dense/linear.py +0 -20
  145. torch_geometric/nn/encoding.py +17 -3
  146. torch_geometric/nn/fx.py +14 -12
  147. torch_geometric/nn/model_hub.py +2 -15
  148. torch_geometric/nn/models/__init__.py +11 -2
  149. torch_geometric/nn/models/attentive_fp.py +1 -1
  150. torch_geometric/nn/models/attract_repel.py +148 -0
  151. torch_geometric/nn/models/basic_gnn.py +2 -1
  152. torch_geometric/nn/models/captum.py +1 -1
  153. torch_geometric/nn/models/deep_graph_infomax.py +1 -1
  154. torch_geometric/nn/models/dimenet.py +2 -2
  155. torch_geometric/nn/models/dimenet_utils.py +4 -2
  156. torch_geometric/nn/models/gpse.py +1083 -0
  157. torch_geometric/nn/models/graph_unet.py +13 -4
  158. torch_geometric/nn/models/lpformer.py +783 -0
  159. torch_geometric/nn/models/metapath2vec.py +1 -1
  160. torch_geometric/nn/models/mlp.py +4 -2
  161. torch_geometric/nn/models/node2vec.py +1 -1
  162. torch_geometric/nn/models/polynormer.py +206 -0
  163. torch_geometric/nn/models/rev_gnn.py +3 -3
  164. torch_geometric/nn/models/sgformer.py +219 -0
  165. torch_geometric/nn/models/signed_gcn.py +1 -1
  166. torch_geometric/nn/models/visnet.py +2 -2
  167. torch_geometric/nn/norm/batch_norm.py +17 -7
  168. torch_geometric/nn/norm/diff_group_norm.py +7 -2
  169. torch_geometric/nn/norm/graph_norm.py +9 -4
  170. torch_geometric/nn/norm/instance_norm.py +5 -1
  171. torch_geometric/nn/norm/layer_norm.py +15 -7
  172. torch_geometric/nn/norm/msg_norm.py +8 -2
  173. torch_geometric/nn/pool/__init__.py +8 -4
  174. torch_geometric/nn/pool/cluster_pool.py +3 -4
  175. torch_geometric/nn/pool/connect/base.py +1 -3
  176. torch_geometric/nn/pool/knn.py +13 -10
  177. torch_geometric/nn/pool/select/base.py +1 -4
  178. torch_geometric/nn/to_hetero_module.py +4 -3
  179. torch_geometric/nn/to_hetero_transformer.py +3 -3
  180. torch_geometric/nn/to_hetero_with_bases_transformer.py +4 -4
  181. torch_geometric/profile/__init__.py +2 -0
  182. torch_geometric/profile/nvtx.py +66 -0
  183. torch_geometric/profile/utils.py +20 -5
  184. torch_geometric/sampler/__init__.py +2 -1
  185. torch_geometric/sampler/base.py +336 -7
  186. torch_geometric/sampler/hgt_sampler.py +11 -1
  187. torch_geometric/sampler/neighbor_sampler.py +296 -23
  188. torch_geometric/sampler/utils.py +93 -5
  189. torch_geometric/testing/__init__.py +4 -0
  190. torch_geometric/testing/decorators.py +35 -5
  191. torch_geometric/testing/distributed.py +1 -1
  192. torch_geometric/transforms/__init__.py +2 -0
  193. torch_geometric/transforms/add_gpse.py +49 -0
  194. torch_geometric/transforms/add_metapaths.py +8 -6
  195. torch_geometric/transforms/add_positional_encoding.py +2 -2
  196. torch_geometric/transforms/base_transform.py +2 -1
  197. torch_geometric/transforms/delaunay.py +65 -15
  198. torch_geometric/transforms/face_to_edge.py +32 -3
  199. torch_geometric/transforms/gdc.py +7 -8
  200. torch_geometric/transforms/largest_connected_components.py +1 -1
  201. torch_geometric/transforms/mask.py +5 -1
  202. torch_geometric/transforms/normalize_features.py +3 -3
  203. torch_geometric/transforms/random_link_split.py +1 -1
  204. torch_geometric/transforms/remove_duplicated_edges.py +4 -2
  205. torch_geometric/transforms/rooted_subgraph.py +1 -1
  206. torch_geometric/typing.py +70 -17
  207. torch_geometric/utils/__init__.py +4 -1
  208. torch_geometric/utils/_lexsort.py +0 -9
  209. torch_geometric/utils/_negative_sampling.py +27 -12
  210. torch_geometric/utils/_scatter.py +132 -195
  211. torch_geometric/utils/_sort_edge_index.py +0 -2
  212. torch_geometric/utils/_spmm.py +16 -14
  213. torch_geometric/utils/_subgraph.py +4 -0
  214. torch_geometric/utils/_trim_to_layer.py +2 -2
  215. torch_geometric/utils/convert.py +17 -10
  216. torch_geometric/utils/cross_entropy.py +34 -13
  217. torch_geometric/utils/embedding.py +91 -2
  218. torch_geometric/utils/geodesic.py +4 -3
  219. torch_geometric/utils/influence.py +279 -0
  220. torch_geometric/utils/map.py +13 -9
  221. torch_geometric/utils/nested.py +1 -1
  222. torch_geometric/utils/smiles.py +3 -3
  223. torch_geometric/utils/sparse.py +7 -14
  224. torch_geometric/visualization/__init__.py +2 -1
  225. torch_geometric/visualization/graph.py +250 -5
  226. torch_geometric/warnings.py +11 -2
  227. torch_geometric/nn/nlp/__init__.py +0 -7
  228. torch_geometric/nn/nlp/sentence_transformer.py +0 -101
@@ -0,0 +1,397 @@
1
+ from typing import List, Optional, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from tqdm import tqdm
6
+
7
+ from torch_geometric.loader import DataLoader, NeighborLoader
8
+ from torch_geometric.nn.models import GraphSAGE, basic_gnn
9
+
10
+
11
+ def deal_nan(x):
12
+ if isinstance(x, torch.Tensor):
13
+ x = x.clone()
14
+ x[torch.isnan(x)] = 0.0
15
+ return x
16
+
17
+
18
+ class GLEM(torch.nn.Module):
19
+ r"""This GNN+LM co-training model is based on GLEM from the `"Learning on
20
+ Large-scale Text-attributed Graphs via Variational Inference"
21
+ <https://arxiv.org/abs/2210.14709>`_ paper.
22
+
23
+ Args:
24
+ lm_to_use (str): A TextEncoder from huggingface model repo
25
+ with a classifier(default: TinyBERT)
26
+ gnn_to_use (torch_geometric.nn.models): (default: GraphSAGE)
27
+ out_channels (int): output channels for LM and GNN, should be same
28
+ num_gnn_heads Optional[int]: Number of heads for attention, if needed
29
+ num_gnn_layers (int): number of gnn layers
30
+ gnn_loss: loss function for gnn, (default: CrossEntropyLoss)
31
+ lm_loss: loss function for Language Model, (default: CrossEntropyLoss)
32
+ alpha (float): pseudo label weight of E-step, LM optimization,
33
+ (default: 0.5)
34
+ beta (float): pseudo label weight of M-step, GNN optimization,
35
+ (default: 0.5)
36
+ lm_dtype (torch.dtype): the data type once you load LM into memory,
37
+ (default: torch.bfloat16)
38
+ lm_use_lora (bool): choose if LM use Lora peft for fine tune,
39
+ (default: True)
40
+ lora_target_modules: The names of the target modules to apply the lora
41
+ adapter to, e.g. ['q_proj', 'v_proj'] for LLM , (default: None)
42
+
43
+ .. note::
44
+ See `examples/llm_plus_gnn/glem.py` for example usage.
45
+ """
46
+ def __init__(
47
+ self,
48
+ lm_to_use: str = 'prajjwal1/bert-tiny',
49
+ gnn_to_use: basic_gnn = GraphSAGE,
50
+ out_channels: int = 47,
51
+ gnn_loss: Optional[nn.Module] = None,
52
+ lm_loss: Optional[nn.Module] = None,
53
+ alpha: float = 0.5,
54
+ beta: float = 0.5,
55
+ lm_dtype: torch.dtype = torch.bfloat16,
56
+ lm_use_lora: bool = True,
57
+ lora_target_modules: Optional[Union[List[str], str]] = None,
58
+ device: Optional[Union[str, torch.device]] = None,
59
+ ):
60
+ super().__init__()
61
+
62
+ if gnn_loss is None:
63
+ gnn_loss = nn.CrossEntropyLoss(reduction='mean')
64
+ if lm_loss is None:
65
+ lm_loss = nn.CrossEntropyLoss(reduction='mean')
66
+ if device is None:
67
+ device = torch.device('cpu')
68
+
69
+ self.device = device
70
+ self.lm_loss = lm_loss
71
+ self.gnn = gnn_to_use
72
+ self.gnn_loss = gnn_loss
73
+ self.alpha = alpha
74
+ self.beta = beta
75
+ self.gnn_loss = gnn_loss
76
+ self.lm = lm_to_use
77
+ from transformers import AutoModelForSequenceClassification
78
+ self.lm = AutoModelForSequenceClassification.from_pretrained(
79
+ lm_to_use, num_labels=out_channels, torch_dtype=lm_dtype,
80
+ offload_folder="offload", trust_remote_code=True)
81
+ if lm_use_lora:
82
+ from peft import (
83
+ LoraConfig,
84
+ TaskType,
85
+ get_peft_model,
86
+ prepare_model_for_kbit_training,
87
+ )
88
+ print("Training LM with LORA!")
89
+ self.lm = prepare_model_for_kbit_training(self.lm)
90
+ config = LoraConfig(task_type=TaskType.SEQ_CLS, r=16,
91
+ lora_alpha=16, lora_dropout=0.05, bias="none",
92
+ target_modules=lora_target_modules)
93
+ self.lm = get_peft_model(self.lm, config)
94
+ self.lm.print_trainable_parameters()
95
+ self.lm.config.pad_token_id = self.lm.config.eos_token_id
96
+ self.lm_device = self.lm.device
97
+
98
+ if self.lm.num_labels != self.gnn.out_channels:
99
+ raise ValueError('''The output channel of language model \
100
+ and gnn should be the same''')
101
+
102
+ def pre_train_gnn(self, train_loader: NeighborLoader,
103
+ optimizer: torch.optim.Optimizer, num_epochs: int,
104
+ patience: int, ext_pseudo_labels: torch.Tensor = None,
105
+ is_augmented: bool = False, verbose: bool = True):
106
+ # Pretrain GNN, optional steps if you do not have pseudo labels.
107
+ best_acc = 0
108
+ early_stopping = 0
109
+ # training only based on gold data
110
+ for epoch in range(0, num_epochs):
111
+ acc, loss = self.train_gnn(train_loader, optimizer, epoch,
112
+ ext_pseudo_labels, is_augmented,
113
+ verbose)
114
+ if acc < best_acc:
115
+ early_stopping += 1
116
+ if early_stopping > patience:
117
+ print(f'Early stopped by Epoch: {epoch}, '
118
+ f'Best acc: {best_acc}')
119
+ break
120
+ best_acc = max(best_acc, acc)
121
+
122
+ def pre_train_lm(self, train_loader: DataLoader,
123
+ optimizer: torch.optim.Optimizer, num_epochs: int,
124
+ patience: int, ext_pseudo_labels: torch.Tensor = None,
125
+ is_augmented: bool = False, verbose: bool = True):
126
+ # Pretrain language model
127
+ best_acc = 0
128
+ early_stopping = 0
129
+ for epoch in range(1, num_epochs + 1):
130
+ acc, loss = self.train_lm(train_loader, optimizer, epoch,
131
+ ext_pseudo_labels, is_augmented, verbose)
132
+ if acc < best_acc:
133
+ early_stopping += 1
134
+ if early_stopping > patience:
135
+ print(f'Early stopped by Epoch: {epoch}, '
136
+ f'Best acc: {best_acc}')
137
+ break
138
+ best_acc = max(best_acc, acc)
139
+
140
+ def train(self, em_phase: str, train_loader: Union[DataLoader,
141
+ NeighborLoader],
142
+ optimizer: torch.optim.Optimizer, pseudo_labels: torch.Tensor,
143
+ epoch: int, is_augmented: bool = False, verbose: bool = False):
144
+ r"""GLEM training step, EM steps.
145
+
146
+ Args:
147
+ em_phase(str): 'gnn' or 'lm' choose which phase you are training on
148
+ train_loader(Union[DataLoader, NeighborLoader]): use DataLoader for
149
+ lm training, include tokenized data, labels is_gold mask.
150
+ use NeighborLoader for gnn training, include x, edge_index.
151
+ optimizer (torch.optim.Optimizer): optimizer for training
152
+ pseudo_labels(torch.Tensor): the predicted labels used as pseudo
153
+ labels
154
+ epoch (int): current epoch
155
+ is_augmented (bool): will use pseudo_labels or not
156
+ verbose (bool): print training progress bar or not
157
+
158
+ Returns:
159
+ acc (float): training accuracy
160
+ loss (float): loss value
161
+ """
162
+ if pseudo_labels is not None:
163
+ pseudo_labels = pseudo_labels.to(self.device)
164
+ if em_phase == 'gnn':
165
+ acc, loss = self.train_gnn(train_loader, optimizer, epoch,
166
+ pseudo_labels, is_augmented, verbose)
167
+ if em_phase == 'lm':
168
+ acc, loss = self.train_lm(train_loader, optimizer, epoch,
169
+ pseudo_labels, is_augmented, verbose)
170
+ return acc, loss
171
+
172
+ def train_lm(self, train_loader: DataLoader,
173
+ optimizer: torch.optim.Optimizer, epoch: int,
174
+ pseudo_labels: torch.Tensor = None,
175
+ is_augmented: bool = False, verbose: bool = True):
176
+ r"""Language model Training in every epoch.
177
+
178
+ Args:
179
+ train_loader (loader.dataloader.DataLoader): text token dataloader
180
+ optimizer (torch.optim.Optimizer): model optimizer
181
+ epoch (int): current train epoch
182
+ pseudo_labels (torch.Tensor): 1-D tensor, predictions from gnn
183
+ is_augmented (bool): train with pseudo labels or not
184
+ verbose (bool): print training progress bar or not
185
+
186
+ Returns:
187
+ approx_acc (torch.tensor): training accuracy
188
+ loss (torch.float): loss value
189
+
190
+ """
191
+ all_out = []
192
+ total_loss = total_correct = 0
193
+ num_nodes = train_loader.dataset.indices.size(0)
194
+ self.lm.train()
195
+ if verbose:
196
+ pbar = tqdm(total=num_nodes)
197
+ pbar.set_description(f'Epoch {epoch:02d}')
198
+ for batch in train_loader:
199
+ inputs = {k: v.to(self.device) for k, v in batch['input'].items()}
200
+ out = self.lm(**inputs).logits
201
+ labels = batch['labels'].to(self.device).squeeze()
202
+ # training with pseudo labels or not
203
+ if is_augmented:
204
+ pl_batch = pseudo_labels[batch['n_id']].to(self.device)
205
+ else:
206
+ pl_batch = None
207
+ loss = self.loss(out, labels, self.lm_loss,
208
+ batch['is_gold'].to(self.device), pl_batch,
209
+ self.alpha, is_augmented)
210
+ loss.backward()
211
+ optimizer.step()
212
+ optimizer.zero_grad()
213
+ all_out.append(out)
214
+ total_correct += int(out.argmax(dim=-1).eq(labels).sum())
215
+ total_loss += float(loss.detach())
216
+ if verbose:
217
+ pbar.update(batch['n_id'].size(0))
218
+
219
+ all_out = torch.cat(all_out, dim=0)
220
+ approx_acc = total_correct / num_nodes
221
+ loss = total_loss / len(train_loader)
222
+ if verbose:
223
+ pbar.close()
224
+ print(f'Epoch {epoch:02d} Loss: {loss:.4f} '
225
+ f'Approx. Train: {approx_acc:.4f}')
226
+ return approx_acc, loss
227
+
228
+ def train_gnn(self, train_loader: NeighborLoader,
229
+ optimizer: torch.optim.Optimizer, epoch: int,
230
+ pseudo_labels: torch.Tensor = None,
231
+ is_augmented: bool = False, verbose: bool = True):
232
+ r"""GNN training step in every epoch.
233
+
234
+ Args:
235
+ train_loader (loader.NeighborLoader): gnn Neighbor node loader
236
+ optimizer (torch.optim.Optimizer): model optimizer
237
+ epoch (int): current train epoch
238
+ pseudo_labels(torch.tensor): 1-D tensor, predictions from lm
239
+ is_augmented(bool): use pseudo labeled node or not
240
+ verbose (bool): print training progress or not
241
+
242
+ Returns:
243
+ approx_acc (torch.tensor): training accuracy
244
+ loss (torch.float): loss value
245
+ """
246
+ self.gnn.train()
247
+ num_nodes = train_loader.input_nodes.size(0)
248
+ if verbose:
249
+ pbar = tqdm(total=num_nodes)
250
+ pbar.set_description(f'Epoch {epoch:02d}')
251
+ total_loss = total_correct = 0
252
+ all_out = []
253
+ for batch in train_loader:
254
+ batch = batch.to(self.device)
255
+ out = self.gnn(batch.x, batch.edge_index)[:batch.batch_size]
256
+ all_out.append(out)
257
+ labels = batch.y[:batch.batch_size].squeeze()
258
+ is_gold_batch = batch.is_gold[:batch.batch_size].squeeze()
259
+ # training with pseudo labels or not
260
+ if is_augmented and pseudo_labels is not None:
261
+ pl_batch = pseudo_labels[batch.n_id[:batch.batch_size]]
262
+ else:
263
+ pl_batch = None
264
+ loss = self.loss(out, labels, self.gnn_loss, is_gold_batch,
265
+ pl_batch, self.beta, is_augmented)
266
+ loss.backward()
267
+ optimizer.step()
268
+ optimizer.zero_grad()
269
+ total_loss += float(loss.detach())
270
+ total_correct += int(out.argmax(dim=-1).eq(labels).sum())
271
+ if verbose:
272
+ pbar.update(batch.batch_size)
273
+
274
+ all_out = torch.cat(all_out, dim=0)
275
+ loss = total_loss / len(train_loader)
276
+ approx_acc = total_correct / num_nodes
277
+ if verbose:
278
+ pbar.close()
279
+ print(f'Epoch: {epoch:02d} Loss: {loss:.4f} '
280
+ f'Approx. Train: {approx_acc:.4f}')
281
+ return approx_acc, loss
282
+
283
+ @torch.no_grad()
284
+ def inference(self, em_phase: str, data_loader: Union[NeighborLoader,
285
+ DataLoader],
286
+ verbose: bool = False):
287
+ r"""GLEM inference step.
288
+
289
+ Args:
290
+ em_phase(str): 'gnn' or 'lm'
291
+ data_loader(dataloader or Neighborloader):
292
+ dataloader: for lm training, include tokenized data
293
+ nodeloader: for gnn training, include x, edge_index
294
+ verbose(bool): print inference progress or not
295
+
296
+ Returns:
297
+ out (torch.Tensor): n * m tensor, m is number of classes,
298
+ n is number of nodes
299
+ """
300
+ out = None
301
+ if em_phase == 'gnn':
302
+ self.gnn.eval()
303
+ out = self.inference_gnn(data_loader, verbose)
304
+ elif em_phase == 'lm':
305
+ self.lm.eval()
306
+ out = self.inference_lm(data_loader, verbose)
307
+ return out
308
+
309
+ @torch.no_grad()
310
+ def inference_lm(self, data_loader: DataLoader, verbose: bool = True):
311
+ r"""LM inference step.
312
+
313
+ Args:
314
+ data_loader (Dataloader): include token, labels, and gold mask
315
+ verbose (bool): print progress bar or not
316
+
317
+ Returns:
318
+ preds (tensor): prediction from GNN, convert to pseudo labels
319
+ by preds.argmax(dim=-1).unsqueeze(1)
320
+ """
321
+ if verbose:
322
+ pbar = tqdm(total=data_loader.dataset._data.num_nodes)
323
+ pbar.set_description('LM inference stage')
324
+ self.lm.eval()
325
+ preds = []
326
+ for batch in data_loader:
327
+ inputs = {k: v.to(self.device) for k, v in batch['input'].items()}
328
+ logits = self.lm(**inputs).logits
329
+ preds.append(logits)
330
+ if verbose:
331
+ pbar.update(batch['n_id'].size(0))
332
+ if verbose:
333
+ pbar.close()
334
+ preds = torch.cat(preds)
335
+ return preds
336
+
337
+ @torch.no_grad()
338
+ def inference_gnn(self, data_loader: NeighborLoader, verbose: bool = True):
339
+ r"""GNN inference step.
340
+
341
+ Args:
342
+ data_loader(NeighborLoader): include x, edge_index,
343
+ verbose (bool): print progress bar or not
344
+
345
+ Returns:
346
+ preds (tensor): prediction from GNN,
347
+ convert to pseudo labels by preds.argmax(dim=-1).unsqueeze(1)
348
+ """
349
+ if verbose:
350
+ pbar = tqdm(total=data_loader.data.num_nodes)
351
+ pbar.set_description('GNN inference stage')
352
+ preds = []
353
+ self.gnn.eval()
354
+ for batch in data_loader:
355
+ batch = batch.to(self.device)
356
+ out = self.gnn(batch.x, batch.edge_index)[:batch.batch_size]
357
+ preds.append(out)
358
+ if verbose:
359
+ pbar.update(batch.batch_size)
360
+ if verbose:
361
+ pbar.close()
362
+ preds = torch.cat(preds, dim=0)
363
+ return preds
364
+
365
+ def loss(self, logits: torch.Tensor, labels: torch.Tensor,
366
+ loss_func: torch.nn.functional, is_gold: torch.Tensor,
367
+ pseudo_labels: torch.Tensor = None, pl_weight: float = 0.5,
368
+ is_augmented: bool = True):
369
+ r"""Core function of variational EM inference, this function is aming
370
+ on combining loss value on gold(original train) and loss value on
371
+ pseudo labels.
372
+
373
+ Reference:
374
+ <https://github.com/AndyJZhao/GLEM/blob/main/src/models/GLEM/GLEM_utils.py> # noqa
375
+
376
+ Args:
377
+ logits(torch.tensor): predict results from LM or GNN
378
+ labels(torch.tensor): combined node labels from ground truth and
379
+ pseudo labels(if provided)
380
+ loss_func(torch.nn.modules.loss): loss function for classification
381
+ is_gold(tensor): a tensor with bool value that mask ground truth
382
+ label and during training, thus ~is_gold mask pseudo labels
383
+ pseudo_labels(torch.tensor): predictions from other model
384
+ pl_weight: the pseudo labels used in E-step and M-step optimization
385
+ alpha in E-step, beta in M-step respectively
386
+ is_augmented: use EM or just train GNN and LM with gold data
387
+
388
+ """
389
+ if is_augmented and (sum(~is_gold) > 0):
390
+ mle_loss = deal_nan(loss_func(logits[is_gold], labels[is_gold]))
391
+ # all other labels beside from ground truth(gold labels)
392
+ pseudo_label_loss = deal_nan(
393
+ loss_func(logits[~is_gold], pseudo_labels[~is_gold]))
394
+ loss = pl_weight * pseudo_label_loss + (1 - pl_weight) * mle_loss
395
+ else:
396
+ loss = loss_func(logits, labels)
397
+ return loss
@@ -10,15 +10,17 @@ try:
10
10
  except ImportError:
11
11
  BatchEncoding = Dict
12
12
 
13
- BOS = '<s>[INST]'
14
- EOS_USER = '[/INST]'
15
- EOS = '[/s]'
16
13
  IGNORE_INDEX = -100
17
14
  MAX_TXT_LEN = 512
18
- MAX_NEW_TOKENS = 32
15
+ MAX_NEW_TOKENS = 128
19
16
  PAD_TOKEN_ID = 0
20
17
  PADDING_SIDE = 'left'
21
18
 
19
+ # legacy constants - used for Llama 2 style prompting
20
+ BOS = '<s>[INST]'
21
+ EOS_USER = '[/INST]'
22
+ EOS = '[/s]'
23
+
22
24
 
23
25
  def get_llm_kwargs(required_memory: int, dtype=torch.dtype) -> Dict[str, Any]:
24
26
  torch.cuda.empty_cache()
@@ -49,50 +51,108 @@ def get_llm_kwargs(required_memory: int, dtype=torch.dtype) -> Dict[str, Any]:
49
51
  class LLM(torch.nn.Module):
50
52
  r"""A wrapper around a Large Language Model (LLM) from HuggingFace.
51
53
 
52
- model_name (str): The HuggingFace model name, *e.g.*, :obj:`"llama2"` or
53
- :obj:`"gemma"`.
54
- num_params (int): An integer representing how many parameters the
55
- HuggingFace model has, in billions. This is used to automatically
56
- allocate the correct number of GPUs needed, given the available GPU
57
- memory of your GPUs.
58
- dtype (torch.dtype, optional): The data type to use for the LLM.
59
- (default :obj: `torch.bloat16`)
54
+ Args:
55
+ model_name (str): The HuggingFace model name
56
+ num_params (float, optional): An integer representing how many params
57
+ the HuggingFace model has, in billions. This is used to
58
+ automatically allocate the correct number of GPUs needed (using a
59
+ rough heuristic), given the available GPU memory of your GPUs. If
60
+ not specified, the number of parameters is determined using the
61
+ `huggingface_hub` module.
62
+ n_gpus (int, optional): Number of GPUs to use. Designed for advanced
63
+ users to select how many GPU's they want to set this manually and
64
+ override the automatic set up mechanism.
65
+ dtype (torch.dtype, optional): The data type to use for the LLM.
66
+ (default :obj: `torch.bfloat16`)
67
+ sys_prompt (str, optional): A system prompt to use for the LLM.
68
+ (default: :obj: `None`)
60
69
  """
61
70
  def __init__(
62
71
  self,
63
72
  model_name: str,
64
- num_params: int,
65
- dtype=torch.bfloat16,
73
+ num_params: Optional[float] = None,
74
+ n_gpus: Optional[int] = None,
75
+ dtype: Optional[torch.dtype] = torch.bfloat16,
76
+ sys_prompt: Optional[str] = None,
66
77
  ) -> None:
67
78
  super().__init__()
68
79
 
69
80
  self.model_name = model_name
70
81
 
71
82
  from transformers import AutoModelForCausalLM, AutoTokenizer
72
-
73
- # A rough heuristic on GPU memory requirements, e.g., we found that
74
- # LLAMA2 (7B parameters) fits on a 85GB GPU.
75
- required_memory = 85 * num_params / 7
76
- kwargs = get_llm_kwargs(required_memory, dtype)
83
+ if n_gpus is None:
84
+ if num_params is None:
85
+ from huggingface_hub import get_safetensors_metadata
86
+ safetensors_metadata = get_safetensors_metadata(model_name)
87
+ param_count = safetensors_metadata.parameter_count
88
+ num_params = float(list(param_count.values())[0] // 10**9)
89
+
90
+ # A rough heuristic on GPU memory requirements, e.g., we found that
91
+ # LLAMA3 (8B parameters) fits on a 96GB GPU.
92
+ required_memory = 96.0 * num_params / 8.0
93
+ kwargs = get_llm_kwargs(required_memory, dtype)
94
+ else:
95
+ gpu_memory: List[int] = []
96
+ for i in range(n_gpus):
97
+ gpu_memory.append(torch.cuda.mem_get_info(i)[0] // 1024**3)
98
+ kwargs = dict(revision='main')
99
+ kwargs['max_memory'] = {
100
+ i: f'{memory}GiB'
101
+ for i, memory in enumerate(gpu_memory)
102
+ }
103
+ kwargs['low_cpu_mem_usage'] = True
104
+ kwargs['device_map'] = 'auto'
105
+ kwargs['torch_dtype'] = dtype
77
106
 
78
107
  print(f"Setting up '{model_name}' with configuration: {kwargs}")
79
108
  self.tokenizer = AutoTokenizer.from_pretrained(
80
109
  model_name,
81
110
  use_fast=False,
82
111
  )
83
- self.tokenizer.pad_token_id = PAD_TOKEN_ID
84
- self.tokenizer.padding_side = PADDING_SIDE
112
+ if self.tokenizer.chat_template and self.tokenizer.bos_token is None:
113
+ dummy_convo = [
114
+ {
115
+ "role": "system",
116
+ "content": "dummy"
117
+ },
118
+ {
119
+ "role": "user",
120
+ "content": "convo"
121
+ },
122
+ ]
123
+ text = self.tokenizer.apply_chat_template(
124
+ dummy_convo,
125
+ tokenize=True,
126
+ )
127
+ self.tokenizer.bos_token = self.tokenizer.decode(text[0])
128
+ if self.tokenizer.pad_token_id is None:
129
+ self.tokenizer.pad_token_id = PAD_TOKEN_ID
130
+ if self.tokenizer.padding_side is None:
131
+ self.tokenizer.padding_side = PADDING_SIDE
85
132
  self.llm = AutoModelForCausalLM.from_pretrained(model_name, **kwargs)
86
133
  self.word_embedding = self.llm.model.get_input_embeddings()
87
-
134
+ if sys_prompt is not None:
135
+ self.sys_prompt = sys_prompt
136
+ else:
137
+ self.sys_prompt = ""
88
138
  if 'max_memory' not in kwargs: # Pure CPU:
89
- warnings.warn("LLM is being used on CPU, which may be slow")
139
+ warnings.warn(
140
+ "LLM is being used on CPU, which may be slow. This decision "
141
+ "was made by a rough hueristic that assumes your GPU set up "
142
+ "does not have enough GPU RAM. This is done to avoid GPU OOM "
143
+ "errors. If you think this is a mistake, please initialize "
144
+ "your LLM with the n_gpus param to dictate how many gpus to "
145
+ "use for the LLM.", stacklevel=2)
90
146
  self.device = torch.device('cpu')
91
147
  self.autocast_context = nullcontext()
92
148
  else:
93
149
  self.device = self.llm.device
94
- self.autocast_context = torch.amp.autocast('cuda', dtype=dtype)
150
+ if dtype == torch.float32:
151
+ self.autocast_context = nullcontext()
152
+ else:
153
+ self.autocast_context = torch.amp.autocast('cuda', dtype=dtype)
95
154
 
155
+ # legacy function - used for Llama 2 style prompting
96
156
  def _encode_inputs(
97
157
  self,
98
158
  question: List[str],
@@ -126,6 +186,7 @@ class LLM(torch.nn.Module):
126
186
  label_input_ids = label_input_ids + eos_tokens.input_ids
127
187
  return label_input_ids
128
188
 
189
+ # legacy function - used for Llama 2 style prompting
129
190
  def _input_ids(
130
191
  self,
131
192
  i: int,
@@ -140,6 +201,7 @@ class LLM(torch.nn.Module):
140
201
  input_ids += eos_user_tokens.input_ids
141
202
  return input_ids
142
203
 
204
+ # legacy function - used for Llama 2 style prompting
143
205
  def _inputs_embeds(
144
206
  self,
145
207
  i: int,
@@ -199,7 +261,8 @@ class LLM(torch.nn.Module):
199
261
  device=self.device)
200
262
  return inputs_embeds, attention_mask, label_input_ids
201
263
 
202
- def _get_embeds(
264
+ # legacy function - used for Llama 2 style prompting
265
+ def _get_embeds_old(
203
266
  self,
204
267
  question: List[str],
205
268
  context: Optional[List[str]] = None,
@@ -246,6 +309,95 @@ class LLM(torch.nn.Module):
246
309
 
247
310
  return inputs_embeds, attention_mask, label_input_ids
248
311
 
312
+ def _get_embeds(
313
+ self,
314
+ question: List[str],
315
+ context: Optional[List[str]] = None,
316
+ embedding: Optional[List[Tensor]] = None,
317
+ answer: Optional[List[str]] = None,
318
+ ) -> tuple:
319
+ if not self.tokenizer.chat_template or not self.sys_prompt:
320
+ warnings.warn(
321
+ f"HuggingFace model {self.model_name} is not using a "
322
+ "chat template, using Llama 2 style prompting. Please "
323
+ "consider using a more recent model and initialize the "
324
+ "LLM with `sys_prompt`.", stacklevel=2)
325
+ return self._get_embeds_old(question, context, embedding, answer)
326
+ batch_label_input_ids = None
327
+ if answer is not None:
328
+ label = self.tokenizer(answer, add_special_tokens=False)
329
+ eos_tokens = self.tokenizer(self.tokenizer.eos_token,
330
+ add_special_tokens=False)
331
+ batch_label_input_ids = []
332
+
333
+ batch_inputs_embeds = []
334
+ batch_attention_mask = []
335
+ for i in range(len(question)):
336
+ ctx = f"{context[i]} - " if context else ""
337
+ messages = [
338
+ {
339
+ "role": "system",
340
+ "content": self.sys_prompt
341
+ },
342
+ {
343
+ "role": "user",
344
+ "content": f"{ctx} - {question[i]}"
345
+ },
346
+ ]
347
+ text = self.tokenizer.apply_chat_template(
348
+ messages,
349
+ tokenize=False,
350
+ add_generation_prompt=True,
351
+ enable_thinking=True,
352
+ )
353
+ text = text[len(self.tokenizer.bos_token):]
354
+ input_ids = self.tokenizer(text,
355
+ add_special_tokens=False).input_ids
356
+ if answer is not None:
357
+ label_input_ids = self._label_input_ids(i, label, eos_tokens)
358
+ input_ids += label_input_ids
359
+ else:
360
+ label_input_ids = None
361
+
362
+ bos_token = self.tokenizer(
363
+ self.tokenizer.bos_token,
364
+ add_special_tokens=False,
365
+ return_tensors='pt',
366
+ ).input_ids[0].to(self.device)
367
+
368
+ bos_embeds = self.word_embedding(bos_token)
369
+
370
+ inputs_embeds = self.word_embedding(
371
+ torch.tensor(input_ids, device=self.device))
372
+
373
+ to_cat = [bos_embeds]
374
+ if embedding is not None and embedding[i] is not None:
375
+ to_cat.append(embedding[i])
376
+ to_cat.append(inputs_embeds)
377
+ inputs_embeds = torch.cat(to_cat, dim=0).to(self.device)
378
+
379
+ (
380
+ batch_inputs_embeds,
381
+ batch_attention_mask,
382
+ batch_label_input_ids,
383
+ ) = self._append_embeds(
384
+ inputs_embeds,
385
+ batch_inputs_embeds,
386
+ batch_attention_mask,
387
+ label_input_ids,
388
+ batch_label_input_ids,
389
+ )
390
+
391
+ pad_token = torch.tensor(self.tokenizer.pad_token_id,
392
+ device=self.device)
393
+ pad_embeds = self.word_embedding(pad_token).unsqueeze(0)
394
+
395
+ inputs_embeds, attention_mask, label_input_ids = self._pad_embeds(
396
+ pad_embeds, batch_inputs_embeds, batch_attention_mask,
397
+ batch_label_input_ids)
398
+
399
+ return inputs_embeds, attention_mask, label_input_ids
400
+
249
401
  def forward(
250
402
  self,
251
403
  question: List[str],
@@ -302,17 +454,13 @@ class LLM(torch.nn.Module):
302
454
  inputs_embeds, attention_mask, _ = self._get_embeds(
303
455
  question, context, embedding)
304
456
 
305
- bos_token = self.tokenizer(
306
- BOS,
307
- add_special_tokens=False,
308
- ).input_ids[0]
309
-
310
457
  with self.autocast_context:
311
458
  outputs = self.llm.generate(
312
459
  inputs_embeds=inputs_embeds,
313
- bos_token_id=bos_token,
460
+ bos_token_id=self.tokenizer.bos_token_id,
314
461
  max_new_tokens=max_tokens,
315
462
  attention_mask=attention_mask,
463
+ pad_token_id=self.tokenizer.eos_token_id,
316
464
  use_cache=True,
317
465
  )
318
466