pyg-nightly 2.6.0.dev20240909__py3-none-any.whl → 2.6.0.dev20240910__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pyg_nightly-2.6.0.dev20240909.dist-info → pyg_nightly-2.6.0.dev20240910.dist-info}/METADATA +1 -1
- {pyg_nightly-2.6.0.dev20240909.dist-info → pyg_nightly-2.6.0.dev20240910.dist-info}/RECORD +10 -8
- torch_geometric/__init__.py +1 -1
- torch_geometric/nn/models/__init__.py +2 -0
- torch_geometric/nn/models/g_retriever.py +205 -0
- torch_geometric/nn/nlp/llm.py +152 -115
- torch_geometric/nn/pool/__init__.py +7 -5
- torch_geometric/nn/pool/cluster_pool.py +145 -0
- torch_geometric/nn/pool/edge_pool.py +1 -1
- {pyg_nightly-2.6.0.dev20240909.dist-info → pyg_nightly-2.6.0.dev20240910.dist-info}/WHEEL +0 -0
{pyg_nightly-2.6.0.dev20240909.dist-info → pyg_nightly-2.6.0.dev20240910.dist-info}/METADATA
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: pyg-nightly
|
3
|
-
Version: 2.6.0.
|
3
|
+
Version: 2.6.0.dev20240910
|
4
4
|
Summary: Graph Neural Network Library for PyTorch
|
5
5
|
Keywords: deep-learning,pytorch,geometric-deep-learning,graph-neural-networks,graph-convolutional-networks
|
6
6
|
Author-email: Matthias Fey <matthias@pyg.org>
|
@@ -1,4 +1,4 @@
|
|
1
|
-
torch_geometric/__init__.py,sha256=
|
1
|
+
torch_geometric/__init__.py,sha256=vRKXyHIGqBHJUKx9tLap5c3uB1Mbb6ZOlvVgapW_D6Q,1904
|
2
2
|
torch_geometric/_compile.py,sha256=0HAdz6MGmyrgi4g6P-PorTg8dPIKx3Jo4zVJavrlfX0,1139
|
3
3
|
torch_geometric/_onnx.py,sha256=V9ffrIKSqhDw6xUZ12lkuSfNs48cQp2EeJ6Z19GfnVw,349
|
4
4
|
torch_geometric/backend.py,sha256=lVaf7aLoVaB3M-UcByUJ1G4T4FOK6LXAg0CF4W3E8jo,1575
|
@@ -416,7 +416,7 @@ torch_geometric/nn/kge/distmult.py,sha256=dGQ0bVzjreZgFN1lXE23_IIidsiOq7ehPrMb-N
|
|
416
416
|
torch_geometric/nn/kge/loader.py,sha256=5Uc1j3OUMQnBYSHDqL7pLCty1siFLzoPkztigYO2zP8,771
|
417
417
|
torch_geometric/nn/kge/rotate.py,sha256=XLuO1AbyTt5cJxr97ZzoyAyIEsHKesgW5TvDmnGJAao,3208
|
418
418
|
torch_geometric/nn/kge/transe.py,sha256=jlejq5BLMm-sb1wWcLDp7pZqCdelWBgjDIC8ctbjSdU,3088
|
419
|
-
torch_geometric/nn/models/__init__.py,sha256=
|
419
|
+
torch_geometric/nn/models/__init__.py,sha256=RpYFFqaYWq1BVMF3Fs-EQo-QZDdLQjIHPdkl3d2MOW4,2017
|
420
420
|
torch_geometric/nn/models/attentive_fp.py,sha256=tkgvw28wg9-JqHIfBllfCwTHrZIUiv85yZJcDqjz3z0,6634
|
421
421
|
torch_geometric/nn/models/autoencoder.py,sha256=nGje-zty78Y3hxOJ9o0_6QziJjOvBlknk6z0_fDQwQU,10770
|
422
422
|
torch_geometric/nn/models/basic_gnn.py,sha256=PGa0RUMyvrNy_5yRI2jX_zwPsmZXwOQWfsWvxOiHsSk,31225
|
@@ -426,6 +426,7 @@ torch_geometric/nn/models/deep_graph_infomax.py,sha256=u6j-5-iHBASDCZ776dyfCI1N8
|
|
426
426
|
torch_geometric/nn/models/deepgcn.py,sha256=tIgT03cj8MghYlxEozpoGvGG_CwpJrGDxv1Z0CVIUts,4339
|
427
427
|
torch_geometric/nn/models/dimenet.py,sha256=Kc5p-rB5q-0e8lY22l-OdQTscTxJh2lTEpeRFMdL4RY,36186
|
428
428
|
torch_geometric/nn/models/dimenet_utils.py,sha256=xP_nbzkSSL25GC3rrZ9KP8x9QZ59S-CZuHzCmQ-K0fI,5062
|
429
|
+
torch_geometric/nn/models/g_retriever.py,sha256=uH_aYrFbFNHaAeKQn_LtUgP5ajutLYYD8N9UvSKcpfk,7271
|
429
430
|
torch_geometric/nn/models/gnnff.py,sha256=15dkiLgy0LmH1hnUrpeoHioIp4BPTfjpVATpnGRt9E0,7860
|
430
431
|
torch_geometric/nn/models/graph_mixer.py,sha256=mthMeCOikR8gseEsu4oJ3Cd9C35zHSv1p32ROwnG-6s,9246
|
431
432
|
torch_geometric/nn/models/graph_unet.py,sha256=WFb7d_DBByMGyXh3AdK2CKNmvMmSKsSUt8l8UnSOovs,5395
|
@@ -448,7 +449,7 @@ torch_geometric/nn/models/signed_gcn.py,sha256=J40CnedFIqtKI1LhW1ITSEFRbA_XiJZL6
|
|
448
449
|
torch_geometric/nn/models/tgn.py,sha256=kEGdfLJybkbMT4UMoAh2nCzfX3_nDjfm1cicuPHEwAM,11878
|
449
450
|
torch_geometric/nn/models/visnet.py,sha256=97OFMCsPDEI5BCSi7RhoRcU2CNRp7zck2tEzrltFZj4,43192
|
450
451
|
torch_geometric/nn/nlp/__init__.py,sha256=JJESTA7w_K8v60XbCd25IqmrKKHLz5OiNexMHYGV2mE,138
|
451
|
-
torch_geometric/nn/nlp/llm.py,sha256=
|
452
|
+
torch_geometric/nn/nlp/llm.py,sha256=KwSXgI55FuHLR_9vhgekDXMaRUodPQceHPD7OCp2KN4,11639
|
452
453
|
torch_geometric/nn/nlp/sentence_transformer.py,sha256=DzbQO8wgR34BkKpXfMqQu61hMrK94W2MBa3bZ4fDmVs,3114
|
453
454
|
torch_geometric/nn/norm/__init__.py,sha256=u2qIDrkbeuObGVXSAIftAlvSd6ouGTtxznCfD-59UiA,669
|
454
455
|
torch_geometric/nn/norm/batch_norm.py,sha256=sJKrinHGwA-noIgteg1RD2W06rd0zskD-rXuY-36glY,8283
|
@@ -460,13 +461,14 @@ torch_geometric/nn/norm/layer_norm.py,sha256=pWo5q8rLNSaU2fECpP7L8T_airtaukjOztL
|
|
460
461
|
torch_geometric/nn/norm/mean_subtraction_norm.py,sha256=KVHOp413mw7obwAN09Le6XdgobtCXpi4UKpjpG1M550,1322
|
461
462
|
torch_geometric/nn/norm/msg_norm.py,sha256=zaQtqhs55LU-e6KPC4ylaSdge4KvEoseqOt7pmAzi2s,1662
|
462
463
|
torch_geometric/nn/norm/pair_norm.py,sha256=IfHMiVYw_xsy035NakbPGdQVaVC-Ue3Oxwo651Vc47I,2824
|
463
|
-
torch_geometric/nn/pool/__init__.py,sha256=
|
464
|
+
torch_geometric/nn/pool/__init__.py,sha256=2Bi-_xlsGIUUKDeOO7BhaTqCc5n6_ixbu_MO9pglMts,14192
|
464
465
|
torch_geometric/nn/pool/approx_knn.py,sha256=n7C8Cbar6o5tJcuAbzhM5hqMK26hW8dm5DopuocidO0,3967
|
465
466
|
torch_geometric/nn/pool/asap.py,sha256=p8fwpMOeCUyJrdvMmLoTMzr0tI9YCTnefMx8ylIv5xE,6683
|
466
467
|
torch_geometric/nn/pool/avg_pool.py,sha256=pwiQh14BCVsT-iULqVAFW-Dxt7DjFOu8CQX_Hu34vZc,3966
|
468
|
+
torch_geometric/nn/pool/cluster_pool.py,sha256=et2YaFu1kf-o6Eg9XpqHGp_Cqv68DndWbE88VJHOSPQ,5227
|
467
469
|
torch_geometric/nn/pool/consecutive.py,sha256=7dMiMd5IybNeml1RqZq436FI6sod5ZUxTuDWJjr5syo,273
|
468
470
|
torch_geometric/nn/pool/decimation.py,sha256=AjbU2h_Gl_EQcfkhF977EnrLJ2kait_e4HyCNKRyxPw,1601
|
469
|
-
torch_geometric/nn/pool/edge_pool.py,sha256=
|
471
|
+
torch_geometric/nn/pool/edge_pool.py,sha256=cXgcN5xF8z5NeycYMX9m1zoAk1jtSdyK42YiNNHTeow,8571
|
470
472
|
torch_geometric/nn/pool/glob.py,sha256=RJrq1sgAe8oV15WSGtXgB6yXWj2irSJIWAdQLb0byN4,3492
|
471
473
|
torch_geometric/nn/pool/graclus.py,sha256=dL9tasXNM-x2NOMRJn8k6z4CeW46nRzoa49IzG58wow,1349
|
472
474
|
torch_geometric/nn/pool/knn.py,sha256=fNZV0q2A4lzhZQyePRLHSrtuWjbxQxvv3V7oeNzBLVk,11343
|
@@ -615,6 +617,6 @@ torch_geometric/utils/undirected.py,sha256=H_nfpI0_WluOG6VfjPyldvcjL4w5USAKWu2x5
|
|
615
617
|
torch_geometric/visualization/__init__.py,sha256=PyR_4K5SafsJrBr6qWrkjKr6GBL1b7FtZybyXCDEVwY,154
|
616
618
|
torch_geometric/visualization/graph.py,sha256=SvbdVx5Zmuy_WSSA4-WWCkqAcCSHVe84mjMfsEWbZCs,4813
|
617
619
|
torch_geometric/visualization/influence.py,sha256=CWMvuNA_Nf1sfbJmQgn58yS4OFpeKXeZPe7kEuvkUBw,477
|
618
|
-
pyg_nightly-2.6.0.
|
619
|
-
pyg_nightly-2.6.0.
|
620
|
-
pyg_nightly-2.6.0.
|
620
|
+
pyg_nightly-2.6.0.dev20240910.dist-info/WHEEL,sha256=EZbGkh7Ie4PoZfRQ8I0ZuP9VklN_TvcZ6DSE5Uar4z4,81
|
621
|
+
pyg_nightly-2.6.0.dev20240910.dist-info/METADATA,sha256=4d3E2ca0L5gmd30HK26OTddin11CJkS3iteHVYcnxfI,63068
|
622
|
+
pyg_nightly-2.6.0.dev20240910.dist-info/RECORD,,
|
torch_geometric/__init__.py
CHANGED
@@ -30,7 +30,7 @@ from .lazy_loader import LazyLoader
|
|
30
30
|
contrib = LazyLoader('contrib', globals(), 'torch_geometric.contrib')
|
31
31
|
graphgym = LazyLoader('graphgym', globals(), 'torch_geometric.graphgym')
|
32
32
|
|
33
|
-
__version__ = '2.6.0.
|
33
|
+
__version__ = '2.6.0.dev20240910'
|
34
34
|
|
35
35
|
__all__ = [
|
36
36
|
'Index',
|
@@ -28,6 +28,7 @@ from .gnnff import GNNFF
|
|
28
28
|
from .pmlp import PMLP
|
29
29
|
from .neural_fingerprint import NeuralFingerprint
|
30
30
|
from .visnet import ViSNet
|
31
|
+
from .g_retriever import GRetriever
|
31
32
|
|
32
33
|
# Deprecated:
|
33
34
|
from torch_geometric.explain.algorithm.captum import (to_captum_input,
|
@@ -75,4 +76,5 @@ __all__ = classes = [
|
|
75
76
|
'PMLP',
|
76
77
|
'NeuralFingerprint',
|
77
78
|
'ViSNet',
|
79
|
+
'GRetriever',
|
78
80
|
]
|
@@ -0,0 +1,205 @@
|
|
1
|
+
from typing import List, Optional
|
2
|
+
|
3
|
+
import torch
|
4
|
+
from torch import Tensor
|
5
|
+
|
6
|
+
from torch_geometric.nn.models import GAT
|
7
|
+
from torch_geometric.nn.nlp.llm import BOS, LLM, MAX_NEW_TOKENS
|
8
|
+
from torch_geometric.utils import scatter
|
9
|
+
|
10
|
+
|
11
|
+
class GRetriever(torch.nn.Module):
|
12
|
+
r"""The G-Retriever model from the `"G-Retriever: Retrieval-Augmented
|
13
|
+
Generation for Textual Graph Understanding and Question Answering"
|
14
|
+
<https://arxiv.org/abs/2402.07630>`_ paper.
|
15
|
+
|
16
|
+
Args:
|
17
|
+
llm (LLM): The LLM to use.
|
18
|
+
gnn (torch.nn.Module): The GNN to use.
|
19
|
+
use_lora (bool, optional): If set to :obj:`True`, will use LORA from
|
20
|
+
:obj:`peft` for training the LLM, see
|
21
|
+
`here <https://huggingface.co/docs/peft/en/index>`_ for details.
|
22
|
+
(default: :obj:`False`)
|
23
|
+
mlp_out_channels (int, optional): The size of each graph embedding
|
24
|
+
after projection. (default: :obj:`4096`)
|
25
|
+
|
26
|
+
.. warning::
|
27
|
+
This module has been tested with the following HuggingFace models
|
28
|
+
|
29
|
+
* :obj:`llm_to_use="meta-llama/Llama-2-7b-chat-hf"`
|
30
|
+
* :obj:`llm_to_use="google/gemma-7b"`
|
31
|
+
|
32
|
+
and may not work with other models. See other models at `HuggingFace
|
33
|
+
Models <https://huggingface.co/models>`_ and let us know if you
|
34
|
+
encounter any issues.
|
35
|
+
|
36
|
+
.. note::
|
37
|
+
For an example of using :class:`GRetriever`, see
|
38
|
+
`examples/llm/g_retriever.py <https://github.com/pyg-team/
|
39
|
+
pytorch_geometric/blob/master/examples/llm/g_retriever.py>`_.
|
40
|
+
"""
|
41
|
+
def __init__(
|
42
|
+
self,
|
43
|
+
llm: LLM,
|
44
|
+
gnn: torch.nn.Module,
|
45
|
+
use_lora: bool = False,
|
46
|
+
gnn_to_use=GAT,
|
47
|
+
mlp_out_channels: int = 4096,
|
48
|
+
) -> None:
|
49
|
+
super().__init__()
|
50
|
+
|
51
|
+
self.llm = llm
|
52
|
+
self.gnn = gnn.to(self.llm.device)
|
53
|
+
|
54
|
+
self.word_embedding = self.llm.word_embedding
|
55
|
+
self.llm_generator = self.llm.llm
|
56
|
+
if use_lora:
|
57
|
+
from peft import (
|
58
|
+
LoraConfig,
|
59
|
+
get_peft_model,
|
60
|
+
prepare_model_for_kbit_training,
|
61
|
+
)
|
62
|
+
self.llm_generator = prepare_model_for_kbit_training(
|
63
|
+
self.llm_generator)
|
64
|
+
lora_r: int = 8
|
65
|
+
lora_alpha: int = 16
|
66
|
+
lora_dropout: float = 0.05
|
67
|
+
lora_target_modules = ['q_proj', 'v_proj']
|
68
|
+
config = LoraConfig(
|
69
|
+
r=lora_r,
|
70
|
+
lora_alpha=lora_alpha,
|
71
|
+
target_modules=lora_target_modules,
|
72
|
+
lora_dropout=lora_dropout,
|
73
|
+
bias='none',
|
74
|
+
task_type='CAUSAL_LM',
|
75
|
+
)
|
76
|
+
self.llm_generator = get_peft_model(self.llm_generator, config)
|
77
|
+
|
78
|
+
mlp_hidden_channels = self.gnn.out_channels
|
79
|
+
self.projector = torch.nn.Sequential(
|
80
|
+
torch.nn.Linear(mlp_hidden_channels, mlp_hidden_channels),
|
81
|
+
torch.nn.Sigmoid(),
|
82
|
+
torch.nn.Linear(mlp_hidden_channels, mlp_out_channels),
|
83
|
+
).to(self.llm.device)
|
84
|
+
|
85
|
+
def encode(
|
86
|
+
self,
|
87
|
+
x: Tensor,
|
88
|
+
edge_index: Tensor,
|
89
|
+
batch: Tensor,
|
90
|
+
edge_attr: Optional[Tensor],
|
91
|
+
) -> Tensor:
|
92
|
+
x = x.to(self.llm.device)
|
93
|
+
edge_index = edge_index.to(self.llm.device)
|
94
|
+
if edge_attr is not None:
|
95
|
+
edge_attr = edge_attr.to(self.llm.device)
|
96
|
+
batch = batch.to(self.llm.device)
|
97
|
+
|
98
|
+
out = self.gnn(x, edge_index, edge_attr=edge_attr)
|
99
|
+
return scatter(out, batch, dim=0, reduce='mean')
|
100
|
+
|
101
|
+
def forward(
|
102
|
+
self,
|
103
|
+
question: List[str],
|
104
|
+
x: Tensor,
|
105
|
+
edge_index: Tensor,
|
106
|
+
batch: Tensor,
|
107
|
+
label: List[str],
|
108
|
+
edge_attr: Optional[Tensor] = None,
|
109
|
+
additional_text_context: Optional[List[str]] = None,
|
110
|
+
):
|
111
|
+
r"""The forward pass.
|
112
|
+
|
113
|
+
Args:
|
114
|
+
question (List[str]): The questions/prompts.
|
115
|
+
x (torch.Tensor): The input node features.
|
116
|
+
edge_index (torch.Tensor): The edge indices.
|
117
|
+
batch (torch.Tensor): The batch vector
|
118
|
+
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns
|
119
|
+
each element to a specific example.
|
120
|
+
label (List[str]): The answers/labels.
|
121
|
+
edge_attr (torch.Tensor, optional): The edge features (if supported
|
122
|
+
by the GNN). (default: :obj:`None`)
|
123
|
+
additional_text_context (List[str], optional): Additional context
|
124
|
+
to give to the LLM, such as textified knowledge graphs.
|
125
|
+
(default: :obj:`None`)
|
126
|
+
"""
|
127
|
+
x = self.encode(x, edge_index, batch, edge_attr)
|
128
|
+
x = self.projector(x)
|
129
|
+
xs = x.split(x.size(0), dim=0)
|
130
|
+
|
131
|
+
(
|
132
|
+
inputs_embeds,
|
133
|
+
attention_mask,
|
134
|
+
label_input_ids,
|
135
|
+
) = self.llm._get_embeds(question, additional_text_context, xs, label)
|
136
|
+
|
137
|
+
with self.llm.autocast_context:
|
138
|
+
outputs = self.llm_generator(
|
139
|
+
inputs_embeds=inputs_embeds,
|
140
|
+
attention_mask=attention_mask,
|
141
|
+
return_dict=True,
|
142
|
+
labels=label_input_ids,
|
143
|
+
)
|
144
|
+
|
145
|
+
return outputs.loss
|
146
|
+
|
147
|
+
@torch.no_grad()
|
148
|
+
def inference(
|
149
|
+
self,
|
150
|
+
question: List[str],
|
151
|
+
x: Tensor,
|
152
|
+
edge_index: Tensor,
|
153
|
+
batch: Tensor,
|
154
|
+
edge_attr: Optional[Tensor] = None,
|
155
|
+
additional_text_context: Optional[List[str]] = None,
|
156
|
+
max_out_tokens: Optional[int] = MAX_NEW_TOKENS,
|
157
|
+
):
|
158
|
+
r"""The inference pass.
|
159
|
+
|
160
|
+
Args:
|
161
|
+
question (List[str]): The questions/prompts.
|
162
|
+
x (torch.Tensor): The input node features.
|
163
|
+
edge_index (torch.Tensor): The edge indices.
|
164
|
+
batch (torch.Tensor): The batch vector
|
165
|
+
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns
|
166
|
+
each element to a specific example.
|
167
|
+
edge_attr (torch.Tensor, optional): The edge features (if supported
|
168
|
+
by the GNN). (default: :obj:`None`)
|
169
|
+
additional_text_context (List[str], optional): Additional context
|
170
|
+
to give to the LLM, such as textified knowledge graphs.
|
171
|
+
(default: :obj:`None`)
|
172
|
+
max_out_tokens (int, optional): How many tokens for the LLM to
|
173
|
+
generate. (default: :obj:`32`)
|
174
|
+
"""
|
175
|
+
x = self.encode(x, edge_index, batch, edge_attr)
|
176
|
+
x = self.projector(x)
|
177
|
+
xs = x.split(x.size(0), dim=0)
|
178
|
+
|
179
|
+
inputs_embeds, attention_mask, _ = self.llm._get_embeds(
|
180
|
+
question, additional_text_context, xs)
|
181
|
+
|
182
|
+
bos_token = self.llm.tokenizer(
|
183
|
+
BOS,
|
184
|
+
add_special_tokens=False,
|
185
|
+
).input_ids[0]
|
186
|
+
|
187
|
+
with self.llm.autocast_context:
|
188
|
+
outputs = self.llm_generator.generate(
|
189
|
+
inputs_embeds=inputs_embeds,
|
190
|
+
max_new_tokens=max_out_tokens,
|
191
|
+
attention_mask=attention_mask,
|
192
|
+
bos_token_id=bos_token,
|
193
|
+
use_cache=True # Important to set!
|
194
|
+
)
|
195
|
+
|
196
|
+
return self.llm.tokenizer.batch_decode(
|
197
|
+
outputs,
|
198
|
+
skip_special_tokens=True,
|
199
|
+
)
|
200
|
+
|
201
|
+
def __repr__(self) -> str:
|
202
|
+
return (f'{self.__class__.__name__}(\n'
|
203
|
+
f' llm={self.llm},\n'
|
204
|
+
f' gnn={self.gnn},\n'
|
205
|
+
f')')
|
torch_geometric/nn/nlp/llm.py
CHANGED
@@ -1,10 +1,14 @@
|
|
1
|
-
import warnings
|
2
1
|
from contextlib import nullcontext
|
3
2
|
from typing import Any, Dict, List, Optional
|
4
3
|
|
5
4
|
import torch
|
6
5
|
from torch import Tensor
|
7
6
|
|
7
|
+
try:
|
8
|
+
from transformers.tokenization_utils_base import BatchEncoding
|
9
|
+
except ImportError:
|
10
|
+
BatchEncoding = Dict
|
11
|
+
|
8
12
|
BOS = '<s>[INST]'
|
9
13
|
EOS_USER = '[/INST]'
|
10
14
|
EOS = '[/s]'
|
@@ -61,23 +65,16 @@ class LLM(torch.nn.Module):
|
|
61
65
|
) -> None:
|
62
66
|
super().__init__()
|
63
67
|
|
64
|
-
|
68
|
+
self.model_name = model_name
|
65
69
|
|
66
|
-
|
67
|
-
pretty_model_name = 'LLAMA2'
|
68
|
-
model_name = 'meta-llama/Llama-2-7b-chat-hf'
|
69
|
-
elif model_name == 'gemma':
|
70
|
-
pretty_model_name = 'GEMMA'
|
71
|
-
model_name = 'google/gemma-7b'
|
72
|
-
else:
|
73
|
-
pretty_model_name = model_name
|
70
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
74
71
|
|
75
72
|
# A rough heuristic on GPU memory requirements, e.g., we found that
|
76
73
|
# LLAMA2 (7B parameters) fits on a 85GB GPU.
|
77
74
|
required_memory = 85 * num_params / 7
|
78
75
|
kwargs = get_llm_kwargs(required_memory, dtype)
|
79
76
|
|
80
|
-
print(f"Setting up '{
|
77
|
+
print(f"Setting up '{model_name}' with configuration: {kwargs}")
|
81
78
|
self.tokenizer = AutoTokenizer.from_pretrained(
|
82
79
|
model_name,
|
83
80
|
use_fast=False,
|
@@ -88,17 +85,17 @@ class LLM(torch.nn.Module):
|
|
88
85
|
self.word_embedding = self.llm.model.get_input_embeddings()
|
89
86
|
|
90
87
|
if 'max_memory' not in kwargs: # Pure CPU:
|
91
|
-
self.
|
88
|
+
self.device = torch.device('cpu')
|
92
89
|
self.autocast_context = nullcontext()
|
93
90
|
else:
|
94
|
-
self.
|
91
|
+
self.device = self.llm.device
|
95
92
|
self.autocast_context = torch.cuda.amp.autocast(dtype=dtype)
|
96
93
|
|
97
94
|
def _encode_inputs(
|
98
95
|
self,
|
99
96
|
question: List[str],
|
100
97
|
context: Optional[List[str]] = None,
|
101
|
-
) ->
|
98
|
+
) -> tuple:
|
102
99
|
batch_size = len(question)
|
103
100
|
questions = self.tokenizer(question, add_special_tokens=False)
|
104
101
|
if context is not None:
|
@@ -109,14 +106,144 @@ class LLM(torch.nn.Module):
|
|
109
106
|
BOS,
|
110
107
|
add_special_tokens=False,
|
111
108
|
return_tensors='pt',
|
112
|
-
).input_ids[0].to(self.
|
109
|
+
).input_ids[0].to(self.device)
|
113
110
|
bos_embeds = self.word_embedding(bos_token)
|
114
111
|
pad_token = torch.tensor(self.tokenizer.pad_token_id,
|
115
|
-
device=self.
|
112
|
+
device=self.device)
|
116
113
|
pad_embeds = self.word_embedding(pad_token).unsqueeze(0)
|
117
114
|
return (batch_size, questions, context, eos_user_tokens, bos_embeds,
|
118
115
|
pad_embeds)
|
119
116
|
|
117
|
+
def _label_input_ids(
|
118
|
+
self,
|
119
|
+
i: int,
|
120
|
+
label: BatchEncoding,
|
121
|
+
eos_tokens: BatchEncoding,
|
122
|
+
) -> List[int]:
|
123
|
+
label_input_ids = label.input_ids[i][:MAX_NEW_TOKENS]
|
124
|
+
label_input_ids = label_input_ids + eos_tokens.input_ids
|
125
|
+
return label_input_ids
|
126
|
+
|
127
|
+
def _input_ids(
|
128
|
+
self,
|
129
|
+
i: int,
|
130
|
+
context: BatchEncoding,
|
131
|
+
question: BatchEncoding,
|
132
|
+
eos_user_tokens: BatchEncoding,
|
133
|
+
) -> List[int]:
|
134
|
+
input_ids: List[int] = []
|
135
|
+
if context is not None:
|
136
|
+
input_ids += context.input_ids[i][:MAX_TXT_LEN]
|
137
|
+
input_ids += question.input_ids[i]
|
138
|
+
input_ids += eos_user_tokens.input_ids
|
139
|
+
return input_ids
|
140
|
+
|
141
|
+
def _inputs_embeds(
|
142
|
+
self,
|
143
|
+
i: int,
|
144
|
+
input_ids: List[int],
|
145
|
+
bos_embeds: Tensor,
|
146
|
+
embedding: Optional[List[Tensor]] = None,
|
147
|
+
) -> Tensor:
|
148
|
+
inputs_embeds = self.word_embedding(
|
149
|
+
torch.tensor(input_ids, device=self.device))
|
150
|
+
|
151
|
+
to_cat = [bos_embeds]
|
152
|
+
if embedding is not None and embedding[i] is not None:
|
153
|
+
to_cat.append(embedding[i])
|
154
|
+
to_cat.append(inputs_embeds)
|
155
|
+
return torch.cat(to_cat, dim=0).to(self.device)
|
156
|
+
|
157
|
+
def _append_embeds(
|
158
|
+
self,
|
159
|
+
inputs_embeds: Tensor,
|
160
|
+
batch_inputs_embeds: List[Tensor],
|
161
|
+
batch_attention_mask: List[List[int]],
|
162
|
+
label_input_ids: List[int] = None,
|
163
|
+
batch_label_input_ids: Optional[List[List[int]]] = None,
|
164
|
+
) -> tuple:
|
165
|
+
batch_inputs_embeds.append(inputs_embeds)
|
166
|
+
batch_attention_mask.append([1] * inputs_embeds.size(0))
|
167
|
+
if label_input_ids is not None:
|
168
|
+
pad = inputs_embeds.size(0) - len(label_input_ids)
|
169
|
+
label_input_ids = [IGNORE_INDEX] * pad + label_input_ids
|
170
|
+
batch_label_input_ids.append(label_input_ids)
|
171
|
+
return batch_inputs_embeds, batch_attention_mask, batch_label_input_ids
|
172
|
+
|
173
|
+
def _pad_embeds(
|
174
|
+
self,
|
175
|
+
pad_embeds: Tensor,
|
176
|
+
batch_inputs_embeds: List[Tensor],
|
177
|
+
batch_attention_mask: List[List[int]],
|
178
|
+
batch_label_input_ids: Optional[List[List[int]]] = None,
|
179
|
+
) -> tuple:
|
180
|
+
max_length = max([x.size(0) for x in batch_inputs_embeds])
|
181
|
+
batch_size = len(batch_inputs_embeds)
|
182
|
+
for i in range(batch_size):
|
183
|
+
pad = max_length - batch_inputs_embeds[i].size(0)
|
184
|
+
batch_inputs_embeds[i] = torch.cat([
|
185
|
+
pad_embeds.repeat(pad, 1),
|
186
|
+
batch_inputs_embeds[i],
|
187
|
+
])
|
188
|
+
batch_attention_mask[i] = [0] * pad + batch_attention_mask[i]
|
189
|
+
if batch_label_input_ids is not None:
|
190
|
+
tmp = [IGNORE_INDEX] * pad + batch_label_input_ids[i]
|
191
|
+
batch_label_input_ids[i] = tmp
|
192
|
+
inputs_embeds = torch.stack(batch_inputs_embeds, dim=0)
|
193
|
+
attention_mask = torch.tensor(batch_attention_mask, device=self.device)
|
194
|
+
label_input_ids = None
|
195
|
+
if batch_label_input_ids is not None:
|
196
|
+
label_input_ids = torch.tensor(batch_label_input_ids,
|
197
|
+
device=self.device)
|
198
|
+
return inputs_embeds, attention_mask, label_input_ids
|
199
|
+
|
200
|
+
def _get_embeds(
|
201
|
+
self,
|
202
|
+
question: List[str],
|
203
|
+
context: Optional[List[str]] = None,
|
204
|
+
embedding: Optional[List[Tensor]] = None,
|
205
|
+
answer: Optional[List[str]] = None,
|
206
|
+
) -> tuple:
|
207
|
+
(batch_size, question, context, eos_user_tokens, bos_embeds,
|
208
|
+
pad_embeds) = self._encode_inputs(question, context)
|
209
|
+
|
210
|
+
batch_label_input_ids = None
|
211
|
+
if answer is not None:
|
212
|
+
label = self.tokenizer(answer, add_special_tokens=False)
|
213
|
+
eos_tokens = self.tokenizer(EOS, add_special_tokens=False)
|
214
|
+
batch_label_input_ids = []
|
215
|
+
|
216
|
+
batch_inputs_embeds = []
|
217
|
+
batch_attention_mask = []
|
218
|
+
for i in range(batch_size):
|
219
|
+
input_ids = self._input_ids(i, context, question, eos_user_tokens)
|
220
|
+
if answer is not None:
|
221
|
+
label_input_ids = self._label_input_ids(i, label, eos_tokens)
|
222
|
+
input_ids += label_input_ids
|
223
|
+
else:
|
224
|
+
label_input_ids = None
|
225
|
+
|
226
|
+
inputs_embeds = self._inputs_embeds(i, input_ids, bos_embeds,
|
227
|
+
embedding)
|
228
|
+
|
229
|
+
(
|
230
|
+
batch_inputs_embeds,
|
231
|
+
batch_attention_mask,
|
232
|
+
batch_label_input_ids,
|
233
|
+
) = self._append_embeds(
|
234
|
+
inputs_embeds,
|
235
|
+
batch_inputs_embeds,
|
236
|
+
batch_attention_mask,
|
237
|
+
label_input_ids,
|
238
|
+
batch_label_input_ids,
|
239
|
+
)
|
240
|
+
|
241
|
+
inputs_embeds, attention_mask, label_input_ids = self._pad_embeds(
|
242
|
+
pad_embeds, batch_inputs_embeds, batch_attention_mask,
|
243
|
+
batch_label_input_ids)
|
244
|
+
|
245
|
+
return inputs_embeds, attention_mask, label_input_ids
|
246
|
+
|
120
247
|
def forward(
|
121
248
|
self,
|
122
249
|
question: List[str],
|
@@ -133,65 +260,11 @@ class LLM(torch.nn.Module):
|
|
133
260
|
LLM, such as textified knowledge graphs. (default: :obj:`None`)
|
134
261
|
embedding (list[torch.Tensor], optional): RAG embedding
|
135
262
|
tensors, *i.e.* the embedded form of :obj:`context`. Either
|
136
|
-
:obj:`context` or :obj:`
|
263
|
+
:obj:`context` or :obj:`embedding` should be used, not
|
137
264
|
both. (default: :obj:`None`)
|
138
265
|
"""
|
139
|
-
|
140
|
-
|
141
|
-
"compute and memory")
|
142
|
-
|
143
|
-
(batch_size, question, context, eos_user_tokens, bos_embeds,
|
144
|
-
pad_embeds) = self._encode_inputs(question, context)
|
145
|
-
|
146
|
-
label = self.tokenizer(answer, add_special_tokens=False)
|
147
|
-
eos_tokens = self.tokenizer(EOS, add_special_tokens=False)
|
148
|
-
|
149
|
-
batch_inputs_embeds = []
|
150
|
-
batch_attention_mask = []
|
151
|
-
batch_label_input_ids = []
|
152
|
-
for i in range(batch_size):
|
153
|
-
label_input_ids = label.input_ids[i][:MAX_NEW_TOKENS]
|
154
|
-
label_input_ids += eos_tokens.input_ids # Add EOS token.
|
155
|
-
|
156
|
-
input_ids: List[int] = []
|
157
|
-
if context is not None:
|
158
|
-
input_ids += context.input_ids[i][:MAX_TXT_LEN]
|
159
|
-
input_ids += question.input_ids[i]
|
160
|
-
input_ids += eos_user_tokens.input_ids
|
161
|
-
input_ids += label_input_ids
|
162
|
-
|
163
|
-
inputs_embeds = self.word_embedding(
|
164
|
-
torch.tensor(input_ids, device=self.llm_device))
|
165
|
-
|
166
|
-
to_cat = [bos_embeds]
|
167
|
-
if embedding is not None:
|
168
|
-
to_cat.append(embedding[i])
|
169
|
-
to_cat.append(inputs_embeds)
|
170
|
-
inputs_embeds = torch.cat(to_cat, dim=0)
|
171
|
-
|
172
|
-
batch_inputs_embeds.append(inputs_embeds)
|
173
|
-
batch_attention_mask.append([1] * inputs_embeds.size(0))
|
174
|
-
label_input_ids = [IGNORE_INDEX] * (
|
175
|
-
inputs_embeds.size(0) - len(label_input_ids)) + label_input_ids
|
176
|
-
batch_label_input_ids.append(label_input_ids)
|
177
|
-
|
178
|
-
# Pad input embeddings:
|
179
|
-
max_length = max([x.size(0) for x in batch_inputs_embeds])
|
180
|
-
for i in range(batch_size):
|
181
|
-
pad = max_length - batch_inputs_embeds[i].size(0)
|
182
|
-
batch_inputs_embeds[i] = torch.cat([
|
183
|
-
pad_embeds.repeat(pad, 1),
|
184
|
-
batch_inputs_embeds[i],
|
185
|
-
])
|
186
|
-
batch_attention_mask[i] = [0] * pad + batch_attention_mask[i]
|
187
|
-
batch_label_input_ids[i] = ([IGNORE_INDEX] * pad +
|
188
|
-
batch_label_input_ids[i])
|
189
|
-
|
190
|
-
inputs_embeds = torch.stack(batch_inputs_embeds, dim=0)
|
191
|
-
attention_mask = torch.tensor(batch_attention_mask,
|
192
|
-
device=self.llm_device)
|
193
|
-
label_input_ids = torch.tensor(batch_label_input_ids,
|
194
|
-
device=self.llm_device)
|
266
|
+
inputs_embeds, attention_mask, label_input_ids = self._get_embeds(
|
267
|
+
question, context, embedding, answer)
|
195
268
|
|
196
269
|
with self.autocast_context:
|
197
270
|
outputs = self.llm(
|
@@ -219,52 +292,13 @@ class LLM(torch.nn.Module):
|
|
219
292
|
LLM, such as textified knowledge graphs. (default: :obj:`None`)
|
220
293
|
embedding (list[torch.Tensor], optional): RAG embedding
|
221
294
|
tensors, *i.e.* the embedded form of :obj:`context`. Either
|
222
|
-
:obj:`context` or :obj:`
|
295
|
+
:obj:`context` or :obj:`embedding` should be used, not
|
223
296
|
both. (default: :obj:`None`)
|
224
297
|
max_tokens (int, optional): How many tokens for the LLM to
|
225
298
|
generate. (default: :obj:`32`)
|
226
299
|
"""
|
227
|
-
|
228
|
-
|
229
|
-
"compute and memory")
|
230
|
-
|
231
|
-
(batch_size, question, context, eos_user_tokens, bos_embeds,
|
232
|
-
pad_embeds) = self._encode_inputs(question, context)
|
233
|
-
|
234
|
-
batch_inputs_embeds = []
|
235
|
-
batch_attention_mask = []
|
236
|
-
for i in range(batch_size):
|
237
|
-
input_ids: List[int] = []
|
238
|
-
if context is not None:
|
239
|
-
input_ids = context.input_ids[i][:MAX_TXT_LEN]
|
240
|
-
input_ids += question.input_ids[i]
|
241
|
-
input_ids += eos_user_tokens.input_ids
|
242
|
-
|
243
|
-
inputs_embeds = self.word_embedding(
|
244
|
-
torch.tensor(input_ids, device=self.llm_device))
|
245
|
-
|
246
|
-
to_cat = [bos_embeds]
|
247
|
-
if embedding is not None:
|
248
|
-
to_cat.append(embedding[i])
|
249
|
-
to_cat.append(inputs_embeds)
|
250
|
-
inputs_embeds = torch.cat(to_cat, dim=0)
|
251
|
-
|
252
|
-
batch_inputs_embeds.append(inputs_embeds)
|
253
|
-
batch_attention_mask.append([1] * inputs_embeds.size(0))
|
254
|
-
|
255
|
-
# Pad input embeddings:
|
256
|
-
max_length = max([x.size(0) for x in batch_inputs_embeds])
|
257
|
-
for i in range(batch_size):
|
258
|
-
pad = max_length - batch_inputs_embeds[i].size(0)
|
259
|
-
batch_inputs_embeds[i] = torch.cat([
|
260
|
-
pad_embeds.repeat(pad, 1),
|
261
|
-
batch_inputs_embeds[i],
|
262
|
-
])
|
263
|
-
batch_attention_mask[i] = [0] * pad + batch_attention_mask[i]
|
264
|
-
|
265
|
-
inputs_embeds = torch.stack(batch_inputs_embeds, dim=0)
|
266
|
-
attention_mask = torch.tensor(batch_attention_mask,
|
267
|
-
device=self.llm_device)
|
300
|
+
inputs_embeds, attention_mask, _ = self._get_embeds(
|
301
|
+
question, context, embedding)
|
268
302
|
|
269
303
|
bos_token = self.tokenizer(
|
270
304
|
BOS,
|
@@ -281,3 +315,6 @@ class LLM(torch.nn.Module):
|
|
281
315
|
)
|
282
316
|
|
283
317
|
return self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
318
|
+
|
319
|
+
def __repr__(self) -> str:
|
320
|
+
return f'{self.__class__.__name__}({self.model_name})'
|
@@ -7,18 +7,19 @@ from torch import Tensor
|
|
7
7
|
import torch_geometric.typing
|
8
8
|
from torch_geometric.typing import OptTensor, torch_cluster
|
9
9
|
|
10
|
-
from .asap import ASAPooling
|
11
10
|
from .avg_pool import avg_pool, avg_pool_neighbor_x, avg_pool_x
|
12
|
-
from .edge_pool import EdgePooling
|
13
11
|
from .glob import global_add_pool, global_max_pool, global_mean_pool
|
14
12
|
from .knn import (KNNIndex, L2KNNIndex, MIPSKNNIndex, ApproxL2KNNIndex,
|
15
13
|
ApproxMIPSKNNIndex)
|
16
14
|
from .graclus import graclus
|
17
15
|
from .max_pool import max_pool, max_pool_neighbor_x, max_pool_x
|
18
|
-
from .mem_pool import MemPooling
|
19
|
-
from .pan_pool import PANPooling
|
20
|
-
from .sag_pool import SAGPooling
|
21
16
|
from .topk_pool import TopKPooling
|
17
|
+
from .sag_pool import SAGPooling
|
18
|
+
from .edge_pool import EdgePooling
|
19
|
+
from .cluster_pool import ClusterPooling
|
20
|
+
from .asap import ASAPooling
|
21
|
+
from .pan_pool import PANPooling
|
22
|
+
from .mem_pool import MemPooling
|
22
23
|
from .voxel_grid import voxel_grid
|
23
24
|
from .approx_knn import approx_knn, approx_knn_graph
|
24
25
|
|
@@ -344,6 +345,7 @@ __all__ = [
|
|
344
345
|
'TopKPooling',
|
345
346
|
'SAGPooling',
|
346
347
|
'EdgePooling',
|
348
|
+
'ClusterPooling',
|
347
349
|
'ASAPooling',
|
348
350
|
'PANPooling',
|
349
351
|
'MemPooling',
|
@@ -0,0 +1,145 @@
|
|
1
|
+
from typing import NamedTuple, Optional, Tuple
|
2
|
+
|
3
|
+
import torch
|
4
|
+
import torch.nn.functional as F
|
5
|
+
from torch import Tensor
|
6
|
+
|
7
|
+
from torch_geometric.utils import (
|
8
|
+
dense_to_sparse,
|
9
|
+
one_hot,
|
10
|
+
to_dense_adj,
|
11
|
+
to_scipy_sparse_matrix,
|
12
|
+
)
|
13
|
+
|
14
|
+
|
15
|
+
class UnpoolInfo(NamedTuple):
|
16
|
+
edge_index: Tensor
|
17
|
+
cluster: Tensor
|
18
|
+
batch: Tensor
|
19
|
+
|
20
|
+
|
21
|
+
class ClusterPooling(torch.nn.Module):
|
22
|
+
r"""The cluster pooling operator from the `"Edge-Based Graph Component
|
23
|
+
Pooling" <paper url>`_ paper.
|
24
|
+
|
25
|
+
:class:`ClusterPooling` computes a score for each edge.
|
26
|
+
Based on the selected edges, graph clusters are calculated and compressed
|
27
|
+
to one node using the injective :obj:`"sum"` aggregation function.
|
28
|
+
Edges are remapped based on the nodes created by each cluster and the
|
29
|
+
original edges.
|
30
|
+
|
31
|
+
Args:
|
32
|
+
in_channels (int): Size of each input sample.
|
33
|
+
edge_score_method (str, optional): The function to apply
|
34
|
+
to compute the edge score from raw edge scores (:obj:`"tanh"`,
|
35
|
+
:obj:`"sigmoid"`, :obj:`"log_softmax"`). (default: :obj:`"tanh"`)
|
36
|
+
dropout (float, optional): The probability with
|
37
|
+
which to drop edge scores during training. (default: :obj:`0.0`)
|
38
|
+
threshold (float, optional): The threshold of edge scores. If set to
|
39
|
+
:obj:`None`, will be automatically inferred depending on
|
40
|
+
:obj:`edge_score_method`. (default: :obj:`None`)
|
41
|
+
"""
|
42
|
+
def __init__(
|
43
|
+
self,
|
44
|
+
in_channels: int,
|
45
|
+
edge_score_method: str = 'tanh',
|
46
|
+
dropout: float = 0.0,
|
47
|
+
threshold: Optional[float] = None,
|
48
|
+
):
|
49
|
+
super().__init__()
|
50
|
+
assert edge_score_method in ['tanh', 'sigmoid', 'log_softmax']
|
51
|
+
|
52
|
+
if threshold is None:
|
53
|
+
threshold = 0.5 if edge_score_method == 'sigmoid' else 0.0
|
54
|
+
|
55
|
+
self.in_channels = in_channels
|
56
|
+
self.edge_score_method = edge_score_method
|
57
|
+
self.dropout = dropout
|
58
|
+
self.threshhold = threshold
|
59
|
+
|
60
|
+
self.lin = torch.nn.Linear(2 * in_channels, 1)
|
61
|
+
|
62
|
+
def reset_parameters(self):
|
63
|
+
r"""Resets all learnable parameters of the module."""
|
64
|
+
self.lin.reset_parameters()
|
65
|
+
|
66
|
+
def forward(
|
67
|
+
self,
|
68
|
+
x: Tensor,
|
69
|
+
edge_index: Tensor,
|
70
|
+
batch: Tensor,
|
71
|
+
) -> Tuple[Tensor, Tensor, Tensor, UnpoolInfo]:
|
72
|
+
r"""Forward pass.
|
73
|
+
|
74
|
+
Args:
|
75
|
+
x (torch.Tensor): The node features.
|
76
|
+
edge_index (torch.Tensor): The edge indices.
|
77
|
+
batch (torch.Tensor): Batch vector
|
78
|
+
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns
|
79
|
+
each node to a specific example.
|
80
|
+
|
81
|
+
Return types:
|
82
|
+
* **x** *(torch.Tensor)* - The pooled node features.
|
83
|
+
* **edge_index** *(torch.Tensor)* - The coarsened edge indices.
|
84
|
+
* **batch** *(torch.Tensor)* - The coarsened batch vector.
|
85
|
+
* **unpool_info** *(UnpoolInfo)* - Information that can be consumed
|
86
|
+
for unpooling.
|
87
|
+
"""
|
88
|
+
mask = edge_index[0] != edge_index[1]
|
89
|
+
edge_index = edge_index[:, mask]
|
90
|
+
|
91
|
+
edge_attr = torch.cat(
|
92
|
+
[x[edge_index[0]], x[edge_index[1]]],
|
93
|
+
dim=-1,
|
94
|
+
)
|
95
|
+
edge_score = self.lin(edge_attr).view(-1)
|
96
|
+
edge_score = F.dropout(edge_score, p=self.dropout,
|
97
|
+
training=self.training)
|
98
|
+
|
99
|
+
if self.edge_score_method == 'tanh':
|
100
|
+
edge_score = edge_score.tanh()
|
101
|
+
elif self.edge_score_method == 'sigmoid':
|
102
|
+
edge_score = edge_score.sigmoid()
|
103
|
+
else:
|
104
|
+
assert self.edge_score_method == 'log_softmax'
|
105
|
+
edge_score = F.log_softmax(edge_score, dim=0)
|
106
|
+
|
107
|
+
return self._merge_edges(x, edge_index, batch, edge_score)
|
108
|
+
|
109
|
+
def _merge_edges(
|
110
|
+
self,
|
111
|
+
x: Tensor,
|
112
|
+
edge_index: Tensor,
|
113
|
+
batch: Tensor,
|
114
|
+
edge_score: Tensor,
|
115
|
+
) -> Tuple[Tensor, Tensor, Tensor, UnpoolInfo]:
|
116
|
+
|
117
|
+
from scipy.sparse.csgraph import connected_components
|
118
|
+
|
119
|
+
edge_contract = edge_index[:, edge_score > self.threshhold]
|
120
|
+
|
121
|
+
adj = to_scipy_sparse_matrix(edge_contract, num_nodes=x.size(0))
|
122
|
+
_, cluster_np = connected_components(adj, directed=True,
|
123
|
+
connection="weak")
|
124
|
+
|
125
|
+
cluster = torch.tensor(cluster_np, dtype=torch.long, device=x.device)
|
126
|
+
C = one_hot(cluster)
|
127
|
+
A = to_dense_adj(edge_index, max_num_nodes=x.size(0)).squeeze(0)
|
128
|
+
S = to_dense_adj(edge_index, edge_attr=edge_score,
|
129
|
+
max_num_nodes=x.size(0)).squeeze(0)
|
130
|
+
|
131
|
+
A_contract = to_dense_adj(edge_contract,
|
132
|
+
max_num_nodes=x.size(0)).squeeze(0)
|
133
|
+
nodes_single = ((A_contract.sum(dim=-1) +
|
134
|
+
A_contract.sum(dim=-2)) == 0).nonzero()
|
135
|
+
S[nodes_single, nodes_single] = 1.0
|
136
|
+
|
137
|
+
x_out = (S @ C).t() @ x
|
138
|
+
edge_index_out, _ = dense_to_sparse((C.T @ A @ C).fill_diagonal_(0))
|
139
|
+
batch_out = batch.new_empty(x_out.size(0)).scatter_(0, cluster, batch)
|
140
|
+
unpool_info = UnpoolInfo(edge_index, cluster, batch)
|
141
|
+
|
142
|
+
return x_out, edge_index_out, batch_out, unpool_info
|
143
|
+
|
144
|
+
def __repr__(self) -> str:
|
145
|
+
return f'{self.__class__.__name__}({self.in_channels})'
|
File without changes
|