pyg-nightly 2.6.0.dev20240319__py3-none-any.whl → 2.7.0.dev20250114__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (226) hide show
  1. {pyg_nightly-2.6.0.dev20240319.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/METADATA +31 -47
  2. {pyg_nightly-2.6.0.dev20240319.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/RECORD +226 -199
  3. {pyg_nightly-2.6.0.dev20240319.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/WHEEL +1 -1
  4. torch_geometric/__init__.py +28 -1
  5. torch_geometric/_compile.py +8 -1
  6. torch_geometric/_onnx.py +14 -0
  7. torch_geometric/config_mixin.py +113 -0
  8. torch_geometric/config_store.py +28 -19
  9. torch_geometric/data/__init__.py +24 -1
  10. torch_geometric/data/batch.py +2 -2
  11. torch_geometric/data/collate.py +8 -2
  12. torch_geometric/data/data.py +16 -8
  13. torch_geometric/data/database.py +61 -15
  14. torch_geometric/data/dataset.py +14 -6
  15. torch_geometric/data/feature_store.py +25 -42
  16. torch_geometric/data/graph_store.py +1 -5
  17. torch_geometric/data/hetero_data.py +18 -9
  18. torch_geometric/data/in_memory_dataset.py +2 -4
  19. torch_geometric/data/large_graph_indexer.py +677 -0
  20. torch_geometric/data/lightning/datamodule.py +4 -4
  21. torch_geometric/data/separate.py +6 -1
  22. torch_geometric/data/storage.py +17 -7
  23. torch_geometric/data/summary.py +14 -4
  24. torch_geometric/data/temporal.py +1 -2
  25. torch_geometric/datasets/__init__.py +17 -2
  26. torch_geometric/datasets/actor.py +9 -11
  27. torch_geometric/datasets/airfrans.py +15 -18
  28. torch_geometric/datasets/airports.py +10 -12
  29. torch_geometric/datasets/amazon.py +8 -11
  30. torch_geometric/datasets/amazon_book.py +9 -10
  31. torch_geometric/datasets/amazon_products.py +9 -10
  32. torch_geometric/datasets/aminer.py +8 -9
  33. torch_geometric/datasets/aqsol.py +10 -13
  34. torch_geometric/datasets/attributed_graph_dataset.py +10 -12
  35. torch_geometric/datasets/ba_multi_shapes.py +10 -12
  36. torch_geometric/datasets/ba_shapes.py +5 -6
  37. torch_geometric/datasets/bitcoin_otc.py +1 -1
  38. torch_geometric/datasets/brca_tgca.py +1 -1
  39. torch_geometric/datasets/cornell.py +145 -0
  40. torch_geometric/datasets/dblp.py +2 -1
  41. torch_geometric/datasets/dbp15k.py +2 -2
  42. torch_geometric/datasets/fake.py +1 -3
  43. torch_geometric/datasets/flickr.py +2 -1
  44. torch_geometric/datasets/freebase.py +1 -1
  45. torch_geometric/datasets/gdelt_lite.py +3 -2
  46. torch_geometric/datasets/ged_dataset.py +3 -2
  47. torch_geometric/datasets/git_mol_dataset.py +263 -0
  48. torch_geometric/datasets/gnn_benchmark_dataset.py +11 -10
  49. torch_geometric/datasets/hgb_dataset.py +8 -8
  50. torch_geometric/datasets/imdb.py +2 -1
  51. torch_geometric/datasets/karate.py +3 -2
  52. torch_geometric/datasets/last_fm.py +2 -1
  53. torch_geometric/datasets/linkx_dataset.py +4 -3
  54. torch_geometric/datasets/lrgb.py +3 -5
  55. torch_geometric/datasets/malnet_tiny.py +4 -3
  56. torch_geometric/datasets/mnist_superpixels.py +2 -3
  57. torch_geometric/datasets/molecule_gpt_dataset.py +485 -0
  58. torch_geometric/datasets/molecule_net.py +15 -3
  59. torch_geometric/datasets/motif_generator/base.py +0 -1
  60. torch_geometric/datasets/neurograph.py +1 -3
  61. torch_geometric/datasets/ogb_mag.py +1 -1
  62. torch_geometric/datasets/opf.py +239 -0
  63. torch_geometric/datasets/ose_gvcs.py +1 -1
  64. torch_geometric/datasets/pascal.py +11 -9
  65. torch_geometric/datasets/pascal_pf.py +1 -1
  66. torch_geometric/datasets/pcpnet_dataset.py +1 -1
  67. torch_geometric/datasets/pcqm4m.py +10 -3
  68. torch_geometric/datasets/ppi.py +1 -1
  69. torch_geometric/datasets/qm9.py +8 -7
  70. torch_geometric/datasets/rcdd.py +4 -4
  71. torch_geometric/datasets/reddit.py +2 -1
  72. torch_geometric/datasets/reddit2.py +2 -1
  73. torch_geometric/datasets/rel_link_pred_dataset.py +3 -3
  74. torch_geometric/datasets/s3dis.py +5 -3
  75. torch_geometric/datasets/shapenet.py +3 -3
  76. torch_geometric/datasets/shrec2016.py +2 -2
  77. torch_geometric/datasets/snap_dataset.py +7 -1
  78. torch_geometric/datasets/tag_dataset.py +350 -0
  79. torch_geometric/datasets/upfd.py +2 -1
  80. torch_geometric/datasets/web_qsp_dataset.py +246 -0
  81. torch_geometric/datasets/webkb.py +2 -2
  82. torch_geometric/datasets/wikics.py +1 -1
  83. torch_geometric/datasets/wikidata.py +3 -2
  84. torch_geometric/datasets/wikipedia_network.py +2 -2
  85. torch_geometric/datasets/willow_object_class.py +1 -1
  86. torch_geometric/datasets/word_net.py +2 -2
  87. torch_geometric/datasets/yelp.py +2 -1
  88. torch_geometric/datasets/zinc.py +1 -1
  89. torch_geometric/device.py +42 -0
  90. torch_geometric/distributed/local_feature_store.py +3 -2
  91. torch_geometric/distributed/local_graph_store.py +2 -1
  92. torch_geometric/distributed/partition.py +9 -8
  93. torch_geometric/edge_index.py +616 -438
  94. torch_geometric/explain/algorithm/base.py +0 -1
  95. torch_geometric/explain/algorithm/graphmask_explainer.py +1 -2
  96. torch_geometric/explain/algorithm/pg_explainer.py +1 -1
  97. torch_geometric/explain/explanation.py +2 -2
  98. torch_geometric/graphgym/checkpoint.py +2 -1
  99. torch_geometric/graphgym/logger.py +4 -4
  100. torch_geometric/graphgym/loss.py +1 -1
  101. torch_geometric/graphgym/utils/agg_runs.py +6 -6
  102. torch_geometric/index.py +826 -0
  103. torch_geometric/inspector.py +8 -3
  104. torch_geometric/io/fs.py +28 -2
  105. torch_geometric/io/npz.py +2 -1
  106. torch_geometric/io/off.py +2 -2
  107. torch_geometric/io/sdf.py +2 -2
  108. torch_geometric/io/tu.py +4 -5
  109. torch_geometric/loader/__init__.py +4 -0
  110. torch_geometric/loader/cluster.py +10 -4
  111. torch_geometric/loader/graph_saint.py +2 -1
  112. torch_geometric/loader/ibmb_loader.py +12 -4
  113. torch_geometric/loader/mixin.py +1 -1
  114. torch_geometric/loader/neighbor_loader.py +1 -1
  115. torch_geometric/loader/neighbor_sampler.py +2 -2
  116. torch_geometric/loader/prefetch.py +1 -1
  117. torch_geometric/loader/rag_loader.py +107 -0
  118. torch_geometric/loader/utils.py +8 -7
  119. torch_geometric/loader/zip_loader.py +10 -0
  120. torch_geometric/metrics/__init__.py +11 -2
  121. torch_geometric/metrics/link_pred.py +159 -34
  122. torch_geometric/nn/aggr/__init__.py +4 -0
  123. torch_geometric/nn/aggr/attention.py +0 -2
  124. torch_geometric/nn/aggr/base.py +2 -4
  125. torch_geometric/nn/aggr/patch_transformer.py +143 -0
  126. torch_geometric/nn/aggr/set_transformer.py +1 -1
  127. torch_geometric/nn/aggr/variance_preserving.py +33 -0
  128. torch_geometric/nn/attention/__init__.py +5 -1
  129. torch_geometric/nn/attention/qformer.py +71 -0
  130. torch_geometric/nn/conv/collect.jinja +7 -4
  131. torch_geometric/nn/conv/cugraph/base.py +8 -12
  132. torch_geometric/nn/conv/edge_conv.py +3 -2
  133. torch_geometric/nn/conv/fused_gat_conv.py +1 -1
  134. torch_geometric/nn/conv/gat_conv.py +35 -7
  135. torch_geometric/nn/conv/gatv2_conv.py +36 -6
  136. torch_geometric/nn/conv/general_conv.py +1 -1
  137. torch_geometric/nn/conv/graph_conv.py +21 -3
  138. torch_geometric/nn/conv/gravnet_conv.py +3 -2
  139. torch_geometric/nn/conv/hetero_conv.py +3 -3
  140. torch_geometric/nn/conv/hgt_conv.py +1 -1
  141. torch_geometric/nn/conv/message_passing.py +138 -87
  142. torch_geometric/nn/conv/mixhop_conv.py +1 -1
  143. torch_geometric/nn/conv/propagate.jinja +9 -1
  144. torch_geometric/nn/conv/rgcn_conv.py +5 -5
  145. torch_geometric/nn/conv/spline_conv.py +4 -4
  146. torch_geometric/nn/conv/x_conv.py +3 -2
  147. torch_geometric/nn/dense/linear.py +11 -6
  148. torch_geometric/nn/fx.py +3 -3
  149. torch_geometric/nn/model_hub.py +3 -1
  150. torch_geometric/nn/models/__init__.py +10 -2
  151. torch_geometric/nn/models/deep_graph_infomax.py +1 -2
  152. torch_geometric/nn/models/dimenet_utils.py +5 -7
  153. torch_geometric/nn/models/g_retriever.py +230 -0
  154. torch_geometric/nn/models/git_mol.py +336 -0
  155. torch_geometric/nn/models/glem.py +385 -0
  156. torch_geometric/nn/models/gnnff.py +0 -1
  157. torch_geometric/nn/models/graph_unet.py +12 -3
  158. torch_geometric/nn/models/jumping_knowledge.py +63 -4
  159. torch_geometric/nn/models/lightgcn.py +1 -1
  160. torch_geometric/nn/models/metapath2vec.py +5 -5
  161. torch_geometric/nn/models/molecule_gpt.py +222 -0
  162. torch_geometric/nn/models/node2vec.py +2 -3
  163. torch_geometric/nn/models/schnet.py +2 -1
  164. torch_geometric/nn/models/signed_gcn.py +3 -3
  165. torch_geometric/nn/module_dict.py +2 -2
  166. torch_geometric/nn/nlp/__init__.py +9 -0
  167. torch_geometric/nn/nlp/llm.py +322 -0
  168. torch_geometric/nn/nlp/sentence_transformer.py +134 -0
  169. torch_geometric/nn/nlp/vision_transformer.py +33 -0
  170. torch_geometric/nn/norm/batch_norm.py +1 -1
  171. torch_geometric/nn/parameter_dict.py +2 -2
  172. torch_geometric/nn/pool/__init__.py +21 -5
  173. torch_geometric/nn/pool/cluster_pool.py +145 -0
  174. torch_geometric/nn/pool/connect/base.py +0 -1
  175. torch_geometric/nn/pool/edge_pool.py +1 -1
  176. torch_geometric/nn/pool/graclus.py +4 -2
  177. torch_geometric/nn/pool/pool.py +8 -2
  178. torch_geometric/nn/pool/select/base.py +0 -1
  179. torch_geometric/nn/pool/voxel_grid.py +3 -2
  180. torch_geometric/nn/resolver.py +1 -1
  181. torch_geometric/nn/sequential.jinja +10 -23
  182. torch_geometric/nn/sequential.py +204 -78
  183. torch_geometric/nn/summary.py +1 -1
  184. torch_geometric/nn/to_hetero_with_bases_transformer.py +19 -19
  185. torch_geometric/profile/__init__.py +2 -0
  186. torch_geometric/profile/nvtx.py +66 -0
  187. torch_geometric/profile/profiler.py +30 -19
  188. torch_geometric/resolver.py +1 -1
  189. torch_geometric/sampler/base.py +34 -13
  190. torch_geometric/sampler/neighbor_sampler.py +11 -10
  191. torch_geometric/sampler/utils.py +1 -1
  192. torch_geometric/template.py +1 -0
  193. torch_geometric/testing/__init__.py +6 -2
  194. torch_geometric/testing/decorators.py +53 -20
  195. torch_geometric/testing/feature_store.py +1 -1
  196. torch_geometric/transforms/__init__.py +2 -0
  197. torch_geometric/transforms/add_metapaths.py +5 -5
  198. torch_geometric/transforms/add_positional_encoding.py +1 -1
  199. torch_geometric/transforms/delaunay.py +65 -14
  200. torch_geometric/transforms/face_to_edge.py +32 -3
  201. torch_geometric/transforms/gdc.py +7 -6
  202. torch_geometric/transforms/laplacian_lambda_max.py +3 -3
  203. torch_geometric/transforms/mask.py +5 -1
  204. torch_geometric/transforms/node_property_split.py +1 -2
  205. torch_geometric/transforms/pad.py +7 -6
  206. torch_geometric/transforms/random_link_split.py +1 -1
  207. torch_geometric/transforms/remove_self_loops.py +36 -0
  208. torch_geometric/transforms/svd_feature_reduction.py +1 -1
  209. torch_geometric/transforms/to_sparse_tensor.py +1 -1
  210. torch_geometric/transforms/two_hop.py +1 -1
  211. torch_geometric/transforms/virtual_node.py +2 -1
  212. torch_geometric/typing.py +43 -6
  213. torch_geometric/utils/__init__.py +5 -1
  214. torch_geometric/utils/_negative_sampling.py +1 -1
  215. torch_geometric/utils/_normalize_edge_index.py +46 -0
  216. torch_geometric/utils/_scatter.py +38 -12
  217. torch_geometric/utils/_subgraph.py +4 -0
  218. torch_geometric/utils/_tree_decomposition.py +2 -2
  219. torch_geometric/utils/augmentation.py +1 -1
  220. torch_geometric/utils/convert.py +12 -8
  221. torch_geometric/utils/geodesic.py +24 -22
  222. torch_geometric/utils/hetero.py +1 -1
  223. torch_geometric/utils/map.py +8 -2
  224. torch_geometric/utils/smiles.py +65 -27
  225. torch_geometric/utils/sparse.py +39 -25
  226. torch_geometric/visualization/graph.py +3 -4
@@ -0,0 +1,385 @@
1
+ from typing import List, Optional, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from tqdm import tqdm
6
+
7
+ from torch_geometric.loader import DataLoader, NeighborLoader
8
+ from torch_geometric.nn.models import GraphSAGE, basic_gnn
9
+
10
+
11
+ class GLEM(torch.nn.Module):
12
+ r"""This GNN+LM co-training model is based on GLEM from the `"Learning on
13
+ Large-scale Text-attributed Graphs via Variational Inference"
14
+ <https://arxiv.org/abs/2210.14709>`_ paper.
15
+
16
+ Args:
17
+ lm_to_use (str): A TextEncoder from huggingface model repo
18
+ with a classifier(default: TinyBERT)
19
+ gnn_to_use (torch_geometric.nn.models): (default: GraphSAGE)
20
+ out_channels (int): output channels for LM and GNN, should be same
21
+ num_gnn_heads Optional[int]: Number of heads for attention, if needed
22
+ num_gnn_layers (int): number of gnn layers
23
+ gnn_loss: loss function for gnn, (default: CrossEntropyLoss)
24
+ lm_loss: loss function for Language Model, (default: CrossEntropyLoss)
25
+ alpha (float): pseudo label weight of E-step, LM optimization,
26
+ (default: 0.5)
27
+ beta (float): pseudo label weight of M-step, GNN optimization,
28
+ (default: 0.5)
29
+ lm_dtype (torch.dtype): the data type once you load LM into memory,
30
+ (default: torch.bfloat16)
31
+ lm_use_lora (bool): choose if LM use Lora peft for fine tune,
32
+ (default: True)
33
+ lora_target_modules: The names of the target modules to apply the lora
34
+ adapter to, e.g. ['q_proj', 'v_proj'] for LLM , (default: None)
35
+
36
+ .. note::
37
+ See `examples/llm_plus_gnn/glem.py` for example usage.
38
+ """
39
+ def __init__(
40
+ self,
41
+ lm_to_use: str = 'prajjwal1/bert-tiny',
42
+ gnn_to_use: basic_gnn = GraphSAGE,
43
+ out_channels: int = 47,
44
+ gnn_loss=nn.CrossEntropyLoss(reduction='mean'),
45
+ lm_loss=nn.CrossEntropyLoss(reduction='mean'),
46
+ alpha: float = 0.5,
47
+ beta: float = 0.5,
48
+ lm_dtype: torch.dtype = torch.bfloat16,
49
+ lm_use_lora: bool = True,
50
+ lora_target_modules: Optional[Union[List[str], str]] = None,
51
+ device: Union[str, torch.device] = torch.device('cpu'),
52
+ ):
53
+ super().__init__()
54
+ self.device = device
55
+ self.lm_loss = lm_loss
56
+ self.gnn = gnn_to_use
57
+ self.gnn_loss = gnn_loss
58
+ self.alpha = alpha
59
+ self.beta = beta
60
+ self.gnn_loss = gnn_loss
61
+ self.lm = lm_to_use
62
+ from transformers import AutoModelForSequenceClassification
63
+ self.lm = AutoModelForSequenceClassification.from_pretrained(
64
+ lm_to_use, num_labels=out_channels, torch_dtype=lm_dtype,
65
+ offload_folder="offload", trust_remote_code=True)
66
+ if lm_use_lora:
67
+ from peft import (
68
+ LoraConfig,
69
+ TaskType,
70
+ get_peft_model,
71
+ prepare_model_for_kbit_training,
72
+ )
73
+ print("Training LM with LORA!")
74
+ self.lm = prepare_model_for_kbit_training(self.lm)
75
+ config = LoraConfig(task_type=TaskType.SEQ_CLS, r=16,
76
+ lora_alpha=16, lora_dropout=0.05, bias="none",
77
+ target_modules=lora_target_modules)
78
+ self.lm = get_peft_model(self.lm, config)
79
+ self.lm.print_trainable_parameters()
80
+ self.lm.config.pad_token_id = self.lm.config.eos_token_id
81
+ self.lm_device = self.lm.device
82
+
83
+ if self.lm.num_labels != self.gnn.out_channels:
84
+ raise ValueError('''The output channel of language model \
85
+ and gnn should be the same''')
86
+
87
+ def pre_train_gnn(self, train_loader: NeighborLoader,
88
+ optimizer: torch.optim.Optimizer, num_epochs: int,
89
+ patience: int, ext_pseudo_labels: torch.Tensor = None,
90
+ is_augmented: bool = False, verbose: bool = True):
91
+ # Pretrain GNN, optional steps if you do not have pseudo labels.
92
+ best_acc = 0
93
+ early_stopping = 0
94
+ # training only based on gold data
95
+ for epoch in range(0, num_epochs):
96
+ acc, loss = self.train_gnn(train_loader, optimizer, epoch,
97
+ ext_pseudo_labels, is_augmented,
98
+ verbose)
99
+ if acc < best_acc:
100
+ early_stopping += 1
101
+ if early_stopping > patience:
102
+ print(f'Early stopped by Epoch: {epoch}, '
103
+ f'Best acc: {best_acc}')
104
+ break
105
+ best_acc = max(best_acc, acc)
106
+
107
+ def pre_train_lm(self, train_loader: DataLoader,
108
+ optimizer: torch.optim.Optimizer, num_epochs: int,
109
+ patience: int, ext_pseudo_labels: torch.Tensor = None,
110
+ is_augmented: bool = False, verbose: bool = True):
111
+ # Pretrain language model
112
+ best_acc = 0
113
+ early_stopping = 0
114
+ for epoch in range(1, num_epochs + 1):
115
+ acc, loss = self.train_lm(train_loader, optimizer, epoch,
116
+ ext_pseudo_labels, is_augmented, verbose)
117
+ if acc < best_acc:
118
+ early_stopping += 1
119
+ if early_stopping > patience:
120
+ print(f'Early stopped by Epoch: {epoch}, '
121
+ f'Best acc: {best_acc}')
122
+ break
123
+ best_acc = max(best_acc, acc)
124
+
125
+ def train(self, em_phase: str, train_loader: Union[DataLoader,
126
+ NeighborLoader],
127
+ optimizer: torch.optim.Optimizer, pseudo_labels: torch.Tensor,
128
+ epoch: int, is_augmented: bool = False, verbose: bool = False):
129
+ r"""GLEM training step, EM steps.
130
+
131
+ Args:
132
+ em_phase(str): 'gnn' or 'lm' choose which phase you are training on
133
+ train_loader(Union[DataLoader, NeighborLoader]): use DataLoader for
134
+ lm training, include tokenized data, labels is_gold mask.
135
+ use NeighborLoader for gnn training, include x, edge_index.
136
+ optimizer (torch.optim.Optimizer): optimizer for training
137
+ pseudo_labels(torch.Tensor): the predicted labels used as pseudo
138
+ labels
139
+ epoch (int): current epoch
140
+ is_augmented (bool): will use pseudo_labels or not
141
+ verbose (bool): print training progress bar or not
142
+
143
+ Returns:
144
+ acc (float): training accuracy
145
+ loss (float): loss value
146
+ """
147
+ if pseudo_labels is not None:
148
+ pseudo_labels = pseudo_labels.to(self.device)
149
+ if em_phase == 'gnn':
150
+ acc, loss = self.train_gnn(train_loader, optimizer, epoch,
151
+ pseudo_labels, is_augmented, verbose)
152
+ if em_phase == 'lm':
153
+ acc, loss = self.train_lm(train_loader, optimizer, epoch,
154
+ pseudo_labels, is_augmented, verbose)
155
+ return acc, loss
156
+
157
+ def train_lm(self, train_loader: DataLoader,
158
+ optimizer: torch.optim.Optimizer, epoch: int,
159
+ pseudo_labels: torch.Tensor = None,
160
+ is_augmented: bool = False, verbose: bool = True):
161
+ r"""Language model Training in every epoch.
162
+
163
+ Args:
164
+ train_loader (loader.dataloader.DataLoader): text token dataloader
165
+ optimizer (torch.optim.Optimizer): model optimizer
166
+ epoch (int): current train epoch
167
+ pseudo_labels (torch.Tensor): 1-D tensor, predictions from gnn
168
+ is_augmented (bool): train with pseudo labels or not
169
+ verbose (bool): print training progress bar or not
170
+
171
+ Returns:
172
+ approx_acc (torch.tensor): training accuracy
173
+ loss (torch.float): loss value
174
+
175
+ """
176
+ all_out = []
177
+ total_loss = total_correct = 0
178
+ num_nodes = train_loader.dataset.indices.size(0)
179
+ self.lm.train()
180
+ if verbose:
181
+ pbar = tqdm(total=num_nodes)
182
+ pbar.set_description(f'Epoch {epoch:02d}')
183
+ for batch in train_loader:
184
+ inputs = {k: v.to(self.device) for k, v in batch['input'].items()}
185
+ out = self.lm(**inputs).logits
186
+ labels = batch['labels'].to(self.device).squeeze()
187
+ # training with pseudo labels or not
188
+ if is_augmented:
189
+ pl_batch = pseudo_labels[batch['n_id']].to(self.device)
190
+ else:
191
+ pl_batch = None
192
+ loss = self.loss(out, labels, self.lm_loss,
193
+ batch['is_gold'].to(self.device), pl_batch,
194
+ self.alpha, is_augmented)
195
+ loss.backward()
196
+ optimizer.step()
197
+ optimizer.zero_grad()
198
+ all_out.append(out)
199
+ total_correct += int(out.argmax(dim=-1).eq(labels).sum())
200
+ total_loss += float(loss)
201
+ if verbose:
202
+ pbar.update(batch['n_id'].size(0))
203
+
204
+ all_out = torch.cat(all_out, dim=0)
205
+ approx_acc = total_correct / num_nodes
206
+ loss = total_loss / len(train_loader)
207
+ if verbose:
208
+ pbar.close()
209
+ print(f'Epoch {epoch:02d} Loss: {loss:.4f} '
210
+ f'Approx. Train: {approx_acc:.4f}')
211
+ return approx_acc, loss
212
+
213
+ def train_gnn(self, train_loader: NeighborLoader,
214
+ optimizer: torch.optim.Optimizer, epoch: int,
215
+ pseudo_labels: torch.Tensor = None,
216
+ is_augmented: bool = False, verbose: bool = True):
217
+ r"""GNN training step in every epoch.
218
+
219
+ Args:
220
+ train_loader (loader.NeighborLoader): gnn Neighbor node loader
221
+ optimizer (torch.optim.Optimizer): model optimizer
222
+ epoch (int): current train epoch
223
+ pseudo_labels(torch.tensor): 1-D tensor, predictions from lm
224
+ is_augmented(bool): use pseudo labeled node or not
225
+ verbose (bool): print training progress or not
226
+
227
+ Returns:
228
+ approx_acc (torch.tensor): training accuracy
229
+ loss (torch.float): loss value
230
+ """
231
+ self.gnn.train()
232
+ num_nodes = train_loader.input_nodes.size(0)
233
+ if verbose:
234
+ pbar = tqdm(total=num_nodes)
235
+ pbar.set_description(f'Epoch {epoch:02d}')
236
+ total_loss = total_correct = 0
237
+ all_out = []
238
+ for batch in train_loader:
239
+ batch = batch.to(self.device)
240
+ out = self.gnn(batch.x, batch.edge_index)[:batch.batch_size]
241
+ all_out.append(out)
242
+ labels = batch.y[:batch.batch_size].squeeze()
243
+ is_gold_batch = batch.is_gold[:batch.batch_size].squeeze()
244
+ # training with pseudo labels or not
245
+ if is_augmented and pseudo_labels is not None:
246
+ pl_batch = pseudo_labels[batch.n_id[:batch.batch_size]]
247
+ else:
248
+ pl_batch = None
249
+ loss = self.loss(out, labels, self.gnn_loss, is_gold_batch,
250
+ pl_batch, self.beta, is_augmented)
251
+ loss.backward()
252
+ optimizer.step()
253
+ optimizer.zero_grad()
254
+ total_loss += float(loss)
255
+ total_correct += int(out.argmax(dim=-1).eq(labels).sum())
256
+ if verbose:
257
+ pbar.update(batch.batch_size)
258
+
259
+ all_out = torch.cat(all_out, dim=0)
260
+ loss = total_loss / len(train_loader)
261
+ approx_acc = total_correct / num_nodes
262
+ if verbose:
263
+ pbar.close()
264
+ print(f'Epoch: {epoch:02d} Loss: {loss:.4f} '
265
+ f'Approx. Train: {approx_acc:.4f}')
266
+ return approx_acc, loss
267
+
268
+ @torch.no_grad()
269
+ def inference(self, em_phase: str, data_loader: Union[NeighborLoader,
270
+ DataLoader],
271
+ verbose: bool = False):
272
+ r"""GLEM inference step.
273
+
274
+ Args:
275
+ em_phase(str): 'gnn' or 'lm'
276
+ data_loader(dataloader or Neighborloader):
277
+ dataloader: for lm training, include tokenized data
278
+ nodeloader: for gnn training, include x, edge_index
279
+ verbose(bool): print inference progress or not
280
+
281
+ Returns:
282
+ out (torch.Tensor): n * m tensor, m is number of classes,
283
+ n is number of nodes
284
+ """
285
+ out = None
286
+ if em_phase == 'gnn':
287
+ self.gnn.eval()
288
+ out = self.inference_gnn(data_loader, verbose)
289
+ elif em_phase == 'lm':
290
+ self.lm.eval()
291
+ out = self.inference_lm(data_loader, verbose)
292
+ return out
293
+
294
+ @torch.no_grad()
295
+ def inference_lm(self, data_loader: DataLoader, verbose: bool = True):
296
+ r"""LM inference step.
297
+
298
+ Args:
299
+ data_loader (Dataloader): include token, labels, and gold mask
300
+ verbose (bool): print progress bar or not
301
+
302
+ Returns:
303
+ preds (tensor): prediction from GNN, convert to pseudo labels
304
+ by preds.argmax(dim=-1).unsqueeze(1)
305
+ """
306
+ if verbose:
307
+ pbar = tqdm(total=data_loader.dataset._data.num_nodes)
308
+ pbar.set_description('LM inference stage')
309
+ self.lm.eval()
310
+ preds = []
311
+ for batch in data_loader:
312
+ inputs = {k: v.to(self.device) for k, v in batch['input'].items()}
313
+ logits = self.lm(**inputs).logits
314
+ preds.append(logits)
315
+ if verbose:
316
+ pbar.update(batch['n_id'].size(0))
317
+ if verbose:
318
+ pbar.close()
319
+ preds = torch.cat(preds)
320
+ return preds
321
+
322
+ @torch.no_grad()
323
+ def inference_gnn(self, data_loader: NeighborLoader, verbose: bool = True):
324
+ r"""GNN inference step.
325
+
326
+ Args:
327
+ data_loader(NeighborLoader): include x, edge_index,
328
+ verbose (bool): print progress bar or not
329
+
330
+ Returns:
331
+ preds (tensor): prediction from GNN,
332
+ convert to pseudo labels by preds.argmax(dim=-1).unsqueeze(1)
333
+ """
334
+ if verbose:
335
+ pbar = tqdm(total=data_loader.data.num_nodes)
336
+ pbar.set_description('GNN inference stage')
337
+ preds = []
338
+ self.gnn.eval()
339
+ for batch in data_loader:
340
+ batch = batch.to(self.device)
341
+ out = self.gnn(batch.x, batch.edge_index)[:batch.batch_size]
342
+ preds.append(out)
343
+ if verbose:
344
+ pbar.update(batch.batch_size)
345
+ if verbose:
346
+ pbar.close()
347
+ preds = torch.cat(preds, dim=0)
348
+ return preds
349
+
350
+ def loss(self, logits: torch.Tensor, labels: torch.Tensor,
351
+ loss_func: torch.nn.functional, is_gold: torch.Tensor,
352
+ pseudo_labels: torch.Tensor = None, pl_weight: float = 0.5,
353
+ is_augmented: bool = True):
354
+ r"""Core function of variational EM inference, this function is aming
355
+ on combining loss value on gold(original train) and loss value on
356
+ pseudo labels.
357
+
358
+ Reference:
359
+ <https://github.com/AndyJZhao/GLEM/blob/main/src/models/GLEM/GLEM_utils.py> # noqa
360
+
361
+ Args:
362
+ logits(torch.tensor): predict results from LM or GNN
363
+ labels(torch.tensor): combined node labels from ground truth and
364
+ pseudo labels(if provided)
365
+ loss_func(torch.nn.modules.loss): loss function for classification
366
+ is_gold(tensor): a tensor with bool value that mask ground truth
367
+ label and during training, thus ~is_gold mask pseudo labels
368
+ pseudo_labels(torch.tensor): predictions from other model
369
+ pl_weight: the pseudo labels used in E-step and M-step optimization
370
+ alpha in E-step, beta in M-step respectively
371
+ is_augmented: use EM or just train GNN and LM with gold data
372
+
373
+ """
374
+ def deal_nan(x):
375
+ return 0 if torch.isnan(x) else x
376
+
377
+ if is_augmented and (sum(~is_gold) > 0):
378
+ mle_loss = deal_nan(loss_func(logits[is_gold], labels[is_gold]))
379
+ # all other labels beside from ground truth(gold labels)
380
+ pseudo_label_loss = deal_nan(
381
+ loss_func(logits[~is_gold], pseudo_labels[~is_gold]))
382
+ loss = pl_weight * pseudo_label_loss + (1 - pl_weight) * mle_loss
383
+ else:
384
+ loss = loss_func(logits, labels)
385
+ return loss
@@ -19,7 +19,6 @@ class GaussianFilter(torch.nn.Module):
19
19
 
20
20
  def reset_parameters(self):
21
21
  r"""Resets all learnable parameters of the module."""
22
- pass
23
22
 
24
23
  def forward(self, dist: Tensor) -> Tensor:
25
24
  dist = dist.view(-1, 1) - self.offset.view(1, -1)
@@ -79,12 +79,21 @@ class GraphUNet(torch.nn.Module):
79
79
  for conv in self.up_convs:
80
80
  conv.reset_parameters()
81
81
 
82
- def forward(self, x: Tensor, edge_index: Tensor,
83
- batch: OptTensor = None) -> Tensor:
82
+ def forward(
83
+ self,
84
+ x: Tensor,
85
+ edge_index: Tensor,
86
+ batch: OptTensor = None,
87
+ edge_weight: Tensor = None,
88
+ ) -> Tensor:
84
89
  """""" # noqa: D419
85
90
  if batch is None:
86
91
  batch = edge_index.new_zeros(x.size(0))
87
- edge_weight = x.new_ones(edge_index.size(1))
92
+
93
+ if edge_weight is None:
94
+ edge_weight = x.new_ones(edge_index.size(1))
95
+ assert edge_weight.dim() == 1
96
+ assert edge_weight.size(0) == edge_index.size(1)
88
97
 
89
98
  x = self.down_convs[0](x, edge_index, edge_weight)
90
99
  x = self.act(x)
@@ -1,4 +1,4 @@
1
- from typing import List, Optional
1
+ from typing import Dict, List, Optional
2
2
 
3
3
  import torch
4
4
  from torch import Tensor
@@ -41,8 +41,12 @@ class JumpingKnowledge(torch.nn.Module):
41
41
  num_layers (int, optional): The number of layers to aggregate. Needs to
42
42
  be only set for LSTM-style aggregation. (default: :obj:`None`)
43
43
  """
44
- def __init__(self, mode: str, channels: Optional[int] = None,
45
- num_layers: Optional[int] = None):
44
+ def __init__(
45
+ self,
46
+ mode: str,
47
+ channels: Optional[int] = None,
48
+ num_layers: Optional[int] = None,
49
+ ) -> None:
46
50
  super().__init__()
47
51
  self.mode = mode.lower()
48
52
  assert self.mode in ['cat', 'max', 'lstm']
@@ -63,7 +67,7 @@ class JumpingKnowledge(torch.nn.Module):
63
67
 
64
68
  self.reset_parameters()
65
69
 
66
- def reset_parameters(self):
70
+ def reset_parameters(self) -> None:
67
71
  r"""Resets all learnable parameters of the module."""
68
72
  if self.lstm is not None:
69
73
  self.lstm.reset_parameters()
@@ -94,3 +98,58 @@ class JumpingKnowledge(torch.nn.Module):
94
98
  return (f'{self.__class__.__name__}({self.mode}, '
95
99
  f'channels={self.channels}, layers={self.num_layers})')
96
100
  return f'{self.__class__.__name__}({self.mode})'
101
+
102
+
103
+ class HeteroJumpingKnowledge(torch.nn.Module):
104
+ r"""A heterogeneous version of the :class:`JumpingKnowledge` module.
105
+
106
+ Args:
107
+ types (List[str]): The keys of the input dictionary.
108
+ mode (str): The aggregation scheme to use
109
+ (:obj:`"cat"`, :obj:`"max"` or :obj:`"lstm"`).
110
+ channels (int, optional): The number of channels per representation.
111
+ Needs to be only set for LSTM-style aggregation.
112
+ (default: :obj:`None`)
113
+ num_layers (int, optional): The number of layers to aggregate. Needs to
114
+ be only set for LSTM-style aggregation. (default: :obj:`None`)
115
+ """
116
+ def __init__(
117
+ self,
118
+ types: List[str],
119
+ mode: str,
120
+ channels: Optional[int] = None,
121
+ num_layers: Optional[int] = None,
122
+ ) -> None:
123
+ super().__init__()
124
+
125
+ self.mode = mode.lower()
126
+
127
+ self.jk_dict = torch.nn.ModuleDict({
128
+ key:
129
+ JumpingKnowledge(mode, channels, num_layers)
130
+ for key in types
131
+ })
132
+
133
+ def reset_parameters(self) -> None:
134
+ r"""Resets all learnable parameters of the module."""
135
+ for jk in self.jk_dict.values():
136
+ jk.reset_parameters()
137
+
138
+ def forward(self, xs_dict: Dict[str, List[Tensor]]) -> Dict[str, Tensor]:
139
+ r"""Forward pass.
140
+
141
+ Args:
142
+ xs_dict (Dict[str, List[torch.Tensor]]): A dictionary holding a
143
+ list of layer-wise representation for each type.
144
+ """
145
+ return {key: jk(xs_dict[key]) for key, jk in self.jk_dict.items()}
146
+
147
+ def __repr__(self):
148
+ if self.mode == 'lstm':
149
+ jk = next(iter(self.jk_dict.values()))
150
+ return (f'{self.__class__.__name__}('
151
+ f'num_types={len(self.jk_dict)}, '
152
+ f'mode={self.mode}, channels={jk.channels}, '
153
+ f'layers={jk.num_layers})')
154
+ return (f'{self.__class__.__name__}(num_types={len(self.jk_dict)}, '
155
+ f'mode={self.mode})')
@@ -275,7 +275,7 @@ class BPRLoss(_Loss):
275
275
  \sum_{j \not\in \mathcal{N}_u} \ln \sigma(\hat{y}_{ui} - \hat{y}_{uj})
276
276
  + \lambda \vert\vert \textbf{x}^{(0)} \vert\vert^2
277
277
 
278
- where :math:`lambda` controls the :math:`L_2` regularization strength.
278
+ where :math:`\lambda` controls the :math:`L_2` regularization strength.
279
279
  We compute the mean BPR loss for simplicity.
280
280
 
281
281
  Args:
@@ -5,9 +5,9 @@ from torch import Tensor
5
5
  from torch.nn import Embedding
6
6
  from torch.utils.data import DataLoader
7
7
 
8
+ from torch_geometric.index import index2ptr
8
9
  from torch_geometric.typing import EdgeType, NodeType, OptTensor
9
10
  from torch_geometric.utils import sort_edge_index
10
- from torch_geometric.utils.sparse import index2ptr
11
11
 
12
12
  EPS = 1e-15
13
13
 
@@ -103,7 +103,7 @@ class MetaPath2Vec(torch.nn.Module):
103
103
  self.num_negative_samples = num_negative_samples
104
104
  self.num_nodes_dict = num_nodes_dict
105
105
 
106
- types = set([x[0] for x in metapath]) | set([x[-1] for x in metapath])
106
+ types = {x[0] for x in metapath} | {x[-1] for x in metapath}
107
107
  types = sorted(list(types))
108
108
 
109
109
  count = 0
@@ -227,14 +227,13 @@ class MetaPath2Vec(torch.nn.Module):
227
227
  return pos_loss + neg_loss
228
228
 
229
229
  def test(self, train_z: Tensor, train_y: Tensor, test_z: Tensor,
230
- test_y: Tensor, solver: str = "lbfgs", multi_class: str = "auto",
231
- *args, **kwargs) -> float:
230
+ test_y: Tensor, solver: str = "lbfgs", *args, **kwargs) -> float:
232
231
  r"""Evaluates latent space quality via a logistic regression downstream
233
232
  task.
234
233
  """
235
234
  from sklearn.linear_model import LogisticRegression
236
235
 
237
- clf = LogisticRegression(solver=solver, multi_class=multi_class, *args,
236
+ clf = LogisticRegression(solver=solver, *args,
238
237
  **kwargs).fit(train_z.detach().cpu().numpy(),
239
238
  train_y.detach().cpu().numpy())
240
239
  return clf.score(test_z.detach().cpu().numpy(),
@@ -256,6 +255,7 @@ def sample(rowptr: Tensor, col: Tensor, rowcount: Tensor, subset: Tensor,
256
255
  rand = torch.rand((subset.size(0), num_neighbors), device=subset.device)
257
256
  rand *= count.to(rand.dtype).view(-1, 1)
258
257
  rand = rand.to(torch.long) + rowptr[subset].view(-1, 1)
258
+ rand = rand.clamp(max=col.numel() - 1) # If last node is isolated.
259
259
 
260
260
  col = col[rand] if col.numel() > 0 else rand
261
261
  col[mask | (count == 0)] = dummy_idx