pyg-nightly 2.6.0.dev20240909__py3-none-any.whl → 2.6.0.dev20240910__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: pyg-nightly
3
- Version: 2.6.0.dev20240909
3
+ Version: 2.6.0.dev20240910
4
4
  Summary: Graph Neural Network Library for PyTorch
5
5
  Keywords: deep-learning,pytorch,geometric-deep-learning,graph-neural-networks,graph-convolutional-networks
6
6
  Author-email: Matthias Fey <matthias@pyg.org>
@@ -1,4 +1,4 @@
1
- torch_geometric/__init__.py,sha256=8YENZOJAbIgDzXkgwPPNDmQik7Kny9kIWjNlWiAnJYg,1904
1
+ torch_geometric/__init__.py,sha256=vRKXyHIGqBHJUKx9tLap5c3uB1Mbb6ZOlvVgapW_D6Q,1904
2
2
  torch_geometric/_compile.py,sha256=0HAdz6MGmyrgi4g6P-PorTg8dPIKx3Jo4zVJavrlfX0,1139
3
3
  torch_geometric/_onnx.py,sha256=V9ffrIKSqhDw6xUZ12lkuSfNs48cQp2EeJ6Z19GfnVw,349
4
4
  torch_geometric/backend.py,sha256=lVaf7aLoVaB3M-UcByUJ1G4T4FOK6LXAg0CF4W3E8jo,1575
@@ -416,7 +416,7 @@ torch_geometric/nn/kge/distmult.py,sha256=dGQ0bVzjreZgFN1lXE23_IIidsiOq7ehPrMb-N
416
416
  torch_geometric/nn/kge/loader.py,sha256=5Uc1j3OUMQnBYSHDqL7pLCty1siFLzoPkztigYO2zP8,771
417
417
  torch_geometric/nn/kge/rotate.py,sha256=XLuO1AbyTt5cJxr97ZzoyAyIEsHKesgW5TvDmnGJAao,3208
418
418
  torch_geometric/nn/kge/transe.py,sha256=jlejq5BLMm-sb1wWcLDp7pZqCdelWBgjDIC8ctbjSdU,3088
419
- torch_geometric/nn/models/__init__.py,sha256=_2KGXTo3eAgwcwAo0xIpw_I41n_cuTimJTZQgG0zKEc,1963
419
+ torch_geometric/nn/models/__init__.py,sha256=RpYFFqaYWq1BVMF3Fs-EQo-QZDdLQjIHPdkl3d2MOW4,2017
420
420
  torch_geometric/nn/models/attentive_fp.py,sha256=tkgvw28wg9-JqHIfBllfCwTHrZIUiv85yZJcDqjz3z0,6634
421
421
  torch_geometric/nn/models/autoencoder.py,sha256=nGje-zty78Y3hxOJ9o0_6QziJjOvBlknk6z0_fDQwQU,10770
422
422
  torch_geometric/nn/models/basic_gnn.py,sha256=PGa0RUMyvrNy_5yRI2jX_zwPsmZXwOQWfsWvxOiHsSk,31225
@@ -426,6 +426,7 @@ torch_geometric/nn/models/deep_graph_infomax.py,sha256=u6j-5-iHBASDCZ776dyfCI1N8
426
426
  torch_geometric/nn/models/deepgcn.py,sha256=tIgT03cj8MghYlxEozpoGvGG_CwpJrGDxv1Z0CVIUts,4339
427
427
  torch_geometric/nn/models/dimenet.py,sha256=Kc5p-rB5q-0e8lY22l-OdQTscTxJh2lTEpeRFMdL4RY,36186
428
428
  torch_geometric/nn/models/dimenet_utils.py,sha256=xP_nbzkSSL25GC3rrZ9KP8x9QZ59S-CZuHzCmQ-K0fI,5062
429
+ torch_geometric/nn/models/g_retriever.py,sha256=uH_aYrFbFNHaAeKQn_LtUgP5ajutLYYD8N9UvSKcpfk,7271
429
430
  torch_geometric/nn/models/gnnff.py,sha256=15dkiLgy0LmH1hnUrpeoHioIp4BPTfjpVATpnGRt9E0,7860
430
431
  torch_geometric/nn/models/graph_mixer.py,sha256=mthMeCOikR8gseEsu4oJ3Cd9C35zHSv1p32ROwnG-6s,9246
431
432
  torch_geometric/nn/models/graph_unet.py,sha256=WFb7d_DBByMGyXh3AdK2CKNmvMmSKsSUt8l8UnSOovs,5395
@@ -448,7 +449,7 @@ torch_geometric/nn/models/signed_gcn.py,sha256=J40CnedFIqtKI1LhW1ITSEFRbA_XiJZL6
448
449
  torch_geometric/nn/models/tgn.py,sha256=kEGdfLJybkbMT4UMoAh2nCzfX3_nDjfm1cicuPHEwAM,11878
449
450
  torch_geometric/nn/models/visnet.py,sha256=97OFMCsPDEI5BCSi7RhoRcU2CNRp7zck2tEzrltFZj4,43192
450
451
  torch_geometric/nn/nlp/__init__.py,sha256=JJESTA7w_K8v60XbCd25IqmrKKHLz5OiNexMHYGV2mE,138
451
- torch_geometric/nn/nlp/llm.py,sha256=blCLWkm76bKMxGgOj7dxMXcyn9ecngX1LKDkP-MRSW4,10824
452
+ torch_geometric/nn/nlp/llm.py,sha256=KwSXgI55FuHLR_9vhgekDXMaRUodPQceHPD7OCp2KN4,11639
452
453
  torch_geometric/nn/nlp/sentence_transformer.py,sha256=DzbQO8wgR34BkKpXfMqQu61hMrK94W2MBa3bZ4fDmVs,3114
