pyg-nightly 2.6.0.dev20240912__py3-none-any.whl → 2.7.0.dev20240914__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pyg_nightly-2.6.0.dev20240912.dist-info → pyg_nightly-2.7.0.dev20240914.dist-info}/METADATA +1 -1
- {pyg_nightly-2.6.0.dev20240912.dist-info → pyg_nightly-2.7.0.dev20240914.dist-info}/RECORD +11 -10
- torch_geometric/__init__.py +1 -1
- torch_geometric/datasets/__init__.py +3 -1
- torch_geometric/datasets/web_qsp_dataset.py +239 -0
- torch_geometric/edge_index.py +1 -1
- torch_geometric/index.py +2 -2
- torch_geometric/nn/models/g_retriever.py +18 -4
- torch_geometric/nn/nlp/llm.py +2 -0
- torch_geometric/nn/nlp/sentence_transformer.py +8 -3
- {pyg_nightly-2.6.0.dev20240912.dist-info → pyg_nightly-2.7.0.dev20240914.dist-info}/WHEEL +0 -0
{pyg_nightly-2.6.0.dev20240912.dist-info → pyg_nightly-2.7.0.dev20240914.dist-info}/METADATA
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: pyg-nightly
|
3
|
-
Version: 2.
|
3
|
+
Version: 2.7.0.dev20240914
|
4
4
|
Summary: Graph Neural Network Library for PyTorch
|
5
5
|
Keywords: deep-learning,pytorch,geometric-deep-learning,graph-neural-networks,graph-convolutional-networks
|
6
6
|
Author-email: Matthias Fey <matthias@pyg.org>
|
@@ -1,4 +1,4 @@
|
|
1
|
-
torch_geometric/__init__.py,sha256=
|
1
|
+
torch_geometric/__init__.py,sha256=n23o0W0D6f8VcOq-d2ljeW7QxvqdEiX9hzH_YvgmLK0,1904
|
2
2
|
torch_geometric/_compile.py,sha256=0HAdz6MGmyrgi4g6P-PorTg8dPIKx3Jo4zVJavrlfX0,1139
|
3
3
|
torch_geometric/_onnx.py,sha256=V9ffrIKSqhDw6xUZ12lkuSfNs48cQp2EeJ6Z19GfnVw,349
|
4
4
|
torch_geometric/backend.py,sha256=lVaf7aLoVaB3M-UcByUJ1G4T4FOK6LXAg0CF4W3E8jo,1575
|
@@ -7,10 +7,10 @@ torch_geometric/config_store.py,sha256=zdMzlgBpUmBkPovpYQh5fMNwTZLDq2OneqX47QEx7
|
|
7
7
|
torch_geometric/debug.py,sha256=cLyH9OaL2v7POyW-80b19w-ctA7a_5EZsS4aUF1wc2U,1295
|
8
8
|
torch_geometric/deprecation.py,sha256=dWRymDIUkUVI2MeEmBG5WF4R6jObZeseSBV9G6FNfjc,858
|
9
9
|
torch_geometric/device.py,sha256=tU5-_lBNVbVHl_kUmWPwiG5mQ1pyapwMF4JkmtNN3MM,1224
|
10
|
-
torch_geometric/edge_index.py,sha256=
|
10
|
+
torch_geometric/edge_index.py,sha256=r4_24Rhm3YCK0BF-kzLvL7PlY_1tWcXrBDIr7JPDjkw,70048
|
11
11
|
torch_geometric/experimental.py,sha256=JbtNNEXjFGI8hZ9raM6-qrZURP6Z5nlDK8QicZUIbz0,4756
|
12
12
|
torch_geometric/home.py,sha256=EV54B4Dmiv61GDbkCwtCfWGWJ4eFGwZ8s3KOgGjwYgY,790
|
13
|
-
torch_geometric/index.py,sha256=
|
13
|
+
torch_geometric/index.py,sha256=9ChzWFCwj2slNcVBOgfV-wQn-KscJe_y7502w-Vf76w,24045
|
14
14
|
torch_geometric/inspector.py,sha256=9M61T9ruSid5-r2aelRAeX9g_7AZ1VMnYAB2KozM71E,19267
|
15
15
|
torch_geometric/isinstance.py,sha256=truZjdU9PxSvjJ6k0d_CLJ2iOpen2o8U-54pbUbNRyE,935
|
16
16
|
torch_geometric/lazy_loader.py,sha256=SM0UcXtIdiFge75MKBAWXedoiSOdFDOV0rm1PfoF9cE,908
|
@@ -53,7 +53,7 @@ torch_geometric/data/temporal.py,sha256=WOJ6gFrTLikaLhUvotyUF5ql14FkE5Ox3hNkdSp6
|
|
53
53
|
torch_geometric/data/view.py,sha256=XjkVSc-UWZFCT4DlXLShZtO8duhFQkS9gq88zZXANsk,1089
|
54
54
|
torch_geometric/data/lightning/__init__.py,sha256=w3En1tJfy3kSqe1MycpOyZpHFO3fxBCgNCUOznPA3YU,178
|
55
55
|
torch_geometric/data/lightning/datamodule.py,sha256=Bn9iaIfE4NWDDWWMqCvBeZ4bIW1Silx_Ol5CPJCliaQ,29242
|
56
|
-
torch_geometric/datasets/__init__.py,sha256=
|
56
|
+
torch_geometric/datasets/__init__.py,sha256=fey-955PyCQXGBeUTNPWwU5uK3PJOEvaY1_fDt1SxXc,5880
|
57
57
|
torch_geometric/datasets/actor.py,sha256=H8srMdo5qo8eg4LDxEdYcxZi49I_HVDcr8R_pb2W99Q,4461
|
58
58
|
torch_geometric/datasets/airfrans.py,sha256=7Yt0Xs7jx2NotPT4_9GbpLRWRXYSS5g_4zSENoB_9hs,5684
|
59
59
|
torch_geometric/datasets/airports.py,sha256=HSZdi6KM_yavppaUl0uWyZ93BEsrtDf9InjPPu9zaUE,3903
|
@@ -149,6 +149,7 @@ torch_geometric/datasets/tosca.py,sha256=nUSF8NQT1GlkwWQLshjWmr8xORsvRHzzIqhUyDC
|
|
149
149
|
torch_geometric/datasets/tu_dataset.py,sha256=14OSaXBgVwT1dX2h1wZ3xVIwoo0GQBEfR3yWh6Q0VF0,7847
|
150
150
|
torch_geometric/datasets/twitch.py,sha256=qfEerf-Uaojx2ZvegENowdG4E7RoUT_HUO9xtULadvo,3658
|
151
151
|
torch_geometric/datasets/upfd.py,sha256=crqO8uQNz1wC1JOn4prSs8iOGv9LuLK3dZf_KUV9tUE,7010
|
152
|
+
torch_geometric/datasets/web_qsp_dataset.py,sha256=OusHv0DcvDgCjUbBtkhPzwm2pdPlyG98BSzaQPv_GP8,8451
|
152
153
|
torch_geometric/datasets/webkb.py,sha256=beC1kWeW7cIjYwWyaINQSk-3lmVR85Lus7cKZniHp8Y,4879
|
153
154
|
torch_geometric/datasets/wikics.py,sha256=iTzYif1WvbMXnMdhPMfvrkVaAbnM009WiB_f_JWZqhU,3879
|
154
155
|
torch_geometric/datasets/wikidata.py,sha256=9mYShF_HlpTmcdLpiaP_tYJ9eQtUOu5vRPvohN6RXqI,4979
|
@@ -426,7 +427,7 @@ torch_geometric/nn/models/deep_graph_infomax.py,sha256=u6j-5-iHBASDCZ776dyfCI1N8
|
|
426
427
|
torch_geometric/nn/models/deepgcn.py,sha256=tIgT03cj8MghYlxEozpoGvGG_CwpJrGDxv1Z0CVIUts,4339
|
427
428
|
torch_geometric/nn/models/dimenet.py,sha256=Kc5p-rB5q-0e8lY22l-OdQTscTxJh2lTEpeRFMdL4RY,36186
|
428
429
|
torch_geometric/nn/models/dimenet_utils.py,sha256=xP_nbzkSSL25GC3rrZ9KP8x9QZ59S-CZuHzCmQ-K0fI,5062
|
429
|
-
torch_geometric/nn/models/g_retriever.py,sha256=
|
430
|
+
torch_geometric/nn/models/g_retriever.py,sha256=VueRImNJlh1WvRWcsSXliSw8RlxlzWlu2WSFs_VQaJc,7749
|
430
431
|
torch_geometric/nn/models/gnnff.py,sha256=15dkiLgy0LmH1hnUrpeoHioIp4BPTfjpVATpnGRt9E0,7860
|
431
432
|
torch_geometric/nn/models/graph_mixer.py,sha256=mthMeCOikR8gseEsu4oJ3Cd9C35zHSv1p32ROwnG-6s,9246
|
432
433
|
torch_geometric/nn/models/graph_unet.py,sha256=WFb7d_DBByMGyXh3AdK2CKNmvMmSKsSUt8l8UnSOovs,5395
|
@@ -449,8 +450,8 @@ torch_geometric/nn/models/signed_gcn.py,sha256=J40CnedFIqtKI1LhW1ITSEFRbA_XiJZL6
|
|
449
450
|
torch_geometric/nn/models/tgn.py,sha256=kEGdfLJybkbMT4UMoAh2nCzfX3_nDjfm1cicuPHEwAM,11878
|
450
451
|
torch_geometric/nn/models/visnet.py,sha256=97OFMCsPDEI5BCSi7RhoRcU2CNRp7zck2tEzrltFZj4,43192
|
451
452
|
torch_geometric/nn/nlp/__init__.py,sha256=JJESTA7w_K8v60XbCd25IqmrKKHLz5OiNexMHYGV2mE,138
|
452
|
-
torch_geometric/nn/nlp/llm.py,sha256=
|
453
|
-
torch_geometric/nn/nlp/sentence_transformer.py,sha256=
|
453
|
+
torch_geometric/nn/nlp/llm.py,sha256=a5YkJA32Ok2PmWFEJ0VJD0HfsauDpxosIwlij6wqwJo,11728
|
454
|
+
torch_geometric/nn/nlp/sentence_transformer.py,sha256=JrTN3W1srdkNX7qYDGB08mY5615i5nfEJSTHAdd5EuA,3260
|
454
455
|
torch_geometric/nn/norm/__init__.py,sha256=u2qIDrkbeuObGVXSAIftAlvSd6ouGTtxznCfD-59UiA,669
|
455
456
|
torch_geometric/nn/norm/batch_norm.py,sha256=sJKrinHGwA-noIgteg1RD2W06rd0zskD-rXuY-36glY,8283
|
456
457
|
torch_geometric/nn/norm/diff_group_norm.py,sha256=b57XvNekrUYGDjNJlGeqvaMGNJmHwopSF0_yyBWlLuA,4722
|
@@ -617,6 +618,6 @@ torch_geometric/utils/undirected.py,sha256=H_nfpI0_WluOG6VfjPyldvcjL4w5USAKWu2x5
|
|
617
618
|
torch_geometric/visualization/__init__.py,sha256=PyR_4K5SafsJrBr6qWrkjKr6GBL1b7FtZybyXCDEVwY,154
|
618
619
|
torch_geometric/visualization/graph.py,sha256=SvbdVx5Zmuy_WSSA4-WWCkqAcCSHVe84mjMfsEWbZCs,4813
|
619
620
|
torch_geometric/visualization/influence.py,sha256=CWMvuNA_Nf1sfbJmQgn58yS4OFpeKXeZPe7kEuvkUBw,477
|
620
|
-
pyg_nightly-2.
|
621
|
-
pyg_nightly-2.
|
622
|
-
pyg_nightly-2.
|
621
|
+
pyg_nightly-2.7.0.dev20240914.dist-info/WHEEL,sha256=EZbGkh7Ie4PoZfRQ8I0ZuP9VklN_TvcZ6DSE5Uar4z4,81
|
622
|
+
pyg_nightly-2.7.0.dev20240914.dist-info/METADATA,sha256=Pz5f49zwvSFV-FkxufA5cJDayObU27GV4gtdnv1KN0g,63068
|
623
|
+
pyg_nightly-2.7.0.dev20240914.dist-info/RECORD,,
|
torch_geometric/__init__.py
CHANGED
@@ -30,7 +30,7 @@ from .lazy_loader import LazyLoader
|
|
30
30
|
contrib = LazyLoader('contrib', globals(), 'torch_geometric.contrib')
|
31
31
|
graphgym = LazyLoader('graphgym', globals(), 'torch_geometric.graphgym')
|
32
32
|
|
33
|
-
__version__ = '2.
|
33
|
+
__version__ = '2.7.0.dev20240914'
|
34
34
|
|
35
35
|
__all__ = [
|
36
36
|
'Index',
|
@@ -61,7 +61,6 @@ from .gemsec import GemsecDeezer
|
|
61
61
|
from .twitch import Twitch
|
62
62
|
from .airports import Airports
|
63
63
|
from .lrgb import LRGBDataset
|
64
|
-
from .neurograph import NeuroGraphDataset
|
65
64
|
from .malnet_tiny import MalNetTiny
|
66
65
|
from .omdb import OMDB
|
67
66
|
from .polblogs import PolBlogs
|
@@ -76,6 +75,8 @@ from .jodie import JODIEDataset
|
|
76
75
|
from .wikidata import Wikidata5M
|
77
76
|
from .myket import MyketDataset
|
78
77
|
from .brca_tgca import BrcaTcga
|
78
|
+
from .neurograph import NeuroGraphDataset
|
79
|
+
from .web_qsp_dataset import WebQSPDataset
|
79
80
|
|
80
81
|
from .dbp15k import DBP15K
|
81
82
|
from .aminer import AMiner
|
@@ -188,6 +189,7 @@ homo_datasets = [
|
|
188
189
|
'MyketDataset',
|
189
190
|
'BrcaTcga',
|
190
191
|
'NeuroGraphDataset',
|
192
|
+
'WebQSPDataset',
|
191
193
|
]
|
192
194
|
|
193
195
|
hetero_datasets = [
|
@@ -0,0 +1,239 @@
|
|
1
|
+
# Code adapted from the G-Retriever paper: https://arxiv.org/abs/2402.07630
|
2
|
+
from typing import Any, Dict, List, Tuple, no_type_check
|
3
|
+
|
4
|
+
import numpy as np
|
5
|
+
import torch
|
6
|
+
from torch import Tensor
|
7
|
+
from tqdm import tqdm
|
8
|
+
|
9
|
+
from torch_geometric.data import Data, InMemoryDataset
|
10
|
+
from torch_geometric.nn.nlp import SentenceTransformer
|
11
|
+
|
12
|
+
|
13
|
+
@no_type_check
|
14
|
+
def retrieval_via_pcst(
|
15
|
+
data: Data,
|
16
|
+
q_emb: Tensor,
|
17
|
+
textual_nodes: Any,
|
18
|
+
textual_edges: Any,
|
19
|
+
topk: int = 3,
|
20
|
+
topk_e: int = 3,
|
21
|
+
cost_e: float = 0.5,
|
22
|
+
) -> Tuple[Data, str]:
|
23
|
+
c = 0.01
|
24
|
+
if len(textual_nodes) == 0 or len(textual_edges) == 0:
|
25
|
+
desc = textual_nodes.to_csv(index=False) + "\n" + textual_edges.to_csv(
|
26
|
+
index=False,
|
27
|
+
columns=["src", "edge_attr", "dst"],
|
28
|
+
)
|
29
|
+
return data, desc
|
30
|
+
|
31
|
+
from pcst_fast import pcst_fast
|
32
|
+
|
33
|
+
root = -1
|
34
|
+
num_clusters = 1
|
35
|
+
pruning = 'gw'
|
36
|
+
verbosity_level = 0
|
37
|
+
if topk > 0:
|
38
|
+
n_prizes = torch.nn.CosineSimilarity(dim=-1)(q_emb, data.x)
|
39
|
+
topk = min(topk, data.num_nodes)
|
40
|
+
_, topk_n_indices = torch.topk(n_prizes, topk, largest=True)
|
41
|
+
|
42
|
+
n_prizes = torch.zeros_like(n_prizes)
|
43
|
+
n_prizes[topk_n_indices] = torch.arange(topk, 0, -1).float()
|
44
|
+
else:
|
45
|
+
n_prizes = torch.zeros(data.num_nodes)
|
46
|
+
|
47
|
+
if topk_e > 0:
|
48
|
+
e_prizes = torch.nn.CosineSimilarity(dim=-1)(q_emb, data.edge_attr)
|
49
|
+
topk_e = min(topk_e, e_prizes.unique().size(0))
|
50
|
+
|
51
|
+
topk_e_values, _ = torch.topk(e_prizes.unique(), topk_e, largest=True)
|
52
|
+
e_prizes[e_prizes < topk_e_values[-1]] = 0.0
|
53
|
+
last_topk_e_value = topk_e
|
54
|
+
for k in range(topk_e):
|
55
|
+
indices = e_prizes == topk_e_values[k]
|
56
|
+
value = min((topk_e - k) / sum(indices), last_topk_e_value - c)
|
57
|
+
e_prizes[indices] = value
|
58
|
+
last_topk_e_value = value * (1 - c)
|
59
|
+
# reduce the cost of the edges such that at least one edge is selected
|
60
|
+
cost_e = min(cost_e, e_prizes.max().item() * (1 - c / 2))
|
61
|
+
else:
|
62
|
+
e_prizes = torch.zeros(data.num_edges)
|
63
|
+
|
64
|
+
costs = []
|
65
|
+
edges = []
|
66
|
+
virtual_n_prizes = []
|
67
|
+
virtual_edges = []
|
68
|
+
virtual_costs = []
|
69
|
+
mapping_n = {}
|
70
|
+
mapping_e = {}
|
71
|
+
for i, (src, dst) in enumerate(data.edge_index.t().numpy()):
|
72
|
+
prize_e = e_prizes[i]
|
73
|
+
if prize_e <= cost_e:
|
74
|
+
mapping_e[len(edges)] = i
|
75
|
+
edges.append((src, dst))
|
76
|
+
costs.append(cost_e - prize_e)
|
77
|
+
else:
|
78
|
+
virtual_node_id = data.num_nodes + len(virtual_n_prizes)
|
79
|
+
mapping_n[virtual_node_id] = i
|
80
|
+
virtual_edges.append((src, virtual_node_id))
|
81
|
+
virtual_edges.append((virtual_node_id, dst))
|
82
|
+
virtual_costs.append(0)
|
83
|
+
virtual_costs.append(0)
|
84
|
+
virtual_n_prizes.append(prize_e - cost_e)
|
85
|
+
|
86
|
+
prizes = np.concatenate([n_prizes, np.array(virtual_n_prizes)])
|
87
|
+
num_edges = len(edges)
|
88
|
+
if len(virtual_costs) > 0:
|
89
|
+
costs = np.array(costs + virtual_costs)
|
90
|
+
edges = np.array(edges + virtual_edges)
|
91
|
+
|
92
|
+
vertices, edges = pcst_fast(edges, prizes, costs, root, num_clusters,
|
93
|
+
pruning, verbosity_level)
|
94
|
+
|
95
|
+
selected_nodes = vertices[vertices < data.num_nodes]
|
96
|
+
selected_edges = [mapping_e[e] for e in edges if e < num_edges]
|
97
|
+
virtual_vertices = vertices[vertices >= data.num_nodes]
|
98
|
+
if len(virtual_vertices) > 0:
|
99
|
+
virtual_vertices = vertices[vertices >= data.num_nodes]
|
100
|
+
virtual_edges = [mapping_n[i] for i in virtual_vertices]
|
101
|
+
selected_edges = np.array(selected_edges + virtual_edges)
|
102
|
+
|
103
|
+
edge_index = data.edge_index[:, selected_edges]
|
104
|
+
selected_nodes = np.unique(
|
105
|
+
np.concatenate(
|
106
|
+
[selected_nodes, edge_index[0].numpy(), edge_index[1].numpy()]))
|
107
|
+
|
108
|
+
n = textual_nodes.iloc[selected_nodes]
|
109
|
+
e = textual_edges.iloc[selected_edges]
|
110
|
+
desc = n.to_csv(index=False) + '\n' + e.to_csv(
|
111
|
+
index=False, columns=['src', 'edge_attr', 'dst'])
|
112
|
+
|
113
|
+
mapping = {n: i for i, n in enumerate(selected_nodes.tolist())}
|
114
|
+
src = [mapping[i] for i in edge_index[0].tolist()]
|
115
|
+
dst = [mapping[i] for i in edge_index[1].tolist()]
|
116
|
+
|
117
|
+
data = Data(
|
118
|
+
x=data.x[selected_nodes],
|
119
|
+
edge_index=torch.tensor([src, dst]),
|
120
|
+
edge_attr=data.edge_attr[selected_edges],
|
121
|
+
)
|
122
|
+
|
123
|
+
return data, desc
|
124
|
+
|
125
|
+
|
126
|
+
class WebQSPDataset(InMemoryDataset):
|
127
|
+
r"""The WebQuestionsSP dataset of the `"The Value of Semantic Parse
|
128
|
+
Labeling for Knowledge Base Question Answering"
|
129
|
+
<https://aclanthology.org/P16-2033/>`_ paper.
|
130
|
+
|
131
|
+
Args:
|
132
|
+
root (str): Root directory where the dataset should be saved.
|
133
|
+
split (str, optional): If :obj:`"train"`, loads the training dataset.
|
134
|
+
If :obj:`"val"`, loads the validation dataset.
|
135
|
+
If :obj:`"test"`, loads the test dataset. (default: :obj:`"train"`)
|
136
|
+
force_reload (bool, optional): Whether to re-process the dataset.
|
137
|
+
(default: :obj:`False`)
|
138
|
+
"""
|
139
|
+
def __init__(
|
140
|
+
self,
|
141
|
+
root: str,
|
142
|
+
split: str = "train",
|
143
|
+
force_reload: bool = False,
|
144
|
+
) -> None:
|
145
|
+
super().__init__(root, force_reload=force_reload)
|
146
|
+
|
147
|
+
if split not in {'train', 'val', 'test'}:
|
148
|
+
raise ValueError(f"Invalid 'split' argument (got {split})")
|
149
|
+
|
150
|
+
path = self.processed_paths[['train', 'val', 'test'].index(split)]
|
151
|
+
self.load(path)
|
152
|
+
|
153
|
+
@property
|
154
|
+
def processed_file_names(self) -> List[str]:
|
155
|
+
return ['train_data.pt', 'val_data.pt', 'test_data.pt']
|
156
|
+
|
157
|
+
def process(self) -> None:
|
158
|
+
import datasets
|
159
|
+
import pandas as pd
|
160
|
+
|
161
|
+
datasets = datasets.load_dataset('rmanluo/RoG-webqsp')
|
162
|
+
|
163
|
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
164
|
+
model_name = 'sentence-transformers/all-roberta-large-v1'
|
165
|
+
model = SentenceTransformer(model_name).to(device)
|
166
|
+
model.eval()
|
167
|
+
|
168
|
+
for dataset, path in zip(
|
169
|
+
[datasets['train'], datasets['validation'], datasets['test']],
|
170
|
+
self.processed_paths,
|
171
|
+
):
|
172
|
+
questions = [example["question"] for example in dataset]
|
173
|
+
question_embs = model.encode(
|
174
|
+
questions,
|
175
|
+
batch_size=256,
|
176
|
+
output_device='cpu',
|
177
|
+
)
|
178
|
+
|
179
|
+
data_list = []
|
180
|
+
for i, example in enumerate(tqdm(dataset)):
|
181
|
+
raw_nodes: Dict[str, int] = {}
|
182
|
+
raw_edges = []
|
183
|
+
for tri in example["graph"]:
|
184
|
+
h, r, t = tri
|
185
|
+
h = h.lower()
|
186
|
+
t = t.lower()
|
187
|
+
if h not in raw_nodes:
|
188
|
+
raw_nodes[h] = len(raw_nodes)
|
189
|
+
if t not in raw_nodes:
|
190
|
+
raw_nodes[t] = len(raw_nodes)
|
191
|
+
raw_edges.append({
|
192
|
+
"src": raw_nodes[h],
|
193
|
+
"edge_attr": r,
|
194
|
+
"dst": raw_nodes[t]
|
195
|
+
})
|
196
|
+
nodes = pd.DataFrame([{
|
197
|
+
"node_id": v,
|
198
|
+
"node_attr": k,
|
199
|
+
} for k, v in raw_nodes.items()])
|
200
|
+
edges = pd.DataFrame(raw_edges)
|
201
|
+
|
202
|
+
nodes.node_attr = nodes.node_attr.fillna("")
|
203
|
+
x = model.encode(
|
204
|
+
nodes.node_attr.tolist(),
|
205
|
+
batch_size=256,
|
206
|
+
output_device='cpu',
|
207
|
+
)
|
208
|
+
edge_attr = model.encode(
|
209
|
+
edges.edge_attr.tolist(),
|
210
|
+
batch_size=256,
|
211
|
+
output_device='cpu',
|
212
|
+
)
|
213
|
+
edge_index = torch.tensor([
|
214
|
+
edges.src.tolist(),
|
215
|
+
edges.dst.tolist(),
|
216
|
+
])
|
217
|
+
|
218
|
+
question = f"Question: {example['question']}\nAnswer: "
|
219
|
+
label = ('|').join(example['answer']).lower()
|
220
|
+
data = Data(
|
221
|
+
x=x,
|
222
|
+
edge_index=edge_index,
|
223
|
+
edge_attr=edge_attr,
|
224
|
+
)
|
225
|
+
data, desc = retrieval_via_pcst(
|
226
|
+
data,
|
227
|
+
question_embs[i],
|
228
|
+
nodes,
|
229
|
+
edges,
|
230
|
+
topk=3,
|
231
|
+
topk_e=5,
|
232
|
+
cost_e=0.5,
|
233
|
+
)
|
234
|
+
data.question = question
|
235
|
+
data.label = label
|
236
|
+
data.desc = desc
|
237
|
+
data_list.append(data)
|
238
|
+
|
239
|
+
self.save(data_list, path)
|
torch_geometric/edge_index.py
CHANGED
@@ -173,7 +173,7 @@ class EdgeIndex(Tensor):
|
|
173
173
|
:meth:`EdgeIndex.fill_cache_`, and are maintained and adjusted over its
|
174
174
|
lifespan (*e.g.*, when calling :meth:`EdgeIndex.flip`).
|
175
175
|
|
176
|
-
This representation ensures
|
176
|
+
This representation ensures optimal computation in GNN message passing
|
177
177
|
schemes, while preserving the ease-of-use of regular COO-based :pyg:`PyG`
|
178
178
|
workflows.
|
179
179
|
|
torch_geometric/index.py
CHANGED
@@ -106,7 +106,7 @@ class Index(Tensor):
|
|
106
106
|
:meth:`Index.fill_cache_`, and are maintaned and adjusted over its
|
107
107
|
lifespan.
|
108
108
|
|
109
|
-
This representation ensures
|
109
|
+
This representation ensures optimal computation in GNN message passing
|
110
110
|
schemes, while preserving the ease-of-use of regular COO-based :pyg:`PyG`
|
111
111
|
workflows.
|
112
112
|
|
@@ -120,7 +120,7 @@ class Index(Tensor):
|
|
120
120
|
assert index.is_sorted
|
121
121
|
|
122
122
|
# Flipping order:
|
123
|
-
|
123
|
+
index.flip(0)
|
124
124
|
>>> Index([[2, 1, 1, 0], dim_size=3)
|
125
125
|
assert not index.is_sorted
|
126
126
|
|
@@ -3,7 +3,6 @@ from typing import List, Optional
|
|
3
3
|
import torch
|
4
4
|
from torch import Tensor
|
5
5
|
|
6
|
-
from torch_geometric.nn.models import GAT
|
7
6
|
from torch_geometric.nn.nlp.llm import BOS, LLM, MAX_NEW_TOKENS
|
8
7
|
from torch_geometric.utils import scatter
|
9
8
|
|
@@ -43,7 +42,6 @@ class GRetriever(torch.nn.Module):
|
|
43
42
|
llm: LLM,
|
44
43
|
gnn: torch.nn.Module,
|
45
44
|
use_lora: bool = False,
|
46
|
-
gnn_to_use=GAT,
|
47
45
|
mlp_out_channels: int = 4096,
|
48
46
|
) -> None:
|
49
47
|
super().__init__()
|
@@ -126,7 +124,15 @@ class GRetriever(torch.nn.Module):
|
|
126
124
|
"""
|
127
125
|
x = self.encode(x, edge_index, batch, edge_attr)
|
128
126
|
x = self.projector(x)
|
129
|
-
xs = x.split(
|
127
|
+
xs = x.split(1, dim=0)
|
128
|
+
|
129
|
+
# Handle questions without node features:
|
130
|
+
batch_unique = batch.unique()
|
131
|
+
batch_size = len(question)
|
132
|
+
if len(batch_unique) < batch_size:
|
133
|
+
xs = [
|
134
|
+
xs[i] if i in batch_unique else None for i in range(batch_size)
|
135
|
+
]
|
130
136
|
|
131
137
|
(
|
132
138
|
inputs_embeds,
|
@@ -174,7 +180,15 @@ class GRetriever(torch.nn.Module):
|
|
174
180
|
"""
|
175
181
|
x = self.encode(x, edge_index, batch, edge_attr)
|
176
182
|
x = self.projector(x)
|
177
|
-
xs = x.split(
|
183
|
+
xs = x.split(1, dim=0)
|
184
|
+
|
185
|
+
# Handle questions without node features:
|
186
|
+
batch_unique = batch.unique()
|
187
|
+
batch_size = len(question)
|
188
|
+
if len(batch_unique) < batch_size:
|
189
|
+
xs = [
|
190
|
+
xs[i] if i in batch_unique else None for i in range(batch_size)
|
191
|
+
]
|
178
192
|
|
179
193
|
inputs_embeds, attention_mask, _ = self.llm._get_embeds(
|
180
194
|
question, additional_text_context, xs)
|
torch_geometric/nn/nlp/llm.py
CHANGED
@@ -1,3 +1,4 @@
|
|
1
|
+
import warnings
|
1
2
|
from contextlib import nullcontext
|
2
3
|
from typing import Any, Dict, List, Optional
|
3
4
|
|
@@ -85,6 +86,7 @@ class LLM(torch.nn.Module):
|
|
85
86
|
self.word_embedding = self.llm.model.get_input_embeddings()
|
86
87
|
|
87
88
|
if 'max_memory' not in kwargs: # Pure CPU:
|
89
|
+
warnings.warn("LLM is being used on CPU, which may be slow")
|
88
90
|
self.device = torch.device('cpu')
|
89
91
|
self.autocast_context = nullcontext()
|
90
92
|
else:
|
@@ -54,8 +54,11 @@ class SentenceTransformer(torch.nn.Module):
|
|
54
54
|
self,
|
55
55
|
text: List[str],
|
56
56
|
batch_size: Optional[int] = None,
|
57
|
-
output_device: Optional[torch.device] = None,
|
57
|
+
output_device: Optional[Union[torch.device, str]] = None,
|
58
58
|
) -> Tensor:
|
59
|
+
is_empty = len(text) == 0
|
60
|
+
text = ['dummy'] if is_empty else text
|
61
|
+
|
59
62
|
batch_size = len(text) if batch_size is None else batch_size
|
60
63
|
|
61
64
|
embs: List[Tensor] = []
|
@@ -70,11 +73,13 @@ class SentenceTransformer(torch.nn.Module):
|
|
70
73
|
emb = self(
|
71
74
|
input_ids=token.input_ids.to(self.device),
|
72
75
|
attention_mask=token.attention_mask.to(self.device),
|
73
|
-
).to(output_device
|
76
|
+
).to(output_device)
|
74
77
|
|
75
78
|
embs.append(emb)
|
76
79
|
|
77
|
-
|
80
|
+
out = torch.cat(embs, dim=0) if len(embs) > 1 else embs[0]
|
81
|
+
out = out[:0] if is_empty else out
|
82
|
+
return out
|
78
83
|
|
79
84
|
def __repr__(self) -> str:
|
80
85
|
return f'{self.__class__.__name__}(model_name={self.model_name})'
|
File without changes
|