pyg-nightly 2.6.0.dev20240704__py3-none-any.whl → 2.8.0.dev20251207__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pyg-nightly might be problematic. Click here for more details.

Files changed (268) hide show
  1. {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/METADATA +81 -58
  2. {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/RECORD +265 -221
  3. {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/WHEEL +1 -1
  4. pyg_nightly-2.8.0.dev20251207.dist-info/licenses/LICENSE +19 -0
  5. torch_geometric/__init__.py +34 -1
  6. torch_geometric/_compile.py +11 -3
  7. torch_geometric/_onnx.py +228 -0
  8. torch_geometric/config_mixin.py +8 -3
  9. torch_geometric/config_store.py +1 -1
  10. torch_geometric/contrib/__init__.py +1 -1
  11. torch_geometric/contrib/explain/pgm_explainer.py +1 -1
  12. torch_geometric/data/__init__.py +19 -1
  13. torch_geometric/data/batch.py +2 -2
  14. torch_geometric/data/collate.py +1 -3
  15. torch_geometric/data/data.py +110 -6
  16. torch_geometric/data/database.py +19 -5
  17. torch_geometric/data/dataset.py +14 -9
  18. torch_geometric/data/extract.py +1 -1
  19. torch_geometric/data/feature_store.py +17 -22
  20. torch_geometric/data/graph_store.py +3 -2
  21. torch_geometric/data/hetero_data.py +139 -7
  22. torch_geometric/data/hypergraph_data.py +2 -2
  23. torch_geometric/data/in_memory_dataset.py +2 -2
  24. torch_geometric/data/lightning/datamodule.py +42 -28
  25. torch_geometric/data/storage.py +9 -1
  26. torch_geometric/datasets/__init__.py +20 -1
  27. torch_geometric/datasets/actor.py +7 -9
  28. torch_geometric/datasets/airfrans.py +17 -20
  29. torch_geometric/datasets/airports.py +8 -10
  30. torch_geometric/datasets/amazon.py +8 -11
  31. torch_geometric/datasets/amazon_book.py +8 -9
  32. torch_geometric/datasets/amazon_products.py +7 -9
  33. torch_geometric/datasets/aminer.py +8 -9
  34. torch_geometric/datasets/aqsol.py +10 -13
  35. torch_geometric/datasets/attributed_graph_dataset.py +8 -10
  36. torch_geometric/datasets/ba_multi_shapes.py +10 -12
  37. torch_geometric/datasets/ba_shapes.py +5 -6
  38. torch_geometric/datasets/brca_tgca.py +1 -1
  39. torch_geometric/datasets/city.py +157 -0
  40. torch_geometric/datasets/dbp15k.py +1 -1
  41. torch_geometric/datasets/gdelt_lite.py +3 -2
  42. torch_geometric/datasets/ged_dataset.py +3 -2
  43. torch_geometric/datasets/git_mol_dataset.py +263 -0
  44. torch_geometric/datasets/gnn_benchmark_dataset.py +3 -2
  45. torch_geometric/datasets/hgb_dataset.py +2 -2
  46. torch_geometric/datasets/hm.py +1 -1
  47. torch_geometric/datasets/instruct_mol_dataset.py +134 -0
  48. torch_geometric/datasets/linkx_dataset.py +4 -3
  49. torch_geometric/datasets/lrgb.py +3 -5
  50. torch_geometric/datasets/malnet_tiny.py +2 -1
  51. torch_geometric/datasets/md17.py +3 -3
  52. torch_geometric/datasets/medshapenet.py +145 -0
  53. torch_geometric/datasets/mnist_superpixels.py +2 -3
  54. torch_geometric/datasets/modelnet.py +1 -1
  55. torch_geometric/datasets/molecule_gpt_dataset.py +492 -0
  56. torch_geometric/datasets/molecule_net.py +3 -2
  57. torch_geometric/datasets/neurograph.py +1 -3
  58. torch_geometric/datasets/ogb_mag.py +1 -1
  59. torch_geometric/datasets/opf.py +19 -5
  60. torch_geometric/datasets/pascal_pf.py +1 -1
  61. torch_geometric/datasets/pcqm4m.py +2 -1
  62. torch_geometric/datasets/ppi.py +2 -1
  63. torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
  64. torch_geometric/datasets/qm7.py +1 -1
  65. torch_geometric/datasets/qm9.py +3 -2
  66. torch_geometric/datasets/shrec2016.py +2 -2
  67. torch_geometric/datasets/snap_dataset.py +8 -4
  68. torch_geometric/datasets/tag_dataset.py +462 -0
  69. torch_geometric/datasets/teeth3ds.py +269 -0
  70. torch_geometric/datasets/web_qsp_dataset.py +342 -0
  71. torch_geometric/datasets/wikics.py +2 -1
  72. torch_geometric/datasets/wikidata.py +2 -1
  73. torch_geometric/deprecation.py +1 -1
  74. torch_geometric/distributed/__init__.py +13 -0
  75. torch_geometric/distributed/dist_loader.py +2 -2
  76. torch_geometric/distributed/local_feature_store.py +3 -2
  77. torch_geometric/distributed/local_graph_store.py +2 -1
  78. torch_geometric/distributed/partition.py +9 -8
  79. torch_geometric/distributed/rpc.py +3 -3
  80. torch_geometric/edge_index.py +35 -22
  81. torch_geometric/explain/algorithm/attention_explainer.py +219 -29
  82. torch_geometric/explain/algorithm/base.py +2 -2
  83. torch_geometric/explain/algorithm/captum.py +1 -1
  84. torch_geometric/explain/algorithm/captum_explainer.py +2 -1
  85. torch_geometric/explain/algorithm/gnn_explainer.py +406 -69
  86. torch_geometric/explain/algorithm/graphmask_explainer.py +8 -8
  87. torch_geometric/explain/algorithm/pg_explainer.py +305 -47
  88. torch_geometric/explain/explainer.py +2 -2
  89. torch_geometric/explain/explanation.py +89 -5
  90. torch_geometric/explain/metric/faithfulness.py +1 -1
  91. torch_geometric/graphgym/checkpoint.py +2 -1
  92. torch_geometric/graphgym/config.py +3 -2
  93. torch_geometric/graphgym/imports.py +15 -4
  94. torch_geometric/graphgym/logger.py +1 -1
  95. torch_geometric/graphgym/loss.py +1 -1
  96. torch_geometric/graphgym/models/encoder.py +2 -2
  97. torch_geometric/graphgym/models/layer.py +1 -1
  98. torch_geometric/graphgym/utils/comp_budget.py +4 -3
  99. torch_geometric/hash_tensor.py +798 -0
  100. torch_geometric/index.py +16 -7
  101. torch_geometric/inspector.py +6 -2
  102. torch_geometric/io/fs.py +27 -0
  103. torch_geometric/io/tu.py +2 -3
  104. torch_geometric/llm/__init__.py +9 -0
  105. torch_geometric/llm/large_graph_indexer.py +741 -0
  106. torch_geometric/llm/models/__init__.py +23 -0
  107. torch_geometric/llm/models/g_retriever.py +251 -0
  108. torch_geometric/llm/models/git_mol.py +336 -0
  109. torch_geometric/llm/models/glem.py +397 -0
  110. torch_geometric/llm/models/llm.py +470 -0
  111. torch_geometric/llm/models/llm_judge.py +158 -0
  112. torch_geometric/llm/models/molecule_gpt.py +222 -0
  113. torch_geometric/llm/models/protein_mpnn.py +333 -0
  114. torch_geometric/llm/models/sentence_transformer.py +188 -0
  115. torch_geometric/llm/models/txt2kg.py +353 -0
  116. torch_geometric/llm/models/vision_transformer.py +38 -0
  117. torch_geometric/llm/rag_loader.py +154 -0
  118. torch_geometric/llm/utils/__init__.py +10 -0
  119. torch_geometric/llm/utils/backend_utils.py +443 -0
  120. torch_geometric/llm/utils/feature_store.py +169 -0
  121. torch_geometric/llm/utils/graph_store.py +199 -0
  122. torch_geometric/llm/utils/vectorrag.py +125 -0
  123. torch_geometric/loader/cluster.py +6 -5
  124. torch_geometric/loader/graph_saint.py +2 -1
  125. torch_geometric/loader/ibmb_loader.py +4 -4
  126. torch_geometric/loader/link_loader.py +1 -1
  127. torch_geometric/loader/link_neighbor_loader.py +2 -1
  128. torch_geometric/loader/mixin.py +6 -5
  129. torch_geometric/loader/neighbor_loader.py +1 -1
  130. torch_geometric/loader/neighbor_sampler.py +2 -2
  131. torch_geometric/loader/prefetch.py +4 -3
  132. torch_geometric/loader/temporal_dataloader.py +2 -2
  133. torch_geometric/loader/utils.py +10 -10
  134. torch_geometric/metrics/__init__.py +23 -2
  135. torch_geometric/metrics/link_pred.py +755 -85
  136. torch_geometric/nn/__init__.py +1 -0
  137. torch_geometric/nn/aggr/__init__.py +2 -0
  138. torch_geometric/nn/aggr/base.py +1 -1
  139. torch_geometric/nn/aggr/equilibrium.py +1 -1
  140. torch_geometric/nn/aggr/fused.py +1 -1
  141. torch_geometric/nn/aggr/patch_transformer.py +149 -0
  142. torch_geometric/nn/aggr/set_transformer.py +1 -1
  143. torch_geometric/nn/aggr/utils.py +9 -4
  144. torch_geometric/nn/attention/__init__.py +9 -1
  145. torch_geometric/nn/attention/polynormer.py +107 -0
  146. torch_geometric/nn/attention/qformer.py +71 -0
  147. torch_geometric/nn/attention/sgformer.py +99 -0
  148. torch_geometric/nn/conv/__init__.py +2 -0
  149. torch_geometric/nn/conv/appnp.py +1 -1
  150. torch_geometric/nn/conv/collect.jinja +6 -3
  151. torch_geometric/nn/conv/cugraph/gat_conv.py +8 -2
  152. torch_geometric/nn/conv/cugraph/rgcn_conv.py +3 -0
  153. torch_geometric/nn/conv/cugraph/sage_conv.py +3 -0
  154. torch_geometric/nn/conv/dna_conv.py +1 -1
  155. torch_geometric/nn/conv/eg_conv.py +7 -7
  156. torch_geometric/nn/conv/gat_conv.py +33 -4
  157. torch_geometric/nn/conv/gatv2_conv.py +35 -4
  158. torch_geometric/nn/conv/gen_conv.py +1 -1
  159. torch_geometric/nn/conv/general_conv.py +1 -1
  160. torch_geometric/nn/conv/gravnet_conv.py +2 -1
  161. torch_geometric/nn/conv/hetero_conv.py +3 -2
  162. torch_geometric/nn/conv/meshcnn_conv.py +487 -0
  163. torch_geometric/nn/conv/message_passing.py +6 -5
  164. torch_geometric/nn/conv/mixhop_conv.py +1 -1
  165. torch_geometric/nn/conv/rgcn_conv.py +2 -1
  166. torch_geometric/nn/conv/sg_conv.py +1 -1
  167. torch_geometric/nn/conv/spline_conv.py +2 -1
  168. torch_geometric/nn/conv/ssg_conv.py +1 -1
  169. torch_geometric/nn/conv/transformer_conv.py +5 -3
  170. torch_geometric/nn/data_parallel.py +5 -4
  171. torch_geometric/nn/dense/linear.py +5 -24
  172. torch_geometric/nn/encoding.py +17 -3
  173. torch_geometric/nn/fx.py +17 -15
  174. torch_geometric/nn/model_hub.py +5 -16
  175. torch_geometric/nn/models/__init__.py +11 -0
  176. torch_geometric/nn/models/attentive_fp.py +1 -1
  177. torch_geometric/nn/models/attract_repel.py +148 -0
  178. torch_geometric/nn/models/basic_gnn.py +2 -1
  179. torch_geometric/nn/models/captum.py +1 -1
  180. torch_geometric/nn/models/deep_graph_infomax.py +1 -1
  181. torch_geometric/nn/models/dimenet.py +2 -2
  182. torch_geometric/nn/models/dimenet_utils.py +4 -2
  183. torch_geometric/nn/models/gpse.py +1083 -0
  184. torch_geometric/nn/models/graph_unet.py +13 -4
  185. torch_geometric/nn/models/lpformer.py +783 -0
  186. torch_geometric/nn/models/metapath2vec.py +1 -1
  187. torch_geometric/nn/models/mlp.py +4 -2
  188. torch_geometric/nn/models/node2vec.py +1 -1
  189. torch_geometric/nn/models/polynormer.py +206 -0
  190. torch_geometric/nn/models/rev_gnn.py +3 -3
  191. torch_geometric/nn/models/schnet.py +2 -1
  192. torch_geometric/nn/models/sgformer.py +219 -0
  193. torch_geometric/nn/models/signed_gcn.py +1 -1
  194. torch_geometric/nn/models/visnet.py +2 -2
  195. torch_geometric/nn/norm/batch_norm.py +17 -7
  196. torch_geometric/nn/norm/diff_group_norm.py +7 -2
  197. torch_geometric/nn/norm/graph_norm.py +9 -4
  198. torch_geometric/nn/norm/instance_norm.py +5 -1
  199. torch_geometric/nn/norm/layer_norm.py +15 -7
  200. torch_geometric/nn/norm/msg_norm.py +8 -2
  201. torch_geometric/nn/pool/__init__.py +15 -9
  202. torch_geometric/nn/pool/cluster_pool.py +144 -0
  203. torch_geometric/nn/pool/connect/base.py +1 -3
  204. torch_geometric/nn/pool/edge_pool.py +1 -1
  205. torch_geometric/nn/pool/knn.py +13 -10
  206. torch_geometric/nn/pool/select/base.py +1 -4
  207. torch_geometric/nn/summary.py +1 -1
  208. torch_geometric/nn/to_hetero_module.py +4 -3
  209. torch_geometric/nn/to_hetero_transformer.py +3 -3
  210. torch_geometric/nn/to_hetero_with_bases_transformer.py +5 -5
  211. torch_geometric/profile/__init__.py +2 -0
  212. torch_geometric/profile/nvtx.py +66 -0
  213. torch_geometric/profile/profiler.py +18 -9
  214. torch_geometric/profile/utils.py +20 -5
  215. torch_geometric/sampler/__init__.py +2 -1
  216. torch_geometric/sampler/base.py +337 -8
  217. torch_geometric/sampler/hgt_sampler.py +11 -1
  218. torch_geometric/sampler/neighbor_sampler.py +298 -25
  219. torch_geometric/sampler/utils.py +93 -5
  220. torch_geometric/testing/__init__.py +4 -0
  221. torch_geometric/testing/decorators.py +35 -5
  222. torch_geometric/testing/distributed.py +1 -1
  223. torch_geometric/transforms/__init__.py +4 -0
  224. torch_geometric/transforms/add_gpse.py +49 -0
  225. torch_geometric/transforms/add_metapaths.py +10 -8
  226. torch_geometric/transforms/add_positional_encoding.py +2 -2
  227. torch_geometric/transforms/base_transform.py +2 -1
  228. torch_geometric/transforms/delaunay.py +65 -15
  229. torch_geometric/transforms/face_to_edge.py +32 -3
  230. torch_geometric/transforms/gdc.py +8 -9
  231. torch_geometric/transforms/largest_connected_components.py +1 -1
  232. torch_geometric/transforms/mask.py +5 -1
  233. torch_geometric/transforms/node_property_split.py +1 -1
  234. torch_geometric/transforms/normalize_features.py +3 -3
  235. torch_geometric/transforms/pad.py +1 -1
  236. torch_geometric/transforms/random_link_split.py +1 -1
  237. torch_geometric/transforms/remove_duplicated_edges.py +4 -2
  238. torch_geometric/transforms/remove_self_loops.py +36 -0
  239. torch_geometric/transforms/rooted_subgraph.py +1 -1
  240. torch_geometric/transforms/svd_feature_reduction.py +1 -1
  241. torch_geometric/transforms/virtual_node.py +2 -1
  242. torch_geometric/typing.py +82 -17
  243. torch_geometric/utils/__init__.py +6 -1
  244. torch_geometric/utils/_lexsort.py +0 -9
  245. torch_geometric/utils/_negative_sampling.py +28 -13
  246. torch_geometric/utils/_normalize_edge_index.py +46 -0
  247. torch_geometric/utils/_scatter.py +126 -164
  248. torch_geometric/utils/_sort_edge_index.py +0 -2
  249. torch_geometric/utils/_spmm.py +16 -14
  250. torch_geometric/utils/_subgraph.py +4 -0
  251. torch_geometric/utils/_tree_decomposition.py +1 -1
  252. torch_geometric/utils/_trim_to_layer.py +2 -2
  253. torch_geometric/utils/augmentation.py +1 -1
  254. torch_geometric/utils/convert.py +17 -10
  255. torch_geometric/utils/cross_entropy.py +34 -13
  256. torch_geometric/utils/embedding.py +91 -2
  257. torch_geometric/utils/geodesic.py +28 -25
  258. torch_geometric/utils/influence.py +279 -0
  259. torch_geometric/utils/map.py +14 -10
  260. torch_geometric/utils/nested.py +1 -1
  261. torch_geometric/utils/smiles.py +3 -3
  262. torch_geometric/utils/sparse.py +32 -24
  263. torch_geometric/visualization/__init__.py +2 -1
  264. torch_geometric/visualization/graph.py +250 -5
  265. torch_geometric/warnings.py +11 -2
  266. torch_geometric/nn/nlp/__init__.py +0 -7
  267. torch_geometric/nn/nlp/llm.py +0 -283
  268. torch_geometric/nn/nlp/sentence_transformer.py +0 -94
@@ -0,0 +1,397 @@
1
+ from typing import List, Optional, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from tqdm import tqdm
6
+
7
+ from torch_geometric.loader import DataLoader, NeighborLoader
8
+ from torch_geometric.nn.models import GraphSAGE, basic_gnn
9
+
10
+
11
+ def deal_nan(x):
12
+ if isinstance(x, torch.Tensor):
13
+ x = x.clone()
14
+ x[torch.isnan(x)] = 0.0
15
+ return x
16
+
17
+
18
+ class GLEM(torch.nn.Module):
19
+ r"""This GNN+LM co-training model is based on GLEM from the `"Learning on
20
+ Large-scale Text-attributed Graphs via Variational Inference"
21
+ <https://arxiv.org/abs/2210.14709>`_ paper.
22
+
23
+ Args:
24
+ lm_to_use (str): A TextEncoder from huggingface model repo
25
+ with a classifier(default: TinyBERT)
26
+ gnn_to_use (torch_geometric.nn.models): (default: GraphSAGE)
27
+ out_channels (int): output channels for LM and GNN, should be same
28
+ num_gnn_heads Optional[int]: Number of heads for attention, if needed
29
+ num_gnn_layers (int): number of gnn layers
30
+ gnn_loss: loss function for gnn, (default: CrossEntropyLoss)
31
+ lm_loss: loss function for Language Model, (default: CrossEntropyLoss)
32
+ alpha (float): pseudo label weight of E-step, LM optimization,
33
+ (default: 0.5)
34
+ beta (float): pseudo label weight of M-step, GNN optimization,
35
+ (default: 0.5)
36
+ lm_dtype (torch.dtype): the data type once you load LM into memory,
37
+ (default: torch.bfloat16)
38
+ lm_use_lora (bool): choose if LM use Lora peft for fine tune,
39
+ (default: True)
40
+ lora_target_modules: The names of the target modules to apply the lora
41
+ adapter to, e.g. ['q_proj', 'v_proj'] for LLM , (default: None)
42
+
43
+ .. note::
44
+ See `examples/llm_plus_gnn/glem.py` for example usage.
45
+ """
46
+ def __init__(
47
+ self,
48
+ lm_to_use: str = 'prajjwal1/bert-tiny',
49
+ gnn_to_use: basic_gnn = GraphSAGE,
50
+ out_channels: int = 47,
51
+ gnn_loss: Optional[nn.Module] = None,
52
+ lm_loss: Optional[nn.Module] = None,
53
+ alpha: float = 0.5,
54
+ beta: float = 0.5,
55
+ lm_dtype: torch.dtype = torch.bfloat16,
56
+ lm_use_lora: bool = True,
57
+ lora_target_modules: Optional[Union[List[str], str]] = None,
58
+ device: Optional[Union[str, torch.device]] = None,
59
+ ):
60
+ super().__init__()
61
+
62
+ if gnn_loss is None:
63
+ gnn_loss = nn.CrossEntropyLoss(reduction='mean')
64
+ if lm_loss is None:
65
+ lm_loss = nn.CrossEntropyLoss(reduction='mean')
66
+ if device is None:
67
+ device = torch.device('cpu')
68
+
69
+ self.device = device
70
+ self.lm_loss = lm_loss
71
+ self.gnn = gnn_to_use
72
+ self.gnn_loss = gnn_loss
73
+ self.alpha = alpha
74
+ self.beta = beta
75
+ self.gnn_loss = gnn_loss
76
+ self.lm = lm_to_use
77
+ from transformers import AutoModelForSequenceClassification
78
+ self.lm = AutoModelForSequenceClassification.from_pretrained(
79
+ lm_to_use, num_labels=out_channels, torch_dtype=lm_dtype,
80
+ offload_folder="offload", trust_remote_code=True)
81
+ if lm_use_lora:
82
+ from peft import (
83
+ LoraConfig,
84
+ TaskType,
85
+ get_peft_model,
86
+ prepare_model_for_kbit_training,
87
+ )
88
+ print("Training LM with LORA!")
89
+ self.lm = prepare_model_for_kbit_training(self.lm)
90
+ config = LoraConfig(task_type=TaskType.SEQ_CLS, r=16,
91
+ lora_alpha=16, lora_dropout=0.05, bias="none",
92
+ target_modules=lora_target_modules)
93
+ self.lm = get_peft_model(self.lm, config)
94
+ self.lm.print_trainable_parameters()
95
+ self.lm.config.pad_token_id = self.lm.config.eos_token_id
96
+ self.lm_device = self.lm.device
97
+
98
+ if self.lm.num_labels != self.gnn.out_channels:
99
+ raise ValueError('''The output channel of language model \
100
+ and gnn should be the same''')
101
+
102
+ def pre_train_gnn(self, train_loader: NeighborLoader,
103
+ optimizer: torch.optim.Optimizer, num_epochs: int,
104
+ patience: int, ext_pseudo_labels: torch.Tensor = None,
105
+ is_augmented: bool = False, verbose: bool = True):
106
+ # Pretrain GNN, optional steps if you do not have pseudo labels.
107
+ best_acc = 0
108
+ early_stopping = 0
109
+ # training only based on gold data
110
+ for epoch in range(0, num_epochs):
111
+ acc, loss = self.train_gnn(train_loader, optimizer, epoch,
112
+ ext_pseudo_labels, is_augmented,
113
+ verbose)
114
+ if acc < best_acc:
115
+ early_stopping += 1
116
+ if early_stopping > patience:
117
+ print(f'Early stopped by Epoch: {epoch}, '
118
+ f'Best acc: {best_acc}')
119
+ break
120
+ best_acc = max(best_acc, acc)
121
+
122
+ def pre_train_lm(self, train_loader: DataLoader,
123
+ optimizer: torch.optim.Optimizer, num_epochs: int,
124
+ patience: int, ext_pseudo_labels: torch.Tensor = None,
125
+ is_augmented: bool = False, verbose: bool = True):
126
+ # Pretrain language model
127
+ best_acc = 0
128
+ early_stopping = 0
129
+ for epoch in range(1, num_epochs + 1):
130
+ acc, loss = self.train_lm(train_loader, optimizer, epoch,
131
+ ext_pseudo_labels, is_augmented, verbose)
132
+ if acc < best_acc:
133
+ early_stopping += 1
134
+ if early_stopping > patience:
135
+ print(f'Early stopped by Epoch: {epoch}, '
136
+ f'Best acc: {best_acc}')
137
+ break
138
+ best_acc = max(best_acc, acc)
139
+
140
+ def train(self, em_phase: str, train_loader: Union[DataLoader,
141
+ NeighborLoader],
142
+ optimizer: torch.optim.Optimizer, pseudo_labels: torch.Tensor,
143
+ epoch: int, is_augmented: bool = False, verbose: bool = False):
144
+ r"""GLEM training step, EM steps.
145
+
146
+ Args:
147
+ em_phase(str): 'gnn' or 'lm' choose which phase you are training on
148
+ train_loader(Union[DataLoader, NeighborLoader]): use DataLoader for
149
+ lm training, include tokenized data, labels is_gold mask.
150
+ use NeighborLoader for gnn training, include x, edge_index.
151
+ optimizer (torch.optim.Optimizer): optimizer for training
152
+ pseudo_labels(torch.Tensor): the predicted labels used as pseudo
153
+ labels
154
+ epoch (int): current epoch
155
+ is_augmented (bool): will use pseudo_labels or not
156
+ verbose (bool): print training progress bar or not
157
+
158
+ Returns:
159
+ acc (float): training accuracy
160
+ loss (float): loss value
161
+ """
162
+ if pseudo_labels is not None:
163
+ pseudo_labels = pseudo_labels.to(self.device)
164
+ if em_phase == 'gnn':
165
+ acc, loss = self.train_gnn(train_loader, optimizer, epoch,
166
+ pseudo_labels, is_augmented, verbose)
167
+ if em_phase == 'lm':
168
+ acc, loss = self.train_lm(train_loader, optimizer, epoch,
169
+ pseudo_labels, is_augmented, verbose)
170
+ return acc, loss
171
+
172
+ def train_lm(self, train_loader: DataLoader,
173
+ optimizer: torch.optim.Optimizer, epoch: int,
174
+ pseudo_labels: torch.Tensor = None,
175
+ is_augmented: bool = False, verbose: bool = True):
176
+ r"""Language model Training in every epoch.
177
+
178
+ Args:
179
+ train_loader (loader.dataloader.DataLoader): text token dataloader
180
+ optimizer (torch.optim.Optimizer): model optimizer
181
+ epoch (int): current train epoch
182
+ pseudo_labels (torch.Tensor): 1-D tensor, predictions from gnn
183
+ is_augmented (bool): train with pseudo labels or not
184
+ verbose (bool): print training progress bar or not
185
+
186
+ Returns:
187
+ approx_acc (torch.tensor): training accuracy
188
+ loss (torch.float): loss value
189
+
190
+ """
191
+ all_out = []
192
+ total_loss = total_correct = 0
193
+ num_nodes = train_loader.dataset.indices.size(0)
194
+ self.lm.train()
195
+ if verbose:
196
+ pbar = tqdm(total=num_nodes)
197
+ pbar.set_description(f'Epoch {epoch:02d}')
198
+ for batch in train_loader:
199
+ inputs = {k: v.to(self.device) for k, v in batch['input'].items()}
200
+ out = self.lm(**inputs).logits
201
+ labels = batch['labels'].to(self.device).squeeze()
202
+ # training with pseudo labels or not
203
+ if is_augmented:
204
+ pl_batch = pseudo_labels[batch['n_id']].to(self.device)
205
+ else:
206
+ pl_batch = None
207
+ loss = self.loss(out, labels, self.lm_loss,
208
+ batch['is_gold'].to(self.device), pl_batch,
209
+ self.alpha, is_augmented)
210
+ loss.backward()
211
+ optimizer.step()
212
+ optimizer.zero_grad()
213
+ all_out.append(out)
214
+ total_correct += int(out.argmax(dim=-1).eq(labels).sum())
215
+ total_loss += float(loss.detach())
216
+ if verbose:
217
+ pbar.update(batch['n_id'].size(0))
218
+
219
+ all_out = torch.cat(all_out, dim=0)
220
+ approx_acc = total_correct / num_nodes
221
+ loss = total_loss / len(train_loader)
222
+ if verbose:
223
+ pbar.close()
224
+ print(f'Epoch {epoch:02d} Loss: {loss:.4f} '
225
+ f'Approx. Train: {approx_acc:.4f}')
226
+ return approx_acc, loss
227
+
228
+ def train_gnn(self, train_loader: NeighborLoader,
229
+ optimizer: torch.optim.Optimizer, epoch: int,
230
+ pseudo_labels: torch.Tensor = None,
231
+ is_augmented: bool = False, verbose: bool = True):
232
+ r"""GNN training step in every epoch.
233
+
234
+ Args:
235
+ train_loader (loader.NeighborLoader): gnn Neighbor node loader
236
+ optimizer (torch.optim.Optimizer): model optimizer
237
+ epoch (int): current train epoch
238
+ pseudo_labels(torch.tensor): 1-D tensor, predictions from lm
239
+ is_augmented(bool): use pseudo labeled node or not
240
+ verbose (bool): print training progress or not
241
+
242
+ Returns:
243
+ approx_acc (torch.tensor): training accuracy
244
+ loss (torch.float): loss value
245
+ """
246
+ self.gnn.train()
247
+ num_nodes = train_loader.input_nodes.size(0)
248
+ if verbose:
249
+ pbar = tqdm(total=num_nodes)
250
+ pbar.set_description(f'Epoch {epoch:02d}')
251
+ total_loss = total_correct = 0
252
+ all_out = []
253
+ for batch in train_loader:
254
+ batch = batch.to(self.device)
255
+ out = self.gnn(batch.x, batch.edge_index)[:batch.batch_size]
256
+ all_out.append(out)
257
+ labels = batch.y[:batch.batch_size].squeeze()
258
+ is_gold_batch = batch.is_gold[:batch.batch_size].squeeze()
259
+ # training with pseudo labels or not
260
+ if is_augmented and pseudo_labels is not None:
261
+ pl_batch = pseudo_labels[batch.n_id[:batch.batch_size]]
262
+ else:
263
+ pl_batch = None
264
+ loss = self.loss(out, labels, self.gnn_loss, is_gold_batch,
265
+ pl_batch, self.beta, is_augmented)
266
+ loss.backward()
267
+ optimizer.step()
268
+ optimizer.zero_grad()
269
+ total_loss += float(loss.detach())
270
+ total_correct += int(out.argmax(dim=-1).eq(labels).sum())
271
+ if verbose:
272
+ pbar.update(batch.batch_size)
273
+
274
+ all_out = torch.cat(all_out, dim=0)
275
+ loss = total_loss / len(train_loader)
276
+ approx_acc = total_correct / num_nodes
277
+ if verbose:
278
+ pbar.close()
279
+ print(f'Epoch: {epoch:02d} Loss: {loss:.4f} '
280
+ f'Approx. Train: {approx_acc:.4f}')
281
+ return approx_acc, loss
282
+
283
+ @torch.no_grad()
284
+ def inference(self, em_phase: str, data_loader: Union[NeighborLoader,
285
+ DataLoader],
286
+ verbose: bool = False):
287
+ r"""GLEM inference step.
288
+
289
+ Args:
290
+ em_phase(str): 'gnn' or 'lm'
291
+ data_loader(dataloader or Neighborloader):
292
+ dataloader: for lm training, include tokenized data
293
+ nodeloader: for gnn training, include x, edge_index
294
+ verbose(bool): print inference progress or not
295
+
296
+ Returns:
297
+ out (torch.Tensor): n * m tensor, m is number of classes,
298
+ n is number of nodes
299
+ """
300
+ out = None
301
+ if em_phase == 'gnn':
302
+ self.gnn.eval()
303
+ out = self.inference_gnn(data_loader, verbose)
304
+ elif em_phase == 'lm':
305
+ self.lm.eval()
306
+ out = self.inference_lm(data_loader, verbose)
307
+ return out
308
+
309
+ @torch.no_grad()
310
+ def inference_lm(self, data_loader: DataLoader, verbose: bool = True):
311
+ r"""LM inference step.
312
+
313
+ Args:
314
+ data_loader (Dataloader): include token, labels, and gold mask
315
+ verbose (bool): print progress bar or not
316
+
317
+ Returns:
318
+ preds (tensor): prediction from GNN, convert to pseudo labels
319
+ by preds.argmax(dim=-1).unsqueeze(1)
320
+ """
321
+ if verbose:
322
+ pbar = tqdm(total=data_loader.dataset._data.num_nodes)
323
+ pbar.set_description('LM inference stage')
324
+ self.lm.eval()
325
+ preds = []
326
+ for batch in data_loader:
327
+ inputs = {k: v.to(self.device) for k, v in batch['input'].items()}
328
+ logits = self.lm(**inputs).logits
329
+ preds.append(logits)
330
+ if verbose:
331
+ pbar.update(batch['n_id'].size(0))
332
+ if verbose:
333
+ pbar.close()
334
+ preds = torch.cat(preds)
335
+ return preds
336
+
337
+ @torch.no_grad()
338
+ def inference_gnn(self, data_loader: NeighborLoader, verbose: bool = True):
339
+ r"""GNN inference step.
340
+
341
+ Args:
342
+ data_loader(NeighborLoader): include x, edge_index,
343
+ verbose (bool): print progress bar or not
344
+
345
+ Returns:
346
+ preds (tensor): prediction from GNN,
347
+ convert to pseudo labels by preds.argmax(dim=-1).unsqueeze(1)
348
+ """
349
+ if verbose:
350
+ pbar = tqdm(total=data_loader.data.num_nodes)
351
+ pbar.set_description('GNN inference stage')
352
+ preds = []
353
+ self.gnn.eval()
354
+ for batch in data_loader:
355
+ batch = batch.to(self.device)
356
+ out = self.gnn(batch.x, batch.edge_index)[:batch.batch_size]
357
+ preds.append(out)
358
+ if verbose:
359
+ pbar.update(batch.batch_size)
360
+ if verbose:
361
+ pbar.close()
362
+ preds = torch.cat(preds, dim=0)
363
+ return preds
364
+
365
+ def loss(self, logits: torch.Tensor, labels: torch.Tensor,
366
+ loss_func: torch.nn.functional, is_gold: torch.Tensor,
367
+ pseudo_labels: torch.Tensor = None, pl_weight: float = 0.5,
368
+ is_augmented: bool = True):
369
+ r"""Core function of variational EM inference, this function is aming
370
+ on combining loss value on gold(original train) and loss value on
371
+ pseudo labels.
372
+
373
+ Reference:
374
+ <https://github.com/AndyJZhao/GLEM/blob/main/src/models/GLEM/GLEM_utils.py> # noqa
375
+
376
+ Args:
377
+ logits(torch.tensor): predict results from LM or GNN
378
+ labels(torch.tensor): combined node labels from ground truth and
379
+ pseudo labels(if provided)
380
+ loss_func(torch.nn.modules.loss): loss function for classification
381
+ is_gold(tensor): a tensor with bool value that mask ground truth
382
+ label and during training, thus ~is_gold mask pseudo labels
383
+ pseudo_labels(torch.tensor): predictions from other model
384
+ pl_weight: the pseudo labels used in E-step and M-step optimization
385
+ alpha in E-step, beta in M-step respectively
386
+ is_augmented: use EM or just train GNN and LM with gold data
387
+
388
+ """
389
+ if is_augmented and (sum(~is_gold) > 0):
390
+ mle_loss = deal_nan(loss_func(logits[is_gold], labels[is_gold]))
391
+ # all other labels beside from ground truth(gold labels)
392
+ pseudo_label_loss = deal_nan(
393
+ loss_func(logits[~is_gold], pseudo_labels[~is_gold]))
394
+ loss = pl_weight * pseudo_label_loss + (1 - pl_weight) * mle_loss
395
+ else:
396
+ loss = loss_func(logits, labels)
397
+ return loss