453
454
  torch_geometric/nn/norm/__init__.py,sha256=u2qIDrkbeuObGVXSAIftAlvSd6ouGTtxznCfD-59UiA,669
454
455
  torch_geometric/nn/norm/batch_norm.py,sha256=sJKrinHGwA-noIgteg1RD2W06rd0zskD-rXuY-36glY,8283
@@ -460,13 +461,14 @@ torch_geometric/nn/norm/layer_norm.py,sha256=pWo5q8rLNSaU2fECpP7L8T_airtaukjOztL
460
461
  torch_geometric/nn/norm/mean_subtraction_norm.py,sha256=KVHOp413mw7obwAN09Le6XdgobtCXpi4UKpjpG1M550,1322
461
462
  torch_geometric/nn/norm/msg_norm.py,sha256=zaQtqhs55LU-e6KPC4ylaSdge4KvEoseqOt7pmAzi2s,1662
462
463
  torch_geometric/nn/norm/pair_norm.py,sha256=IfHMiVYw_xsy035NakbPGdQVaVC-Ue3Oxwo651Vc47I,2824
463
- torch_geometric/nn/pool/__init__.py,sha256=pJsD4qumvCu_oZUtC-orZCHp9nObx-VMWHFlJckFrHc,14129
464
+ torch_geometric/nn/pool/__init__.py,sha256=2Bi-_xlsGIUUKDeOO7BhaTqCc5n6_ixbu_MO9pglMts,14192
464
465
  torch_geometric/nn/pool/approx_knn.py,sha256=n7C8Cbar6o5tJcuAbzhM5hqMK26hW8dm5DopuocidO0,3967
465
466
  torch_geometric/nn/pool/asap.py,sha256=p8fwpMOeCUyJrdvMmLoTMzr0tI9YCTnefMx8ylIv5xE,6683
466
467
  torch_geometric/nn/pool/avg_pool.py,sha256=pwiQh14BCVsT-iULqVAFW-Dxt7DjFOu8CQX_Hu34vZc,3966
468
+ torch_geometric/nn/pool/cluster_pool.py,sha256=et2YaFu1kf-o6Eg9XpqHGp_Cqv68DndWbE88VJHOSPQ,5227
467
469
  torch_geometric/nn/pool/consecutive.py,sha256=7dMiMd5IybNeml1RqZq436FI6sod5ZUxTuDWJjr5syo,273
468
470
  torch_geometric/nn/pool/decimation.py,sha256=AjbU2h_Gl_EQcfkhF977EnrLJ2kait_e4HyCNKRyxPw,1601
469
- torch_geometric/nn/pool/edge_pool.py,sha256=ZdDv0t1CYwdgg56V8oFTvYozHHzerMoltEtVsY-9Wv8,8581
471
+ torch_geometric/nn/pool/edge_pool.py,sha256=cXgcN5xF8z5NeycYMX9m1zoAk1jtSdyK42YiNNHTeow,8571
470
472
  torch_geometric/nn/pool/glob.py,sha256=RJrq1sgAe8oV15WSGtXgB6yXWj2irSJIWAdQLb0byN4,3492
471
473
  torch_geometric/nn/pool/graclus.py,sha256=dL9tasXNM-x2NOMRJn8k6z4CeW46nRzoa49IzG58wow,1349
472
474
  torch_geometric/nn/pool/knn.py,sha256=fNZV0q2A4lzhZQyePRLHSrtuWjbxQxvv3V7oeNzBLVk,11343
@@ -615,6 +617,6 @@ torch_geometric/utils/undirected.py,sha256=H_nfpI0_WluOG6VfjPyldvcjL4w5USAKWu2x5
615
617
  torch_geometric/visualization/__init__.py,sha256=PyR_4K5SafsJrBr6qWrkjKr6GBL1b7FtZybyXCDEVwY,154
616
618
  torch_geometric/visualization/graph.py,sha256=SvbdVx5Zmuy_WSSA4-WWCkqAcCSHVe84mjMfsEWbZCs,4813
617
619
  torch_geometric/visualization/influence.py,sha256=CWMvuNA_Nf1sfbJmQgn58yS4OFpeKXeZPe7kEuvkUBw,477
618
- pyg_nightly-2.6.0.dev20240909.dist-info/WHEEL,sha256=EZbGkh7Ie4PoZfRQ8I0ZuP9VklN_TvcZ6DSE5Uar4z4,81
619
- pyg_nightly-2.6.0.dev20240909.dist-info/METADATA,sha256=TyeRwIRUgTAzeT0sDoNXKBQ80n5fjfhulW918W_iAWs,63068
620
- pyg_nightly-2.6.0.dev20240909.dist-info/RECORD,,
620
+ pyg_nightly-2.6.0.dev20240910.dist-info/WHEEL,sha256=EZbGkh7Ie4PoZfRQ8I0ZuP9VklN_TvcZ6DSE5Uar4z4,81
621
+ pyg_nightly-2.6.0.dev20240910.dist-info/METADATA,sha256=4d3E2ca0L5gmd30HK26OTddin11CJkS3iteHVYcnxfI,63068
622
+ pyg_nightly-2.6.0.dev20240910.dist-info/RECORD,,
@@ -30,7 +30,7 @@ from .lazy_loader import LazyLoader
30
30
  contrib = LazyLoader('contrib', globals(), 'torch_geometric.contrib')
