pyg-nightly 2.6.0.dev20240912__py3-none-any.whl → 2.6.0.dev20240913__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pyg_nightly-2.6.0.dev20240912.dist-info → pyg_nightly-2.6.0.dev20240913.dist-info}/METADATA +1 -1
- {pyg_nightly-2.6.0.dev20240912.dist-info → pyg_nightly-2.6.0.dev20240913.dist-info}/RECORD +9 -8
- torch_geometric/__init__.py +1 -1
- torch_geometric/datasets/__init__.py +3 -1
- torch_geometric/datasets/web_qsp_dataset.py +239 -0
- torch_geometric/nn/models/g_retriever.py +18 -4
- torch_geometric/nn/nlp/llm.py +2 -0
- torch_geometric/nn/nlp/sentence_transformer.py +8 -3
- {pyg_nightly-2.6.0.dev20240912.dist-info → pyg_nightly-2.6.0.dev20240913.dist-info}/WHEEL +0 -0
{pyg_nightly-2.6.0.dev20240912.dist-info → pyg_nightly-2.6.0.dev20240913.dist-info}/METADATA
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: pyg-nightly
|
3
|
-
Version: 2.6.0.
|
3
|
+
Version: 2.6.0.dev20240913
|
4
4
|
Summary: Graph Neural Network Library for PyTorch
|
5
5
|
Keywords: deep-learning,pytorch,geometric-deep-learning,graph-neural-networks,graph-convolutional-networks
|
6
6
|
Author-email: Matthias Fey <matthias@pyg.org>
|
@@ -1,4 +1,4 @@
|
|
1
|
-
torch_geometric/__init__.py,sha256=
|
1
|
+
torch_geometric/__init__.py,sha256=SRoAYR7ELKP-AZQAN4ZGBuNrS8AKNWoF7O_9u_FX6LA,1904
|
2
2
|
torch_geometric/_compile.py,sha256=0HAdz6MGmyrgi4g6P-PorTg8dPIKx3Jo4zVJavrlfX0,1139
|
3
3
|
torch_geometric/_onnx.py,sha256=V9ffrIKSqhDw6xUZ12lkuSfNs48cQp2EeJ6Z19GfnVw,349
|
4
4
|
torch_geometric/backend.py,sha256=lVaf7aLoVaB3M-UcByUJ1G4T4FOK6LXAg0CF4W3E8jo,1575
|
@@ -53,7 +53,7 @@ torch_geometric/data/temporal.py,sha256=WOJ6gFrTLikaLhUvotyUF5ql14FkE5Ox3hNkdSp6
|
|
53
53
|
torch_geometric/data/view.py,sha256=XjkVSc-UWZFCT4DlXLShZtO8duhFQkS9gq88zZXANsk,1089
|
54
54
|
torch_geometric/data/lightning/__init__.py,sha256=w3En1tJfy3kSqe1MycpOyZpHFO3fxBCgNCUOznPA3YU,178
|
55
55
|
torch_geometric/data/lightning/datamodule.py,sha256=Bn9iaIfE4NWDDWWMqCvBeZ4bIW1Silx_Ol5CPJCliaQ,29242
|
56
|
-
torch_geometric/datasets/__init__.py,sha256=
|
56
|
+
torch_geometric/datasets/__init__.py,sha256=fey-955PyCQXGBeUTNPWwU5uK3PJOEvaY1_fDt1SxXc,5880
|
57
57
|
torch_geometric/datasets/actor.py,sha256=H8srMdo5qo8eg4LDxEdYcxZi49I_HVDcr8R_pb2W99Q,4461
|
58
58
|
torch_geometric/datasets/airfrans.py,sha256=7Yt0Xs7jx2NotPT4_9GbpLRWRXYSS5g_4zSENoB_9hs,5684
|
59
59
|
torch_geometric/datasets/airports.py,sha256=HSZdi6KM_yavppaUl0uWyZ93BEsrtDf9InjPPu9zaUE,3903
|
@@ -149,6 +149,7 @@ torch_geometric/datasets/tosca.py,sha256=nUSF8NQT1GlkwWQLshjWmr8xORsvRHzzIqhUyDC
|
|
149
149
|
torch_geometric/datasets/tu_dataset.py,sha256=14OSaXBgVwT1dX2h1wZ3xVIwoo0GQBEfR3yWh6Q0VF0,7847
|
150
150
|
torch_geometric/datasets/twitch.py,sha256=qfEerf-Uaojx2ZvegENowdG4E7RoUT_HUO9xtULadvo,3658
|
151
151
|
torch_geometric/datasets/upfd.py,sha256=crqO8uQNz1wC1JOn4prSs8iOGv9LuLK3dZf_KUV9tUE,7010
|
152
|
+
torch_geometric/datasets/web_qsp_dataset.py,sha256=OusHv0DcvDgCjUbBtkhPzwm2pdPlyG98BSzaQPv_GP8,8451
|
152
153
|
torch_geometric/datasets/webkb.py,sha256=beC1kWeW7cIjYwWyaINQSk-3lmVR85Lus7cKZniHp8Y,4879
|
153
154
|
torch_geometric/datasets/wikics.py,sha256=iTzYif1WvbMXnMdhPMfvrkVaAbnM009WiB_f_JWZqhU,3879
|
154
155
|
torch_geometric/datasets/wikidata.py,sha256=9mYShF_HlpTmcdLpiaP_tYJ9eQtUOu5vRPvohN6RXqI,4979
|
@@ -426,7 +427,7 @@ torch_geometric/nn/models/deep_graph_infomax.py,sha256=u6j-5-iHBASDCZ776dyfCI1N8
|
|
426
427
|
torch_geometric/nn/models/deepgcn.py,sha256=tIgT03cj8MghYlxEozpoGvGG_CwpJrGDxv1Z0CVIUts,4339
|
427
428
|
torch_geometric/nn/models/dimenet.py,sha256=Kc5p-rB5q-0e8lY22l-OdQTscTxJh2lTEpeRFMdL4RY,36186
|
428
429
|
torch_geometric/nn/models/dimenet_utils.py,sha256=xP_nbzkSSL25GC3rrZ9KP8x9QZ59S-CZuHzCmQ-K0fI,5062
|
429
|
-
torch_geometric/nn/models/g_retriever.py,sha256=
|
430
|
+
torch_geometric/nn/models/g_retriever.py,sha256=VueRImNJlh1WvRWcsSXliSw8RlxlzWlu2WSFs_VQaJc,7749
|
430
431
|
torch_geometric/nn/models/gnnff.py,sha256=15dkiLgy0LmH1hnUrpeoHioIp4BPTfjpVATpnGRt9E0,7860
|
431
432
|
torch_geometric/nn/models/graph_mixer.py,sha256=mthMeCOikR8gseEsu4oJ3Cd9C35zHSv1p32ROwnG-6s,9246
|
432
433
|
torch_geometric/nn/models/graph_unet.py,sha256=WFb7d_DBByMGyXh3AdK2CKNmvMmSKsSUt8l8UnSOovs,5395
|
@@ -449,8 +450,8 @@ torch_geometric/nn/models/signed_gcn.py,sha256=J40CnedFIqtKI1LhW1ITSEFRbA_XiJZL6
|
|
449
450
|
torch_geometric/nn/models/tgn.py,sha256=kEGdfLJybkbMT4UMoAh2nCzfX3_nDjfm1cicuPHEwAM,11878
|
450
451
|
torch_geometric/nn/models/visnet.py,sha256=97OFMCsPDEI5BCSi7RhoRcU2CNRp7zck2tEzrltFZj4,43192
|
451
452
|
torch_geometric/nn/nlp/__init__.py,sha256=JJESTA7w_K8v60XbCd25IqmrKKHLz5OiNexMHYGV2mE,138
|
452
|
-
torch_geometric/nn/nlp/llm.py,sha256=
|
453
|
-
torch_geometric/nn/nlp/sentence_transformer.py,sha256=
|
453
|
+
torch_geometric/nn/nlp/llm.py,sha256=a5YkJA32Ok2PmWFEJ0VJD0HfsauDpxosIwlij6wqwJo,11728
|
454
|
+
torch_geometric/nn/nlp/sentence_transformer.py,sha256=JrTN3W1srdkNX7qYDGB08mY5615i5nfEJSTHAdd5EuA,3260
|
454
455
|
torch_geometric/nn/norm/__init__.py,sha256=u2qIDrkbeuObGVXSAIftAlvSd6ouGTtxznCfD-59UiA,669
|
455
456
|
torch_geometric/nn/norm/batch_norm.py,sha256=sJKrinHGwA-noIgteg1RD2W06rd0zskD-rXuY-36glY,8283
|
456
457
|
torch_geometric/nn/norm/diff_group_norm.py,sha256=b57XvNekrUYGDjNJlGeqvaMGNJmHwopSF0_yyBWlLuA,4722
|
@@ -617,6 +618,6 @@ torch_geometric/utils/undirected.py,sha256=H_nfpI0_WluOG6VfjPyldvcjL4w5USAKWu2x5
|
|
617
618
|
torch_geometric/visualization/__init__.py,sha256=PyR_4K5SafsJrBr6qWrkjKr6GBL1b7FtZybyXCDEVwY,154
|
618
619
|
torch_geometric/visualization/graph.py,sha256=SvbdVx5Zmuy_WSSA4-WWCkqAcCSHVe84mjMfsEWbZCs,4813
|
619
620
|
torch_geometric/visualization/influence.py,sha256=CWMvuNA_Nf1sfbJmQgn58yS4OFpeKXeZPe7kEuvkUBw,477
|
620
|
-
pyg_nightly-2.6.0.
|
621
|
-
pyg_nightly-2.6.0.
|
622
|
-
pyg_nightly-2.6.0.
|
621
|
+
pyg_nightly-2.6.0.dev20240913.dist-info/WHEEL,sha256=EZbGkh7Ie4PoZfRQ8I0ZuP9VklN_TvcZ6DSE5Uar4z4,81
|
622
|
+
pyg_nightly-2.6.0.dev20240913.dist-info/METADATA,sha256=LOqSGxoPrSOJz1djA4lRO9o1Z0Y0-D8NRau_0xGS1tQ,63068
|
623
|
+
pyg_nightly-2.6.0.dev20240913.dist-info/RECORD,,
|
torch_geometric/__init__.py
CHANGED
@@ -30,7 +30,7 @@ from .lazy_loader import LazyLoader
|
|
30
30
|
contrib = LazyLoader('contrib', globals(), 'torch_geometric.contrib')
|
31
31
|
graphgym = LazyLoader('graphgym', globals(), 'torch_geometric.graphgym')
|
32
32
|
|
33
|
-
__version__ = '2.6.0.
|
33
|
+
__version__ = '2.6.0.dev20240913'
|
34
34
|
|
35
35
|
__all__ = [
|
36
36
|
'Index',
|
@@ -61,7 +61,6 @@ from .gemsec import GemsecDeezer
|
|
61
61
|
from .twitch import Twitch
|
62
62
|
from .airports import Airports
|
63
63
|
from .lrgb import LRGBDataset
|
64
|
-
from .neurograph import NeuroGraphDataset
|
65
64
|
from .malnet_tiny import MalNetTiny
|
66
65
|
from .omdb import OMDB
|
67
66
|
from .polblogs import PolBlogs
|
@@ -76,6 +75,8 @@ from .jodie import JODIEDataset
|
|
76
75
|
from .wikidata import Wikidata5M
|
77
76
|
from .myket import MyketDataset
|
78
77
|
from .brca_tgca import BrcaTcga
|
78
|
+
from .neurograph import NeuroGraphDataset
|
79
|
+
from .web_qsp_dataset import WebQSPDataset
|
79
80
|
|
80
81
|
from .dbp15k import DBP15K
|
81
82
|
from .aminer import AMiner
|
@@ -188,6 +189,7 @@ homo_datasets = [
|
|
188
189
|
'MyketDataset',
|
189
190
|
'BrcaTcga',
|
190
191
|
'NeuroGraphDataset',
|
192
|
+
'WebQSPDataset',
|
191
193
|
]
|
192
194
|
|
193
195
|
hetero_datasets = [
|
@@ -0,0 +1,239 @@
|
|
1
|
+
# Code adapted from the G-Retriever paper: https://arxiv.org/abs/2402.07630
|
2
|
+
from typing import Any, Dict, List, Tuple, no_type_check
|
3
|
+
|
4
|
+
import numpy as np
|
5
|
+
import torch
|
6
|
+
from torch import Tensor
|
7
|
+
from tqdm import tqdm
|
8
|
+
|
9
|
+
from torch_geometric.data import Data, InMemoryDataset
|
10
|
+
from torch_geometric.nn.nlp import SentenceTransformer
|
11
|
+
|
12
|
+
|
13
|
+
@no_type_check
|
14
|
+
def retrieval_via_pcst(
|
15
|
+
data: Data,
|
16
|
+
q_emb: Tensor,
|
17
|
+
textual_nodes: Any,
|
18
|
+
textual_edges: Any,
|
19
|
+
topk: int = 3,
|
20
|
+
topk_e: int = 3,
|
21
|
+
cost_e: float = 0.5,
|
22
|
+
) -> Tuple[Data, str]:
|
23
|
+
c = 0.01
|
24
|
+
if len(textual_nodes) == 0 or len(textual_edges) == 0:
|
25
|
+
desc = textual_nodes.to_csv(index=False) + "\n" + textual_edges.to_csv(
|
26
|
+
index=False,
|
27
|
+
columns=["src", "edge_attr", "dst"],
|
28
|
+
)
|
29
|
+
return data, desc
|
30
|
+
|
31
|
+
from pcst_fast import pcst_fast
|
32
|
+
|
33
|
+
root = -1
|
34
|
+
num_clusters = 1
|
35
|
+
pruning = 'gw'
|
36
|
+
verbosity_level = 0
|
37
|
+
if topk > 0:
|
38
|
+
n_prizes = torch.nn.CosineSimilarity(dim=-1)(q_emb, data.x)
|
39
|
+
topk = min(topk, data.num_nodes)
|
40
|
+
_, topk_n_indices = torch.topk(n_prizes, topk, largest=True)
|
41
|
+
|
42
|
+
n_prizes = torch.zeros_like(n_prizes)
|
43
|
+
n_prizes[topk_n_indices] = torch.arange(topk, 0, -1).float()
|
44
|
+
else:
|
45
|
+
n_prizes = torch.zeros(data.num_nodes)
|
46
|
+
|
47
|
+
if topk_e > 0:
|
48
|
+
e_prizes = torch.nn.CosineSimilarity(dim=-1)(q_emb, data.edge_attr)
|
49
|
+
topk_e = min(topk_e, e_prizes.unique().size(0))
|
50
|
+
|
51
|
+
topk_e_values, _ = torch.topk(e_prizes.unique(), topk_e, largest=True)
|
52
|
+
e_prizes[e_prizes < topk_e_values[-1]] = 0.0
|
53
|
+
last_topk_e_value = topk_e
|
54
|
+
for k in range(topk_e):
|
55
|
+
indices = e_prizes == topk_e_values[k]
|
56
|
+
value = min((topk_e - k) / sum(indices), last_topk_e_value - c)
|
57
|
+
e_prizes[indices] = value
|
58
|
+
last_topk_e_value = value * (1 - c)
|
59
|
+
# reduce the cost of the edges such that at least one edge is selected
|
60
|
+
cost_e = min(cost_e, e_prizes.max().item() * (1 - c / 2))
|
61
|
+
else:
|
62
|
+
e_prizes = torch.zeros(data.num_edges)
|
63
|
+
|
64
|
+
costs = []
|
65
|
+
edges = []
|
66
|
+
virtual_n_prizes = []
|
67
|
+
virtual_edges = []
|
68
|
+
virtual_costs = []
|
69
|
+
mapping_n = {}
|
70
|
+
mapping_e = {}
|
71
|
+
for i, (src, dst) in enumerate(data.edge_index.t().numpy()):
|
72
|
+
prize_e = e_prizes[i]
|
73
|
+
if prize_e <= cost_e:
|
74
|
+
mapping_e[len(edges)] = i
|
75
|
+
edges.append((src, dst))
|
76
|
+
costs.append(cost_e - prize_e)
|
77
|
+
else:
|
78
|
+
virtual_node_id = data.num_nodes + len(virtual_n_prizes)
|
79
|
+
mapping_n[virtual_node_id] = i
|
80
|
+
virtual_edges.append((src, virtual_node_id))
|
81
|
+
virtual_edges.append((virtual_node_id, dst))
|
82
|
+
virtual_costs.append(0)
|
83
|
+
virtual_costs.append(0)
|
84
|
+
virtual_n_prizes.append(prize_e - cost_e)
|
85
|
+
|
86
|
+
prizes = np.concatenate([n_prizes, np.array(virtual_n_prizes)])
|
87
|
+
num_edges = len(edges)
|
88
|
+
if len(virtual_costs) > 0:
|
89
|
+
costs = np.array(costs + virtual_costs)
|
90
|
+
edges = np.array(edges + virtual_edges)
|
91
|
+
|
92
|
+
vertices, edges = pcst_fast(edges, prizes, costs, root, num_clusters,
|
93
|
+
pruning, verbosity_level)
|
94
|
+
|
95
|
+
selected_nodes = vertices[vertices < data.num_nodes]
|
96
|
+
selected_edges = [mapping_e[e] for e in edges if e < num_edges]
|
97
|
+
virtual_vertices = vertices[vertices >= data.num_nodes]
|
98
|
+
if len(virtual_vertices) > 0:
|
99
|
+
virtual_vertices = vertices[vertices >= data.num_nodes]
|
100
|
+
virtual_edges = [mapping_n[i] for i in virtual_vertices]
|
101
|
+
selected_edges = np.array(selected_edges + virtual_edges)
|
102
|
+
|
103
|
+
edge_index = data.edge_index[:, selected_edges]
|
104
|
+
selected_nodes = np.unique(
|
105
|
+
np.concatenate(
|
106
|
+
[selected_nodes, edge_index[0].numpy(), edge_index[1].numpy()]))
|
107
|
+
|
108
|
+
n = textual_nodes.iloc[selected_nodes]
|
109
|
+
e = textual_edges.iloc[selected_edges]
|
110
|
+
desc = n.to_csv(index=False) + '\n' + e.to_csv(
|
111
|
+
index=False, columns=['src', 'edge_attr', 'dst'])
|
112
|
+
|
113
|
+
mapping = {n: i for i, n in enumerate(selected_nodes.tolist())}
|
114
|
+
src = [mapping[i] for i in edge_index[0].tolist()]
|
115
|
+
dst = [mapping[i] for i in edge_index[1].tolist()]
|
116
|
+
|
117
|
+
data = Data(
|
118
|
+
x=data.x[selected_nodes],
|
119
|
+
edge_index=torch.tensor([src, dst]),
|
120
|
+
edge_attr=data.edge_attr[selected_edges],
|
121
|
+
)
|
122
|
+
|
123
|
+
return data, desc
|
124
|
+
|
125
|
+
|
126
|
+
class WebQSPDataset(InMemoryDataset):
|
127
|
+
r"""The WebQuestionsSP dataset of the `"The Value of Semantic Parse
|
128
|
+
Labeling for Knowledge Base Question Answering"
|
129
|
+
<https://aclanthology.org/P16-2033/>`_ paper.
|
130
|
+
|
131
|
+
Args:
|
132
|
+
root (str): Root directory where the dataset should be saved.
|
133
|
+
split (str, optional): If :obj:`"train"`, loads the training dataset.
|
134
|
+
If :obj:`"val"`, loads the validation dataset.
|
135
|
+
If :obj:`"test"`, loads the test dataset. (default: :obj:`"train"`)
|
136
|
+
force_reload (bool, optional): Whether to re-process the dataset.
|
137
|
+
(default: :obj:`False`)
|
138
|
+
"""
|
139
|
+
def __init__(
|
140
|
+
self,
|
141
|
+
root: str,
|
142
|
+
split: str = "train",
|
143
|
+
force_reload: bool = False,
|
144
|
+
) -> None:
|
145
|
+
super().__init__(root, force_reload=force_reload)
|
146
|
+
|
147
|
+
if split not in {'train', 'val', 'test'}:
|
148
|
+
raise ValueError(f"Invalid 'split' argument (got {split})")
|
149
|
+
|
150
|
+
path = self.processed_paths[['train', 'val', 'test'].index(split)]
|
151
|
+
self.load(path)
|
152
|
+
|
153
|
+
@property
|
154
|
+
def processed_file_names(self) -> List[str]:
|
155
|
+
return ['train_data.pt', 'val_data.pt', 'test_data.pt']
|
156
|
+
|
157
|
+
def process(self) -> None:
|
158
|
+
import datasets
|
159
|
+
import pandas as pd
|
160
|
+
|
161
|
+
datasets = datasets.load_dataset('rmanluo/RoG-webqsp')
|
162
|
+
|
163
|
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
164
|
+
model_name = 'sentence-transformers/all-roberta-large-v1'
|
165
|
+
model = SentenceTransformer(model_name).to(device)
|
166
|
+
model.eval()
|
167
|
+
|
168
|
+
for dataset, path in zip(
|
169
|
+
[datasets['train'], datasets['validation'], datasets['test']],
|
170
|
+
self.processed_paths,
|
171
|
+
):
|
172
|
+
questions = [example["question"] for example in dataset]
|
173
|
+
question_embs = model.encode(
|
174
|
+
questions,
|
175
|
+
batch_size=256,
|
176
|
+
output_device='cpu',
|
177
|
+
)
|
178
|
+
|
179
|
+
data_list = []
|
180
|
+
for i, example in enumerate(tqdm(dataset)):
|
181
|
+
raw_nodes: Dict[str, int] = {}
|
182
|
+
raw_edges = []
|
183
|
+
for tri in example["graph"]:
|
184
|
+
h, r, t = tri
|
185
|
+
h = h.lower()
|
186
|
+
t = t.lower()
|
187
|
+
if h not in raw_nodes:
|
188
|
+
raw_nodes[h] = len(raw_nodes)
|
189
|
+
if t not in raw_nodes:
|
190
|
+
raw_nodes[t] = len(raw_nodes)
|
191
|
+
raw_edges.append({
|
192
|
+
"src": raw_nodes[h],
|
193
|
+
"edge_attr": r,
|
194
|
+
"dst": raw_nodes[t]
|
195
|
+
})
|
196
|
+
nodes = pd.DataFrame([{
|
197
|
+
"node_id": v,
|
198
|
+
"node_attr": k,
|
199
|
+
} for k, v in raw_nodes.items()])
|
200
|
+
edges = pd.DataFrame(raw_edges)
|
201
|
+
|
202
|
+
nodes.node_attr = nodes.node_attr.fillna("")
|
203
|
+
x = model.encode(
|
204
|
+
nodes.node_attr.tolist(),
|
205
|
+
batch_size=256,
|
206
|
+
output_device='cpu',
|
207
|
+
)
|
208
|
+
edge_attr = model.encode(
|
209
|
+
edges.edge_attr.tolist(),
|
210
|
+
batch_size=256,
|
211
|
+
output_device='cpu',
|
212
|
+
)
|
213
|
+
edge_index = torch.tensor([
|
214
|
+
edges.src.tolist(),
|
215
|
+
edges.dst.tolist(),
|
216
|
+
])
|
217
|
+
|
218
|
+
question = f"Question: {example['question']}\nAnswer: "
|
219
|
+
label = ('|').join(example['answer']).lower()
|
220
|
+
data = Data(
|
221
|
+
x=x,
|
222
|
+
edge_index=edge_index,
|
223
|
+
edge_attr=edge_attr,
|
224
|
+
)
|
225
|
+
data, desc = retrieval_via_pcst(
|
226
|
+
data,
|
227
|
+
question_embs[i],
|
228
|
+
nodes,
|
229
|
+
edges,
|
230
|
+
topk=3,
|
231
|
+
topk_e=5,
|
232
|
+
cost_e=0.5,
|
233
|
+
)
|
234
|
+
data.question = question
|
235
|
+
data.label = label
|
236
|
+
data.desc = desc
|
237
|
+
data_list.append(data)
|
238
|
+
|
239
|
+
self.save(data_list, path)
|
@@ -3,7 +3,6 @@ from typing import List, Optional
|
|
3
3
|
import torch
|
4
4
|
from torch import Tensor
|
5
5
|
|
6
|
-
from torch_geometric.nn.models import GAT
|
7
6
|
from torch_geometric.nn.nlp.llm import BOS, LLM, MAX_NEW_TOKENS
|
8
7
|
from torch_geometric.utils import scatter
|
9
8
|
|
@@ -43,7 +42,6 @@ class GRetriever(torch.nn.Module):
|
|
43
42
|
llm: LLM,
|
44
43
|
gnn: torch.nn.Module,
|
45
44
|
use_lora: bool = False,
|
46
|
-
gnn_to_use=GAT,
|
47
45
|
mlp_out_channels: int = 4096,
|
48
46
|
) -> None:
|
49
47
|
super().__init__()
|
@@ -126,7 +124,15 @@ class GRetriever(torch.nn.Module):
|
|
126
124
|
"""
|
127
125
|
x = self.encode(x, edge_index, batch, edge_attr)
|
128
126
|
x = self.projector(x)
|
129
|
-
xs = x.split(
|
127
|
+
xs = x.split(1, dim=0)
|
128
|
+
|
129
|
+
# Handle questions without node features:
|
130
|
+
batch_unique = batch.unique()
|
131
|
+
batch_size = len(question)
|
132
|
+
if len(batch_unique) < batch_size:
|
133
|
+
xs = [
|
134
|
+
xs[i] if i in batch_unique else None for i in range(batch_size)
|
135
|
+
]
|
130
136
|
|
131
137
|
(
|
132
138
|
inputs_embeds,
|
@@ -174,7 +180,15 @@ class GRetriever(torch.nn.Module):
|
|
174
180
|
"""
|
175
181
|
x = self.encode(x, edge_index, batch, edge_attr)
|
176
182
|
x = self.projector(x)
|
177
|
-
xs = x.split(
|
183
|
+
xs = x.split(1, dim=0)
|
184
|
+
|
185
|
+
# Handle questions without node features:
|
186
|
+
batch_unique = batch.unique()
|
187
|
+
batch_size = len(question)
|
188
|
+
if len(batch_unique) < batch_size:
|
189
|
+
xs = [
|
190
|
+
xs[i] if i in batch_unique else None for i in range(batch_size)
|
191
|
+
]
|
178
192
|
|
179
193
|
inputs_embeds, attention_mask, _ = self.llm._get_embeds(
|
180
194
|
question, additional_text_context, xs)
|
torch_geometric/nn/nlp/llm.py
CHANGED
@@ -1,3 +1,4 @@
|
|
1
|
+
import warnings
|
1
2
|
from contextlib import nullcontext
|
2
3
|
from typing import Any, Dict, List, Optional
|
3
4
|
|
@@ -85,6 +86,7 @@ class LLM(torch.nn.Module):
|
|
85
86
|
self.word_embedding = self.llm.model.get_input_embeddings()
|
86
87
|
|
87
88
|
if 'max_memory' not in kwargs: # Pure CPU:
|
89
|
+
warnings.warn("LLM is being used on CPU, which may be slow")
|
88
90
|
self.device = torch.device('cpu')
|
89
91
|
self.autocast_context = nullcontext()
|
90
92
|
else:
|
@@ -54,8 +54,11 @@ class SentenceTransformer(torch.nn.Module):
|
|
54
54
|
self,
|
55
55
|
text: List[str],
|
56
56
|
batch_size: Optional[int] = None,
|
57
|
-
output_device: Optional[torch.device] = None,
|
57
|
+
output_device: Optional[Union[torch.device, str]] = None,
|
58
58
|
) -> Tensor:
|
59
|
+
is_empty = len(text) == 0
|
60
|
+
text = ['dummy'] if is_empty else text
|
61
|
+
|
59
62
|
batch_size = len(text) if batch_size is None else batch_size
|
60
63
|
|
61
64
|
embs: List[Tensor] = []
|
@@ -70,11 +73,13 @@ class SentenceTransformer(torch.nn.Module):
|
|
70
73
|
emb = self(
|
71
74
|
input_ids=token.input_ids.to(self.device),
|
72
75
|
attention_mask=token.attention_mask.to(self.device),
|
73
|
-
).to(output_device
|
76
|
+
).to(output_device)
|
74
77
|
|
75
78
|
embs.append(emb)
|
76
79
|
|
77
|
-
|
80
|
+
out = torch.cat(embs, dim=0) if len(embs) > 1 else embs[0]
|
81
|
+
out = out[:0] if is_empty else out
|
82
|
+
return out
|
78
83
|
|
79
84
|
def __repr__(self) -> str:
|
80
85
|
return f'{self.__class__.__name__}(model_name={self.model_name})'
|
File without changes
|