31
31
  graphgym = LazyLoader('graphgym', globals(), 'torch_geometric.graphgym')
32
32
 
33
- __version__ = '2.6.0.dev20240909'
33
+ __version__ = '2.6.0.dev20240910'
34
34
 
35
35
  __all__ = [
36
36
  'Index',
@@ -28,6 +28,7 @@ from .gnnff import GNNFF
28
28
  from .pmlp import PMLP
29
29
  from .neural_fingerprint import NeuralFingerprint
30
30
  from .visnet import ViSNet
31
+ from .g_retriever import GRetriever
31
32
 
32
33
  # Deprecated:
33
34
  from torch_geometric.explain.algorithm.captum import (to_captum_input,
@@ -75,4 +76,5 @@ __all__ = classes = [
75
76
  'PMLP',
76
77
  'NeuralFingerprint',
77
78
  'ViSNet',
79
+ 'GRetriever',
78
80
  ]
@@ -0,0 +1,205 @@
1
+ from typing import List, Optional
2
+
3
+ import torch
4
+ from torch import Tensor
5
+
6
+ from torch_geometric.nn.models import GAT
7
+ from torch_geometric.nn.nlp.llm import BOS, LLM, MAX_NEW_TOKENS
8
+ from torch_geometric.utils import scatter
9
+
10
+
11
+ class GRetriever(torch.nn.Module):
12
+ r"""The G-Retriever model from the `"G-Retriever: Retrieval-Augmented
13
+ Generation for Textual Graph Understanding and Question Answering"
14
+ <https://arxiv.org/abs/2402.07630>`_ paper.
15
+
16
+ Args:
17
+ llm (LLM): The LLM to use.
18
+ gnn (torch.nn.Module): The GNN to use.
19
+ use_lora (bool, optional): If set to :obj:`True`, will use LORA from
20
+ :obj:`peft` for training the LLM, see
21
+ `here <https://huggingface.co/docs/peft/en/index>`_ for details.
22
+ (default: :obj:`False`)
23
+ mlp_out_channels (int, optional): The size of each graph embedding
24
+ after projection. (default: :obj:`4096`)
25
+
26
+ .. warning::
27
+ This module has been tested with the following HuggingFace models
28
+
29
+ * :obj:`llm_to_use="meta-llama/Llama-2-7b-chat-hf"`
30
+ * :obj:`llm_to_use="google/gemma-7b"`
31
+
32
+ and may not work with other models. See other models at `HuggingFace
33
+ Models <https://huggingface.co/models>`_ and let us know if you
34
+ encounter any issues.
35
+
36
+ .. note::
37
+ For an example of using :class:`GRetriever`, see
38
+ `examples/llm/g_retriever.py <https://github.com/pyg-team/
39
+ pytorch_geometric/blob/master/examples/llm/g_retriever.py>`_.
40
+ """
41
+ def __init__(
42
+ self,
43
+ llm: LLM,
44
+ gnn: torch.nn.Module,
45
+ use_lora: bool = False,
46
+ gnn_to_use=GAT,
47
+ mlp_out_channels: int = 4096,
48
+ ) -> None:
49
+ super().__init__()
50
+
51
+ self.llm = llm
52
+ self.gnn = gnn.to(self.llm.device)
53
+
54
+ self.word_embedding = self.llm.word_embedding
55
+ self.llm_generator = self.llm.llm
56
+ if use_lora:
57
+ from peft import (
58
+ LoraConfig,
59
+ get_peft_model,
60
+ prepare_model_for_kbit_training,
61
+ )
62
+ self.llm_generator = prepare_model_for_kbit_training(
63
+ self.llm_generator)
64
+ lora_r: int = 8
65
+ lora_alpha: int = 16
66
+ lora_dropout: float = 0.05
67
+ lora_target_modules = ['q_proj', 'v_proj']
68
+ config = LoraConfig(
69
+ r=lora_r,
70
+ lora_alpha=lora_alpha,
71
+ target_modules=lora_target_modules,
72
+ lora_dropout=lora_dropout,
73
+ bias='none',
74
+ task_type='CAUSAL_LM',
75
+ )
76
+ self.llm_generator = get_peft_model(self.llm_generator, config)
77
+
78
+ mlp_hidden_channels = self.gnn.out_channels
79
+ self.projector = torch.nn.Sequential(
80
+ torch.nn.Linear(mlp_hidden_channels, mlp_hidden_channels),
81
+ torch.nn.Sigmoid(),
82
+ torch.nn.Linear(mlp_hidden_channels, mlp_out_channels),
83
+ ).to(self.llm.device)
84
+
85
+ def encode(
86
+ self,
87
+ x: Tensor,
88
+ edge_index: Tensor,
89
+ batch: Tensor,
90
+ edge_attr: Optional[Tensor],
91
+ ) -> Tensor:
92
+ x = x.to(self.llm.device)
93
+ edge_index = edge_index.to(self.llm.device)
94
+ if edge_attr is not None:
95
+ edge_attr = edge_attr.to(self.llm.device)
96
+ batch = batch.to(self.llm.device)
97
+
98
+ out = self.gnn(x, edge_index, edge_attr=edge_attr)
99
+ return scatter(out, batch, dim=0, reduce='mean')
100
+
101
+ def forward(
102
+ self,
103
+ question: List[str],
104
+ x: Tensor,
105
+ edge_index: Tensor,
106
+ batch: Tensor,
107
+ label: List[str],
108
+ edge_attr: Optional[Tensor] = None,
109
+ additional_text_context: Optional[List[str]] = None,
110
+ ):
111
+ r"""The forward pass.
112
+
113
+ Args:
114
+ question (List[str]): The questions/prompts.
115
+ x (torch.Tensor): The input node features.
116
+ edge_index (torch.Tensor): The edge indices.
117
+ batch (torch.Tensor): The batch vector
118
+ :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns
119
+ each element to a specific example.
120
+ label (List[str]): The answers/labels.
121
+ edge_attr (torch.Tensor, optional): The edge features (if supported
122
+ by the GNN). (default: :obj:`None`)
123
+ additional_text_context (List[str], optional): Additional context
124
+ to give to the LLM, such as textified knowledge graphs.
125
+ (default: :obj:`None`)
126
+ """
127
+ x = self.encode(x, edge_index, batch, edge_attr)
128
+ x = self.projector(x)
129
+ xs = x.split(x.size(0), dim=0)
130
+
131
+ (
132
+ inputs_embeds,
133
+ attention_mask,
134
+ label_input_ids,
135
+ ) = self.llm._get_embeds(question, additional_text_context, xs, label)
136
+
137
+ with self.llm.autocast_context:
138
+ outputs = self.llm_generator(
139
+ inputs_embeds=inputs_embeds,
140
+ attention_mask=attention_mask,
141
+ return_dict=True,
142
+ labels=label_input_ids,
143
+ )
144
+
145
+ return outputs.loss
146
+
147
+ @torch.no_grad()
148
+ def inference(
149
+ self,
150
+ question: List[str],
151
+ x: Tensor,
152
+ edge_index: Tensor,
153
+ batch: Tensor,
154
+ edge_attr: Optional[Tensor] = None,
155
+ additional_text_context: Optional[List[str]] = None,
156
+ max_out_tokens: Optional[int] = MAX_NEW_TOKENS,
157
+ ):
158
+ r"""The inference pass.
159
+
160
+ Args:
161
+ question (List[str]): The questions/prompts.
162
+ x (torch.Tensor): The input node features.
163
+ edge_index (torch.Tensor): The edge indices.
164
+ batch (torch.Tensor): The batch vector
165
+ :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns
166
+ each element to a specific example.
167
+ edge_attr (torch.Tensor, optional): The edge features (if supported
168
+ by the GNN). (default: :obj:`None`)
169
+ additional_text_context (List[str], optional): Additional context
170
+ to give to the LLM, such as textified knowledge graphs.
171
+ (default: :obj:`None`)
172
+ max_out_tokens (int, optional): How many tokens for the LLM to
173
+ generate. (default: :obj:`32`)
174
+ """
175
+ x = self.encode(x, edge_index, batch, edge_attr)
176
+ x = self.projector(x)
177
+ xs = x.split(x.size(0), dim=0)
178
+
179
+ inputs_embeds, attention_mask, _ = self.llm._get_embeds(
180
+ question, additional_text_context, xs)
181
+
182
+ bos_token = self.llm.tokenizer(
183
+ BOS,
184
+ add_special_tokens=False,
185
+ ).input_ids[0]
186
+
187
+ with self.llm.autocast_context:
188
+ outputs = self.llm_generator.generate(
189
+ inputs_embeds=inputs_embeds,
190
+ max_new_tokens=max_out_tokens,
191
+ attention_mask=attention_mask,
192
+ bos_token_id=bos_token,
193
+ use_cache=True # Important to set!
194
+ )
195
+
196
+ return self.llm.tokenizer.batch_decode(
197
+ outputs,
198
+ skip_special_tokens=True,
199
+ )
200
+
201
+ def __repr__(self) -> str:
202
+ return (f'{self.__class__.__name__}(\n'
203
+ f' llm={self.llm},\n'
204
+ f' gnn={self.gnn},\n'
205
+ f')')
@@ -1,10 +1,14 @@
1
- import warnings
2
1
  from contextlib import nullcontext
3
2
  from typing import Any, Dict, List, Optional
4
3
 
5
4
  import torch
6
5
  from torch import Tensor
7
6
 
7
+ try:
8
+ from transformers.tokenization_utils_base import BatchEncoding
9
+ except ImportError:
10
+ BatchEncoding = Dict
11
+
8
12
  BOS = '<s>[INST]'
9
13
  EOS_USER = '[/INST]'
10
14
  EOS = '[/s]'
@@ -61,23 +65,16 @@ class LLM(torch.nn.Module):
61
65
  ) -> None:
62
66
  super().__init__()
63
67
 
64
- from transformers import AutoModelForCausalLM, AutoTokenizer
68
+ self.model_name = model_name
65
69
 
66
- if model_name == 'llama2-7b':
67
- pretty_model_name = 'LLAMA2'
68
- model_name = 'meta-llama/Llama-2-7b-chat-hf'
69
- elif model_name == 'gemma':
70
- pretty_model_name = 'GEMMA'
71
- model_name = 'google/gemma-7b'
72
- else:
73
- pretty_model_name = model_name
70
+ from transformers import AutoModelForCausalLM, AutoTokenizer
74
71
 
75
72
  # A rough heuristic on GPU memory requirements, e.g., we found that
76
73
  # LLAMA2 (7B parameters) fits on a 85GB GPU.
77
74
  required_memory = 85 * num_params / 7
78
75
  kwargs = get_llm_kwargs(required_memory, dtype)
79
76
 
80
- print(f"Setting up '{pretty_model_name}' with configuration: {kwargs}")
77
+ print(f"Setting up '{model_name}' with configuration: {kwargs}")
81
78
  self.tokenizer = AutoTokenizer.from_pretrained(
82
79
  model_name,
83
80
  use_fast=False,
@@ -88,17 +85,17 @@ class LLM(torch.nn.Module):
88
85
  self.word_embedding = self.llm.model.get_input_embeddings()
89
86
 
90
87
  if 'max_memory' not in kwargs: # Pure CPU:
91
- self.llm_device = torch.device('cpu')
88
+ self.device = torch.device('cpu')
92
89
  self.autocast_context = nullcontext()
93
90
  else:
94
- self.llm_device = self.llm.device
91
+ self.device = self.llm.device
95
92
  self.autocast_context = torch.cuda.amp.autocast(dtype=dtype)
96
93
 
97
94
  def _encode_inputs(
98
95
  self,
99
96
  question: List[str],
100
97
  context: Optional[List[str]] = None,
101
- ) -> None:
98
+ ) -> tuple:
102
99
  batch_size = len(question)
103
100
  questions = self.tokenizer(question, add_special_tokens=False)
104
101
  if context is not None:
@@ -109,14 +106,144 @@ class LLM(torch.nn.Module):
109
106
  BOS,
110
107
  add_special_tokens=False,
111
108
  return_tensors='pt',
112
- ).input_ids[0].to(self.llm_device)
109
+ ).input_ids[0].to(self.device)
113
110
  bos_embeds = self.word_embedding(bos_token)
114
111
  pad_token = torch.tensor(self.tokenizer.pad_token_id,
115
- device=self.llm_device)
112
+ device=self.device)
116
113
  pad_embeds = self.word_embedding(pad_token).unsqueeze(0)
117
114
  return (batch_size, questions, context, eos_user_tokens, bos_embeds,
118
115
  pad_embeds)
119
116
 
117
+ def _label_input_ids(
118
+ self,
119
+ i: int,
120
+ label: BatchEncoding,
121
+ eos_tokens: BatchEncoding,
122
+ ) -> List[int]:
123
+ label_input_ids = label.input_ids[i][:MAX_NEW_TOKENS]
124
+ label_input_ids = label_input_ids + eos_tokens.input_ids
125
+ return label_input_ids
126
+
127
+ def _input_ids(
128
+ self,
129
+ i: int,
130
+ context: BatchEncoding,
131
+ question: BatchEncoding,
132
+ eos_user_tokens: BatchEncoding,
133
+ ) -> List[int]:
134
+ input_ids: List[int] = []
135
+ if context is not None:
136
+ input_ids += context.input_ids[i][:MAX_TXT_LEN]
137
+ input_ids += question.input_ids[i]
138
+ input_ids += eos_user_tokens.input_ids
139
+ return input_ids
140
+
141
+ def _inputs_embeds(
142
+ self,
143
+ i: int,
144
+ input_ids: List[int],
145
+ bos_embeds: Tensor,
146
+ embedding: Optional[List[Tensor]] = None,
147
+ ) -> Tensor:
148
+ inputs_embeds = self.word_embedding(
149
+ torch.tensor(input_ids, device=self.device))
150
+
151
+ to_cat = [bos_embeds]
152
+ if embedding is not None and embedding[i] is not None:
153
+ to_cat.append(embedding[i])
154
+ to_cat.append(inputs_embeds)
155
+ return torch.cat(to_cat, dim=0).to(self.device)
156
+
157
+ def _append_embeds(
158
+ self,
159
+ inputs_embeds: Tensor,
160
+ batch_inputs_embeds: List[Tensor],
161
+ batch_attention_mask: List[List[int]],
162
+ label_input_ids: List[int] = None,
163
+ batch_label_input_ids: Optional[List[List[int]]] = None,
164
+ ) -> tuple:
165
+ batch_inputs_embeds.append(inputs_embeds)
166
+ batch_attention_mask.append([1] * inputs_embeds.size(0))
167
+ if label_input_ids is not None:
168
+ pad = inputs_embeds.size(0) - len(label_input_ids)
169
+ label_input_ids = [IGNORE_INDEX] * pad + label_input_ids
170
+ batch_label_input_ids.append(label_input_ids)
171
+ return batch_inputs_embeds, batch_attention_mask, batch_label_input_ids
172
+
173
+ def _pad_embeds(
174
+ self,
175
+ pad_embeds: Tensor,
176
+ batch_inputs_embeds: List[Tensor],
177
+ batch_attention_mask: List[List[int]],
178
+ batch_label_input_ids: Optional[List[List[int]]] = None,
179
+ ) -> tuple:
180
+ max_length = max([x.size(0) for x in batch_inputs_embeds])
181
+ batch_size = len(batch_inputs_embeds)
182
+ for i in range(batch_size):
183
+ pad = max_length - batch_inputs_embeds[i].size(0)
184
+ batch_inputs_embeds[i] = torch.cat([
185
+ pad_embeds.repeat(pad, 1),
186
+ batch_inputs_embeds[i],
187
+ ])
188
+ batch_attention_mask[i] = [0] * pad + batch_attention_mask[i]
189
+ if batch_label_input_ids is not None:
190
+ tmp = [IGNORE_INDEX] * pad + batch_label_input_ids[i]
191
+ batch_label_input_ids[i] = tmp
192
+ inputs_embeds = torch.stack(batch_inputs_embeds, dim=0)
193
+ attention_mask = torch.tensor(batch_attention_mask, device=self.device)
194
+ label_input_ids = None
195
+ if batch_label_input_ids is not None:
196
+ label_input_ids = torch.tensor(batch_label_input_ids,
197
+ device=self.device)
198
+ return inputs_embeds, attention_mask, label_input_ids
199
+
200
+ def _get_embeds(
201
+ self,
202
+ question: List[str],
203
+ context: Optional[List[str]] = None,
204
+ embedding: Optional[List[Tensor]] = None,
205
+ answer: Optional[List[str]] = None,
206
+ ) -> tuple:
207
+ (batch_size, question, context, eos_user_tokens, bos_embeds,
208
+ pad_embeds) = self._encode_inputs(question, context)
209
+
210
+ batch_label_input_ids = None
211
+ if answer is not None:
212
+ label = self.tokenizer(answer, add_special_tokens=False)
213
+ eos_tokens = self.tokenizer(EOS, add_special_tokens=False)
214
+ batch_label_input_ids = []
215
+
216
+ batch_inputs_embeds = []
217
+ batch_attention_mask = []
218
+ for i in range(batch_size):
219
+ input_ids = self._input_ids(i, context, question, eos_user_tokens)
220
+ if answer is not None:
221
+ label_input_ids = self._label_input_ids(i, label, eos_tokens)
222
+ input_ids += label_input_ids
223
+ else:
224
+ label_input_ids = None
225
+
226
+ inputs_embeds = self._inputs_embeds(i, input_ids, bos_embeds,
227
+ embedding)
228
+
229
+ (
230
+ batch_inputs_embeds,
231
+ batch_attention_mask,
232
+ batch_label_input_ids,
233
+ ) = self._append_embeds(
234
+ inputs_embeds,
235
+ batch_inputs_embeds,
236
+ batch_attention_mask,
237
+ label_input_ids,
238
+ batch_label_input_ids,
239
+ )
240
+
241
+ inputs_embeds, attention_mask, label_input_ids = self._pad_embeds(
242
+ pad_embeds, batch_inputs_embeds, batch_attention_mask,
243
+ batch_label_input_ids)
244
+
245
+ return inputs_embeds, attention_mask, label_input_ids
246
+
120
247
  def forward(
121
248
  self,
122
249
  question: List[str],
@@ -133,65 +260,11 @@ class LLM(torch.nn.Module):
133
260
  LLM, such as textified knowledge graphs. (default: :obj:`None`)
134
261
  embedding (list[torch.Tensor], optional): RAG embedding
135
262
  tensors, *i.e.* the embedded form of :obj:`context`. Either
136
- :obj:`context` or :obj:`rag_embeddings` should be used, not
263
+ :obj:`context` or :obj:`embedding` should be used, not
137
264
  both. (default: :obj:`None`)
138
265
  """
139
- if context is not None and embedding is not None:
140
- warnings.warn("Using both 'context' and 'embedding' is a waste of "
141
- "compute and memory")
142
-
143
- (batch_size, question, context, eos_user_tokens, bos_embeds,
144
- pad_embeds) = self._encode_inputs(question, context)
145
-
146
- label = self.tokenizer(answer, add_special_tokens=False)
147
- eos_tokens = self.tokenizer(EOS, add_special_tokens=False)
148
-
149
- batch_inputs_embeds = []
150
- batch_attention_mask = []
151
- batch_label_input_ids = []
152
- for i in range(batch_size):
153
- label_input_ids = label.input_ids[i][:MAX_NEW_TOKENS]
154
- label_input_ids += eos_tokens.input_ids # Add EOS token.
155
-
156
- input_ids: List[int] = []
157
- if context is not None:
158
- input_ids += context.input_ids[i][:MAX_TXT_LEN]
159
- input_ids += question.input_ids[i]
160
- input_ids += eos_user_tokens.input_ids
161
- input_ids += label_input_ids
162
-
163
- inputs_embeds = self.word_embedding(
164
- torch.tensor(input_ids, device=self.llm_device))
165
-
166
- to_cat = [bos_embeds]
167
- if embedding is not None:
168
- to_cat.append(embedding[i])
169
- to_cat.append(inputs_embeds)
170
- inputs_embeds = torch.cat(to_cat, dim=0)
171
-
172
- batch_inputs_embeds.append(inputs_embeds)
173
- batch_attention_mask.append([1] * inputs_embeds.size(0))
174
- label_input_ids = [IGNORE_INDEX] * (
175
- inputs_embeds.size(0) - len(label_input_ids)) + label_input_ids
176
- batch_label_input_ids.append(label_input_ids)
177
-
178
- # Pad input embeddings:
179
- max_length = max([x.size(0) for x in batch_inputs_embeds])
180
- for i in range(batch_size):
181
- pad = max_length - batch_inputs_embeds[i].size(0)
182
- batch_inputs_embeds[i] = torch.cat([
183
- pad_embeds.repeat(pad, 1),
184
- batch_inputs_embeds[i],
185
- ])
186
- batch_attention_mask[i] = [0] * pad + batch_attention_mask[i]
187
- batch_label_input_ids[i] = ([IGNORE_INDEX] * pad +
188
- batch_label_input_ids[i])
189
-
190
- inputs_embeds = torch.stack(batch_inputs_embeds, dim=0)
191
- attention_mask = torch.tensor(batch_attention_mask,
192
- device=self.llm_device)
193
- label_input_ids = torch.tensor(batch_label_input_ids,
194
- device=self.llm_device)
266
+ inputs_embeds, attention_mask, label_input_ids = self._get_embeds(
267
+ question, context, embedding, answer)
195
268
 
196
269
  with self.autocast_context:
197
270
  outputs = self.llm(
@@ -219,52 +292,13 @@ class LLM(torch.nn.Module):
219
292
  LLM, such as textified knowledge graphs. (default: :obj:`None`)
220
293
  embedding (list[torch.Tensor], optional): RAG embedding
221
294
  tensors, *i.e.* the embedded form of :obj:`context`. Either
222
- :obj:`context` or :obj:`rag_embeddings` should be used, not
295
+ :obj:`context` or :obj:`embedding` should be used, not
223
296
  both. (default: :obj:`None`)
224
297
  max_tokens (int, optional): How many tokens for the LLM to
225
298
  generate. (default: :obj:`32`)
226
299
  """
227
- if context is not None and embedding is not None:
228
- warnings.warn("Using both 'context' and 'embedding' is a waste of "
229
- "compute and memory")
230
-
231
- (batch_size, question, context, eos_user_tokens, bos_embeds,
232
- pad_embeds) = self._encode_inputs(question, context)
233
-
234
- batch_inputs_embeds = []
235
- batch_attention_mask = []
236
- for i in range(batch_size):
237
- input_ids: List[int] = []
238
- if context is not None:
239
- input_ids = context.input_ids[i][:MAX_TXT_LEN]
240
- input_ids += question.input_ids[i]
241
- input_ids += eos_user_tokens.input_ids
242
-
243
- inputs_embeds = self.word_embedding(
244
- torch.tensor(input_ids, device=self.llm_device))
245
-
246
- to_cat = [bos_embeds]
247
- if embedding is not None:
248
- to_cat.append(embedding[i])
249
- to_cat.append(inputs_embeds)
250
- inputs_embeds = torch.cat(to_cat, dim=0)
251
-
252
- batch_inputs_embeds.append(inputs_embeds)
253
- batch_attention_mask.append([1] * inputs_embeds.size(0))
254
-
255
- # Pad input embeddings:
256
- max_length = max([x.size(0) for x in batch_inputs_embeds])
257
- for i in range(batch_size):
258
- pad = max_length - batch_inputs_embeds[i].size(0)
259
- batch_inputs_embeds[i] = torch.cat([
260
- pad_embeds.repeat(pad, 1),
261
- batch_inputs_embeds[i],
262
- ])
263
- batch_attention_mask[i] = [0] * pad + batch_attention_mask[i]
264
-
265
- inputs_embeds = torch.stack(batch_inputs_embeds, dim=0)
266
- attention_mask = torch.tensor(batch_attention_mask,
267
- device=self.llm_device)
300
+ inputs_embeds, attention_mask, _ = self._get_embeds(
301
+ question, context, embedding)
268
302
 
269
303
  bos_token = self.tokenizer(
270
304
  BOS,
@@ -281,3 +315,6 @@ class LLM(torch.nn.Module):
281
315
  )
282
316
 
283
317
  return self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
318
+
319
+ def __repr__(self) -> str:
320
+ return f'{self.__class__.__name__}({self.model_name})'
@@ -7,18 +7,19 @@ from torch import Tensor
7
7
  import torch_geometric.typing
8
8
  from torch_geometric.typing import OptTensor, torch_cluster
9
9
 
10
- from .asap import ASAPooling
11
10
  from .avg_pool import avg_pool, avg_pool_neighbor_x, avg_pool_x
12
- from .edge_pool import EdgePooling
13
11
  from .glob import global_add_pool, global_max_pool, global_mean_pool
14
12
  from .knn import (KNNIndex, L2KNNIndex, MIPSKNNIndex, ApproxL2KNNIndex,
15
13
  ApproxMIPSKNNIndex)
16
14
  from .graclus import graclus
17
15
  from .max_pool import max_pool, max_pool_neighbor_x, max_pool_x
18
- from .mem_pool import MemPooling
19
- from .pan_pool import PANPooling
20
- from .sag_pool import SAGPooling
21
16
  from .topk_pool import TopKPooling
17
+ from .sag_pool import SAGPooling
18
+ from .edge_pool import EdgePooling
19
+ from .cluster_pool import ClusterPooling
20
+ from .asap import ASAPooling
21
+ from .pan_pool import PANPooling
22
+ from .mem_pool import MemPooling
22
23
  from .voxel_grid import voxel_grid
23
24
  from .approx_knn import approx_knn, approx_knn_graph
24
25
 
@@ -344,6 +345,7 @@ __all__ = [
344
345
  'TopKPooling',
345
346
  'SAGPooling',
346
347
  'EdgePooling',
348
+ 'ClusterPooling',
347
349
  'ASAPooling',
348
350
  'PANPooling',
349
351
  'MemPooling',
@@ -0,0 +1,145 @@
1
+ from typing import NamedTuple, Optional, Tuple
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import Tensor
6
+
7
+ from torch_geometric.utils import (
8
+ dense_to_sparse,
9
+ one_hot,
10
+ to_dense_adj,
11
+ to_scipy_sparse_matrix,
12
+ )
13
+
14
+
15
+ class UnpoolInfo(NamedTuple):
16
+ edge_index: Tensor
17
+ cluster: Tensor
18
+ batch: Tensor
19
+
20
+
21
+ class ClusterPooling(torch.nn.Module):
22
+ r"""The cluster pooling operator from the `"Edge-Based Graph Component
23
+ Pooling" <paper url>`_ paper.
24
+
25
+ :class:`ClusterPooling` computes a score for each edge.
26
+ Based on the selected edges, graph clusters are calculated and compressed
27
+ to one node using the injective :obj:`"sum"` aggregation function.
28
+ Edges are remapped based on the nodes created by each cluster and the
29
+ original edges.
30
+
31
+ Args:
32
+ in_channels (int): Size of each input sample.
33
+ edge_score_method (str, optional): The function to apply
34
+ to compute the edge score from raw edge scores (:obj:`"tanh"`,
35
+ :obj:`"sigmoid"`, :obj:`"log_softmax"`). (default: :obj:`"tanh"`)
36
+ dropout (float, optional): The probability with
37
+ which to drop edge scores during training. (default: :obj:`0.0`)
38
+ threshold (float, optional): The threshold of edge scores. If set to
39
+ :obj:`None`, will be automatically inferred depending on
40
+ :obj:`edge_score_method`. (default: :obj:`None`)
41
+ """
42
+ def __init__(
43
+ self,
44
+ in_channels: int,
45
+ edge_score_method: str = 'tanh',
46
+ dropout: float = 0.0,
47
+ threshold: Optional[float] = None,
48
+ ):
49
+ super().__init__()
50
+ assert edge_score_method in ['tanh', 'sigmoid', 'log_softmax']
51
+
52
+ if threshold is None:
53
+ threshold = 0.5 if edge_score_method == 'sigmoid' else 0.0
54
+
55
+ self.in_channels = in_channels
56
+ self.edge_score_method = edge_score_method
57
+ self.dropout = dropout
58
+ self.threshhold = threshold
59
+
60
+ self.lin = torch.nn.Linear(2 * in_channels, 1)
61
+
62
+ def reset_parameters(self):
63
+ r"""Resets all learnable parameters of the module."""
64
+ self.lin.reset_parameters()
65
+
66
+ def forward(
67
+ self,
68
+ x: Tensor,
69
+ edge_index: Tensor,
70
+ batch: Tensor,
71
+ ) -> Tuple[Tensor, Tensor, Tensor, UnpoolInfo]:
72
+ r"""Forward pass.
73
+
74
+ Args:
75
+ x (torch.Tensor): The node features.
76
+ edge_index (torch.Tensor): The edge indices.
77
+ batch (torch.Tensor): Batch vector
78
+ :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns
79
+ each node to a specific example.
80
+
81
+ Return types:
82
+ * **x** *(torch.Tensor)* - The pooled node features.
83
+ * **edge_index** *(torch.Tensor)* - The coarsened edge indices.
84
+ * **batch** *(torch.Tensor)* - The coarsened batch vector.
85
+ * **unpool_info** *(UnpoolInfo)* - Information that can be consumed
86
+ for unpooling.
87
+ """
88
+ mask = edge_index[0] != edge_index[1]
89
+ edge_index = edge_index[:, mask]
90
+
91
+ edge_attr = torch.cat(
92
+ [x[edge_index[0]], x[edge_index[1]]],
93
+ dim=-1,
94
+ )
95
+ edge_score = self.lin(edge_attr).view(-1)
96
+ edge_score = F.dropout(edge_score, p=self.dropout,
97
+ training=self.training)
98
+
99
+ if self.edge_score_method == 'tanh':
100
+ edge_score = edge_score.tanh()
101
+ elif self.edge_score_method == 'sigmoid':
102
+ edge_score = edge_score.sigmoid()
103
+ else:
104
+ assert self.edge_score_method == 'log_softmax'
105
+ edge_score = F.log_softmax(edge_score, dim=0)
106
+
107
+ return self._merge_edges(x, edge_index, batch, edge_score)
108
+
109
+ def _merge_edges(
110
+ self,
111
+ x: Tensor,
112
+ edge_index: Tensor,
113
+ batch: Tensor,
114
+ edge_score: Tensor,
115
+ ) -> Tuple[Tensor, Tensor, Tensor, UnpoolInfo]:
116
+
117
+ from scipy.sparse.csgraph import connected_components
118
+
119
+ edge_contract = edge_index[:, edge_score > self.threshhold]
120
+
121
+ adj = to_scipy_sparse_matrix(edge_contract, num_nodes=x.size(0))
122
+ _, cluster_np = connected_components(adj, directed=True,
123
+ connection="weak")
124
+
125
+ cluster = torch.tensor(cluster_np, dtype=torch.long, device=x.device)
126
+ C = one_hot(cluster)
127
+ A = to_dense_adj(edge_index, max_num_nodes=x.size(0)).squeeze(0)
128
+ S = to_dense_adj(edge_index, edge_attr=edge_score,
129
+ max_num_nodes=x.size(0)).squeeze(0)
130
+
131
+ A_contract = to_dense_adj(edge_contract,
132
+ max_num_nodes=x.size(0)).squeeze(0)
133
+ nodes_single = ((A_contract.sum(dim=-1) +
134
+ A_contract.sum(dim=-2)) == 0).nonzero()
135
+ S[nodes_single, nodes_single] = 1.0
136
+
137
+ x_out = (S @ C).t() @ x
138
+ edge_index_out, _ = dense_to_sparse((C.T @ A @ C).fill_diagonal_(0))
139
+ batch_out = batch.new_empty(x_out.size(0)).scatter_(0, cluster, batch)
140
+ unpool_info = UnpoolInfo(edge_index, cluster, batch)
141
+
142
+ return x_out, edge_index_out, batch_out, unpool_info
143
+
144
+ def __repr__(self) -> str:
145
+ return f'{self.__class__.__name__}({self.in_channels})'
@@ -58,7 +58,7 @@ class EdgePooling(torch.nn.Module):
58
58
  self,
59
59
  in_channels: int,
60
60
  edge_score_method: Optional[Callable] = None,
61
- dropout: Optional[float] = 0.0,
61
+ dropout: float = 0.0,
62
62
  add_to_edge_score: float = 0.5,
63
63
  ):
64
64
  super().__init__()