pyg-nightly 2.6.0.dev20240912__py3-none-any.whl → 2.7.0.dev20240914__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: pyg-nightly
3
- Version: 2.6.0.dev20240912
3
+ Version: 2.7.0.dev20240914
4
4
  Summary: Graph Neural Network Library for PyTorch
5
5
  Keywords: deep-learning,pytorch,geometric-deep-learning,graph-neural-networks,graph-convolutional-networks
6
6
  Author-email: Matthias Fey <matthias@pyg.org>
@@ -1,4 +1,4 @@
1
- torch_geometric/__init__.py,sha256=zJh2N-U_IS2TvPl3dt8Cas7iUpxUQ0vxgqsO85xR8cA,1904
1
+ torch_geometric/__init__.py,sha256=n23o0W0D6f8VcOq-d2ljeW7QxvqdEiX9hzH_YvgmLK0,1904
2
2
  torch_geometric/_compile.py,sha256=0HAdz6MGmyrgi4g6P-PorTg8dPIKx3Jo4zVJavrlfX0,1139
3
3
  torch_geometric/_onnx.py,sha256=V9ffrIKSqhDw6xUZ12lkuSfNs48cQp2EeJ6Z19GfnVw,349
4
4
  torch_geometric/backend.py,sha256=lVaf7aLoVaB3M-UcByUJ1G4T4FOK6LXAg0CF4W3E8jo,1575
@@ -7,10 +7,10 @@ torch_geometric/config_store.py,sha256=zdMzlgBpUmBkPovpYQh5fMNwTZLDq2OneqX47QEx7
7
7
  torch_geometric/debug.py,sha256=cLyH9OaL2v7POyW-80b19w-ctA7a_5EZsS4aUF1wc2U,1295
8
8
  torch_geometric/deprecation.py,sha256=dWRymDIUkUVI2MeEmBG5WF4R6jObZeseSBV9G6FNfjc,858
9
9
  torch_geometric/device.py,sha256=tU5-_lBNVbVHl_kUmWPwiG5mQ1pyapwMF4JkmtNN3MM,1224
10
- torch_geometric/edge_index.py,sha256=Kjkk1kW9C0D2xZPdvTKXQUA4WnqqkHWlUJXpwPG-JWc,70052
10
+ torch_geometric/edge_index.py,sha256=r4_24Rhm3YCK0BF-kzLvL7PlY_1tWcXrBDIr7JPDjkw,70048
11
11
  torch_geometric/experimental.py,sha256=JbtNNEXjFGI8hZ9raM6-qrZURP6Z5nlDK8QicZUIbz0,4756
12
12
  torch_geometric/home.py,sha256=EV54B4Dmiv61GDbkCwtCfWGWJ4eFGwZ8s3KOgGjwYgY,790
13
- torch_geometric/index.py,sha256=ZDUq2LTumN1UyYkNF3tYglW8DZ1G-s2ejAIKaNcvfgI,24054
13
+ torch_geometric/index.py,sha256=9ChzWFCwj2slNcVBOgfV-wQn-KscJe_y7502w-Vf76w,24045
14
14
  torch_geometric/inspector.py,sha256=9M61T9ruSid5-r2aelRAeX9g_7AZ1VMnYAB2KozM71E,19267
15
15
  torch_geometric/isinstance.py,sha256=truZjdU9PxSvjJ6k0d_CLJ2iOpen2o8U-54pbUbNRyE,935
16
16
  torch_geometric/lazy_loader.py,sha256=SM0UcXtIdiFge75MKBAWXedoiSOdFDOV0rm1PfoF9cE,908
@@ -53,7 +53,7 @@ torch_geometric/data/temporal.py,sha256=WOJ6gFrTLikaLhUvotyUF5ql14FkE5Ox3hNkdSp6
53
53
  torch_geometric/data/view.py,sha256=XjkVSc-UWZFCT4DlXLShZtO8duhFQkS9gq88zZXANsk,1089
54
54
  torch_geometric/data/lightning/__init__.py,sha256=w3En1tJfy3kSqe1MycpOyZpHFO3fxBCgNCUOznPA3YU,178
55
55
  torch_geometric/data/lightning/datamodule.py,sha256=Bn9iaIfE4NWDDWWMqCvBeZ4bIW1Silx_Ol5CPJCliaQ,29242
56
- torch_geometric/datasets/__init__.py,sha256=U8ieW-6Xb4Ha1YwjoMqsEEOYziLAweJk5vxx9TPgXqs,5816
56
+ torch_geometric/datasets/__init__.py,sha256=fey-955PyCQXGBeUTNPWwU5uK3PJOEvaY1_fDt1SxXc,5880
57
57
  torch_geometric/datasets/actor.py,sha256=H8srMdo5qo8eg4LDxEdYcxZi49I_HVDcr8R_pb2W99Q,4461
58
58
  torch_geometric/datasets/airfrans.py,sha256=7Yt0Xs7jx2NotPT4_9GbpLRWRXYSS5g_4zSENoB_9hs,5684
59
59
  torch_geometric/datasets/airports.py,sha256=HSZdi6KM_yavppaUl0uWyZ93BEsrtDf9InjPPu9zaUE,3903
@@ -149,6 +149,7 @@ torch_geometric/datasets/tosca.py,sha256=nUSF8NQT1GlkwWQLshjWmr8xORsvRHzzIqhUyDC
149
149
  torch_geometric/datasets/tu_dataset.py,sha256=14OSaXBgVwT1dX2h1wZ3xVIwoo0GQBEfR3yWh6Q0VF0,7847
150
150
  torch_geometric/datasets/twitch.py,sha256=qfEerf-Uaojx2ZvegENowdG4E7RoUT_HUO9xtULadvo,3658
151
151
  torch_geometric/datasets/upfd.py,sha256=crqO8uQNz1wC1JOn4prSs8iOGv9LuLK3dZf_KUV9tUE,7010
152
+ torch_geometric/datasets/web_qsp_dataset.py,sha256=OusHv0DcvDgCjUbBtkhPzwm2pdPlyG98BSzaQPv_GP8,8451
152
153
  torch_geometric/datasets/webkb.py,sha256=beC1kWeW7cIjYwWyaINQSk-3lmVR85Lus7cKZniHp8Y,4879
153
154
  torch_geometric/datasets/wikics.py,sha256=iTzYif1WvbMXnMdhPMfvrkVaAbnM009WiB_f_JWZqhU,3879
154
155
  torch_geometric/datasets/wikidata.py,sha256=9mYShF_HlpTmcdLpiaP_tYJ9eQtUOu5vRPvohN6RXqI,4979
@@ -426,7 +427,7 @@ torch_geometric/nn/models/deep_graph_infomax.py,sha256=u6j-5-iHBASDCZ776dyfCI1N8
426
427
  torch_geometric/nn/models/deepgcn.py,sha256=tIgT03cj8MghYlxEozpoGvGG_CwpJrGDxv1Z0CVIUts,4339
427
428
  torch_geometric/nn/models/dimenet.py,sha256=Kc5p-rB5q-0e8lY22l-OdQTscTxJh2lTEpeRFMdL4RY,36186
428
429
  torch_geometric/nn/models/dimenet_utils.py,sha256=xP_nbzkSSL25GC3rrZ9KP8x9QZ59S-CZuHzCmQ-K0fI,5062
429
- torch_geometric/nn/models/g_retriever.py,sha256=uH_aYrFbFNHaAeKQn_LtUgP5ajutLYYD8N9UvSKcpfk,7271
430
+ torch_geometric/nn/models/g_retriever.py,sha256=VueRImNJlh1WvRWcsSXliSw8RlxlzWlu2WSFs_VQaJc,7749
430
431
  torch_geometric/nn/models/gnnff.py,sha256=15dkiLgy0LmH1hnUrpeoHioIp4BPTfjpVATpnGRt9E0,7860
431
432
  torch_geometric/nn/models/graph_mixer.py,sha256=mthMeCOikR8gseEsu4oJ3Cd9C35zHSv1p32ROwnG-6s,9246
432
433
  torch_geometric/nn/models/graph_unet.py,sha256=WFb7d_DBByMGyXh3AdK2CKNmvMmSKsSUt8l8UnSOovs,5395
@@ -449,8 +450,8 @@ torch_geometric/nn/models/signed_gcn.py,sha256=J40CnedFIqtKI1LhW1ITSEFRbA_XiJZL6
449
450
  torch_geometric/nn/models/tgn.py,sha256=kEGdfLJybkbMT4UMoAh2nCzfX3_nDjfm1cicuPHEwAM,11878
450
451
  torch_geometric/nn/models/visnet.py,sha256=97OFMCsPDEI5BCSi7RhoRcU2CNRp7zck2tEzrltFZj4,43192
451
452
  torch_geometric/nn/nlp/__init__.py,sha256=JJESTA7w_K8v60XbCd25IqmrKKHLz5OiNexMHYGV2mE,138
452
- torch_geometric/nn/nlp/llm.py,sha256=KwSXgI55FuHLR_9vhgekDXMaRUodPQceHPD7OCp2KN4,11639
453
- torch_geometric/nn/nlp/sentence_transformer.py,sha256=DzbQO8wgR34BkKpXfMqQu61hMrK94W2MBa3bZ4fDmVs,3114
453
+ torch_geometric/nn/nlp/llm.py,sha256=a5YkJA32Ok2PmWFEJ0VJD0HfsauDpxosIwlij6wqwJo,11728
454
+ torch_geometric/nn/nlp/sentence_transformer.py,sha256=JrTN3W1srdkNX7qYDGB08mY5615i5nfEJSTHAdd5EuA,3260
454
455
  torch_geometric/nn/norm/__init__.py,sha256=u2qIDrkbeuObGVXSAIftAlvSd6ouGTtxznCfD-59UiA,669
455
456
  torch_geometric/nn/norm/batch_norm.py,sha256=sJKrinHGwA-noIgteg1RD2W06rd0zskD-rXuY-36glY,8283
456
457
  torch_geometric/nn/norm/diff_group_norm.py,sha256=b57XvNekrUYGDjNJlGeqvaMGNJmHwopSF0_yyBWlLuA,4722
@@ -617,6 +618,6 @@ torch_geometric/utils/undirected.py,sha256=H_nfpI0_WluOG6VfjPyldvcjL4w5USAKWu2x5
617
618
  torch_geometric/visualization/__init__.py,sha256=PyR_4K5SafsJrBr6qWrkjKr6GBL1b7FtZybyXCDEVwY,154
618
619
  torch_geometric/visualization/graph.py,sha256=SvbdVx5Zmuy_WSSA4-WWCkqAcCSHVe84mjMfsEWbZCs,4813
619
620
  torch_geometric/visualization/influence.py,sha256=CWMvuNA_Nf1sfbJmQgn58yS4OFpeKXeZPe7kEuvkUBw,477
620
- pyg_nightly-2.6.0.dev20240912.dist-info/WHEEL,sha256=EZbGkh7Ie4PoZfRQ8I0ZuP9VklN_TvcZ6DSE5Uar4z4,81
621
- pyg_nightly-2.6.0.dev20240912.dist-info/METADATA,sha256=D-HQbnicYK-Tr_IxBWo2XuDsqcPG2_K5zrYJXg91xnQ,63068
622
- pyg_nightly-2.6.0.dev20240912.dist-info/RECORD,,
621
+ pyg_nightly-2.7.0.dev20240914.dist-info/WHEEL,sha256=EZbGkh7Ie4PoZfRQ8I0ZuP9VklN_TvcZ6DSE5Uar4z4,81
622
+ pyg_nightly-2.7.0.dev20240914.dist-info/METADATA,sha256=Pz5f49zwvSFV-FkxufA5cJDayObU27GV4gtdnv1KN0g,63068
623
+ pyg_nightly-2.7.0.dev20240914.dist-info/RECORD,,
@@ -30,7 +30,7 @@ from .lazy_loader import LazyLoader
30
30
  contrib = LazyLoader('contrib', globals(), 'torch_geometric.contrib')
31
31
  graphgym = LazyLoader('graphgym', globals(), 'torch_geometric.graphgym')
32
32
 
33
- __version__ = '2.6.0.dev20240912'
33
+ __version__ = '2.7.0.dev20240914'
34
34
 
35
35
  __all__ = [
36
36
  'Index',
@@ -61,7 +61,6 @@ from .gemsec import GemsecDeezer
61
61
  from .twitch import Twitch
62
62
  from .airports import Airports
63
63
  from .lrgb import LRGBDataset
64
- from .neurograph import NeuroGraphDataset
65
64
  from .malnet_tiny import MalNetTiny
66
65
  from .omdb import OMDB
67
66
  from .polblogs import PolBlogs
@@ -76,6 +75,8 @@ from .jodie import JODIEDataset
76
75
  from .wikidata import Wikidata5M
77
76
  from .myket import MyketDataset
78
77
  from .brca_tgca import BrcaTcga
78
+ from .neurograph import NeuroGraphDataset
79
+ from .web_qsp_dataset import WebQSPDataset
79
80
 
80
81
  from .dbp15k import DBP15K
81
82
  from .aminer import AMiner
@@ -188,6 +189,7 @@ homo_datasets = [
188
189
  'MyketDataset',
189
190
  'BrcaTcga',
190
191
  'NeuroGraphDataset',
192
+ 'WebQSPDataset',
191
193
  ]
192
194
 
193
195
  hetero_datasets = [
@@ -0,0 +1,239 @@
1
+ # Code adapted from the G-Retriever paper: https://arxiv.org/abs/2402.07630
2
+ from typing import Any, Dict, List, Tuple, no_type_check
3
+
4
+ import numpy as np
5
+ import torch
6
+ from torch import Tensor
7
+ from tqdm import tqdm
8
+
9
+ from torch_geometric.data import Data, InMemoryDataset
10
+ from torch_geometric.nn.nlp import SentenceTransformer
11
+
12
+
13
+ @no_type_check
14
+ def retrieval_via_pcst(
15
+ data: Data,
16
+ q_emb: Tensor,
17
+ textual_nodes: Any,
18
+ textual_edges: Any,
19
+ topk: int = 3,
20
+ topk_e: int = 3,
21
+ cost_e: float = 0.5,
22
+ ) -> Tuple[Data, str]:
23
+ c = 0.01
24
+ if len(textual_nodes) == 0 or len(textual_edges) == 0:
25
+ desc = textual_nodes.to_csv(index=False) + "\n" + textual_edges.to_csv(
26
+ index=False,
27
+ columns=["src", "edge_attr", "dst"],
28
+ )
29
+ return data, desc
30
+
31
+ from pcst_fast import pcst_fast
32
+
33
+ root = -1
34
+ num_clusters = 1
35
+ pruning = 'gw'
36
+ verbosity_level = 0
37
+ if topk > 0:
38
+ n_prizes = torch.nn.CosineSimilarity(dim=-1)(q_emb, data.x)
39
+ topk = min(topk, data.num_nodes)
40
+ _, topk_n_indices = torch.topk(n_prizes, topk, largest=True)
41
+
42
+ n_prizes = torch.zeros_like(n_prizes)
43
+ n_prizes[topk_n_indices] = torch.arange(topk, 0, -1).float()
44
+ else:
45
+ n_prizes = torch.zeros(data.num_nodes)
46
+
47
+ if topk_e > 0:
48
+ e_prizes = torch.nn.CosineSimilarity(dim=-1)(q_emb, data.edge_attr)
49
+ topk_e = min(topk_e, e_prizes.unique().size(0))
50
+
51
+ topk_e_values, _ = torch.topk(e_prizes.unique(), topk_e, largest=True)
52
+ e_prizes[e_prizes < topk_e_values[-1]] = 0.0
53
+ last_topk_e_value = topk_e
54
+ for k in range(topk_e):
55
+ indices = e_prizes == topk_e_values[k]
56
+ value = min((topk_e - k) / sum(indices), last_topk_e_value - c)
57
+ e_prizes[indices] = value
58
+ last_topk_e_value = value * (1 - c)
59
+ # reduce the cost of the edges such that at least one edge is selected
60
+ cost_e = min(cost_e, e_prizes.max().item() * (1 - c / 2))
61
+ else:
62
+ e_prizes = torch.zeros(data.num_edges)
63
+
64
+ costs = []
65
+ edges = []
66
+ virtual_n_prizes = []
67
+ virtual_edges = []
68
+ virtual_costs = []
69
+ mapping_n = {}
70
+ mapping_e = {}
71
+ for i, (src, dst) in enumerate(data.edge_index.t().numpy()):
72
+ prize_e = e_prizes[i]
73
+ if prize_e <= cost_e:
74
+ mapping_e[len(edges)] = i
75
+ edges.append((src, dst))
76
+ costs.append(cost_e - prize_e)
77
+ else:
78
+ virtual_node_id = data.num_nodes + len(virtual_n_prizes)
79
+ mapping_n[virtual_node_id] = i
80
+ virtual_edges.append((src, virtual_node_id))
81
+ virtual_edges.append((virtual_node_id, dst))
82
+ virtual_costs.append(0)
83
+ virtual_costs.append(0)
84
+ virtual_n_prizes.append(prize_e - cost_e)
85
+
86
+ prizes = np.concatenate([n_prizes, np.array(virtual_n_prizes)])
87
+ num_edges = len(edges)
88
+ if len(virtual_costs) > 0:
89
+ costs = np.array(costs + virtual_costs)
90
+ edges = np.array(edges + virtual_edges)
91
+
92
+ vertices, edges = pcst_fast(edges, prizes, costs, root, num_clusters,
93
+ pruning, verbosity_level)
94
+
95
+ selected_nodes = vertices[vertices < data.num_nodes]
96
+ selected_edges = [mapping_e[e] for e in edges if e < num_edges]
97
+ virtual_vertices = vertices[vertices >= data.num_nodes]
98
+ if len(virtual_vertices) > 0:
99
+ virtual_vertices = vertices[vertices >= data.num_nodes]
100
+ virtual_edges = [mapping_n[i] for i in virtual_vertices]
101
+ selected_edges = np.array(selected_edges + virtual_edges)
102
+
103
+ edge_index = data.edge_index[:, selected_edges]
104
+ selected_nodes = np.unique(
105
+ np.concatenate(
106
+ [selected_nodes, edge_index[0].numpy(), edge_index[1].numpy()]))
107
+
108
+ n = textual_nodes.iloc[selected_nodes]
109
+ e = textual_edges.iloc[selected_edges]
110
+ desc = n.to_csv(index=False) + '\n' + e.to_csv(
111
+ index=False, columns=['src', 'edge_attr', 'dst'])
112
+
113
+ mapping = {n: i for i, n in enumerate(selected_nodes.tolist())}
114
+ src = [mapping[i] for i in edge_index[0].tolist()]
115
+ dst = [mapping[i] for i in edge_index[1].tolist()]
116
+
117
+ data = Data(
118
+ x=data.x[selected_nodes],
119
+ edge_index=torch.tensor([src, dst]),
120
+ edge_attr=data.edge_attr[selected_edges],
121
+ )
122
+
123
+ return data, desc
124
+
125
+
126
+ class WebQSPDataset(InMemoryDataset):
127
+ r"""The WebQuestionsSP dataset of the `"The Value of Semantic Parse
128
+ Labeling for Knowledge Base Question Answering"
129
+ <https://aclanthology.org/P16-2033/>`_ paper.
130
+
131
+ Args:
132
+ root (str): Root directory where the dataset should be saved.
133
+ split (str, optional): If :obj:`"train"`, loads the training dataset.
134
+ If :obj:`"val"`, loads the validation dataset.
135
+ If :obj:`"test"`, loads the test dataset. (default: :obj:`"train"`)
136
+ force_reload (bool, optional): Whether to re-process the dataset.
137
+ (default: :obj:`False`)
138
+ """
139
+ def __init__(
140
+ self,
141
+ root: str,
142
+ split: str = "train",
143
+ force_reload: bool = False,
144
+ ) -> None:
145
+ super().__init__(root, force_reload=force_reload)
146
+
147
+ if split not in {'train', 'val', 'test'}:
148
+ raise ValueError(f"Invalid 'split' argument (got {split})")
149
+
150
+ path = self.processed_paths[['train', 'val', 'test'].index(split)]
151
+ self.load(path)
152
+
153
+ @property
154
+ def processed_file_names(self) -> List[str]:
155
+ return ['train_data.pt', 'val_data.pt', 'test_data.pt']
156
+
157
+ def process(self) -> None:
158
+ import datasets
159
+ import pandas as pd
160
+
161
+ datasets = datasets.load_dataset('rmanluo/RoG-webqsp')
162
+
163
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
164
+ model_name = 'sentence-transformers/all-roberta-large-v1'
165
+ model = SentenceTransformer(model_name).to(device)
166
+ model.eval()
167
+
168
+ for dataset, path in zip(
169
+ [datasets['train'], datasets['validation'], datasets['test']],
170
+ self.processed_paths,
171
+ ):
172
+ questions = [example["question"] for example in dataset]
173
+ question_embs = model.encode(
174
+ questions,
175
+ batch_size=256,
176
+ output_device='cpu',
177
+ )
178
+
179
+ data_list = []
180
+ for i, example in enumerate(tqdm(dataset)):
181
+ raw_nodes: Dict[str, int] = {}
182
+ raw_edges = []
183
+ for tri in example["graph"]:
184
+ h, r, t = tri
185
+ h = h.lower()
186
+ t = t.lower()
187
+ if h not in raw_nodes:
188
+ raw_nodes[h] = len(raw_nodes)
189
+ if t not in raw_nodes:
190
+ raw_nodes[t] = len(raw_nodes)
191
+ raw_edges.append({
192
+ "src": raw_nodes[h],
193
+ "edge_attr": r,
194
+ "dst": raw_nodes[t]
195
+ })
196
+ nodes = pd.DataFrame([{
197
+ "node_id": v,
198
+ "node_attr": k,
199
+ } for k, v in raw_nodes.items()])
200
+ edges = pd.DataFrame(raw_edges)
201
+
202
+ nodes.node_attr = nodes.node_attr.fillna("")
203
+ x = model.encode(
204
+ nodes.node_attr.tolist(),
205
+ batch_size=256,
206
+ output_device='cpu',
207
+ )
208
+ edge_attr = model.encode(
209
+ edges.edge_attr.tolist(),
210
+ batch_size=256,
211
+ output_device='cpu',
212
+ )
213
+ edge_index = torch.tensor([
214
+ edges.src.tolist(),
215
+ edges.dst.tolist(),
216
+ ])
217
+
218
+ question = f"Question: {example['question']}\nAnswer: "
219
+ label = ('|').join(example['answer']).lower()
220
+ data = Data(
221
+ x=x,
222
+ edge_index=edge_index,
223
+ edge_attr=edge_attr,
224
+ )
225
+ data, desc = retrieval_via_pcst(
226
+ data,
227
+ question_embs[i],
228
+ nodes,
229
+ edges,
230
+ topk=3,
231
+ topk_e=5,
232
+ cost_e=0.5,
233
+ )
234
+ data.question = question
235
+ data.label = label
236
+ data.desc = desc
237
+ data_list.append(data)
238
+
239
+ self.save(data_list, path)
@@ -173,7 +173,7 @@ class EdgeIndex(Tensor):
173
173
  :meth:`EdgeIndex.fill_cache_`, and are maintained and adjusted over its
174
174
  lifespan (*e.g.*, when calling :meth:`EdgeIndex.flip`).
175
175
 
176
- This representation ensures for optimal computation in GNN message passing
176
+ This representation ensures optimal computation in GNN message passing
177
177
  schemes, while preserving the ease-of-use of regular COO-based :pyg:`PyG`
178
178
  workflows.
179
179
 
torch_geometric/index.py CHANGED
@@ -106,7 +106,7 @@ class Index(Tensor):
106
106
  :meth:`Index.fill_cache_`, and are maintaned and adjusted over its
107
107
  lifespan.
108
108
 
109
- This representation ensures for optimal computation in GNN message passing
109
+ This representation ensures optimal computation in GNN message passing
110
110
  schemes, while preserving the ease-of-use of regular COO-based :pyg:`PyG`
111
111
  workflows.
112
112
 
@@ -120,7 +120,7 @@ class Index(Tensor):
120
120
  assert index.is_sorted
121
121
 
122
122
  # Flipping order:
123
- edge_index.flip(0)
123
+ index.flip(0)
124
124
  >>> Index([[2, 1, 1, 0], dim_size=3)
125
125
  assert not index.is_sorted
126
126
 
@@ -3,7 +3,6 @@ from typing import List, Optional
3
3
  import torch
4
4
  from torch import Tensor
5
5
 
6
- from torch_geometric.nn.models import GAT
7
6
  from torch_geometric.nn.nlp.llm import BOS, LLM, MAX_NEW_TOKENS
8
7
  from torch_geometric.utils import scatter
9
8
 
@@ -43,7 +42,6 @@ class GRetriever(torch.nn.Module):
43
42
  llm: LLM,
44
43
  gnn: torch.nn.Module,
45
44
  use_lora: bool = False,
46
- gnn_to_use=GAT,
47
45
  mlp_out_channels: int = 4096,
48
46
  ) -> None:
49
47
  super().__init__()
@@ -126,7 +124,15 @@ class GRetriever(torch.nn.Module):
126
124
  """
127
125
  x = self.encode(x, edge_index, batch, edge_attr)
128
126
  x = self.projector(x)
129
- xs = x.split(x.size(0), dim=0)
127
+ xs = x.split(1, dim=0)
128
+
129
+ # Handle questions without node features:
130
+ batch_unique = batch.unique()
131
+ batch_size = len(question)
132
+ if len(batch_unique) < batch_size:
133
+ xs = [
134
+ xs[i] if i in batch_unique else None for i in range(batch_size)
135
+ ]
130
136
 
131
137
  (
132
138
  inputs_embeds,
@@ -174,7 +180,15 @@ class GRetriever(torch.nn.Module):
174
180
  """
175
181
  x = self.encode(x, edge_index, batch, edge_attr)
176
182
  x = self.projector(x)
177
- xs = x.split(x.size(0), dim=0)
183
+ xs = x.split(1, dim=0)
184
+
185
+ # Handle questions without node features:
186
+ batch_unique = batch.unique()
187
+ batch_size = len(question)
188
+ if len(batch_unique) < batch_size:
189
+ xs = [
190
+ xs[i] if i in batch_unique else None for i in range(batch_size)
191
+ ]
178
192
 
179
193
  inputs_embeds, attention_mask, _ = self.llm._get_embeds(
180
194
  question, additional_text_context, xs)
@@ -1,3 +1,4 @@
1
+ import warnings
1
2
  from contextlib import nullcontext
2
3
  from typing import Any, Dict, List, Optional
3
4
 
@@ -85,6 +86,7 @@ class LLM(torch.nn.Module):
85
86
  self.word_embedding = self.llm.model.get_input_embeddings()
86
87
 
87
88
  if 'max_memory' not in kwargs: # Pure CPU:
89
+ warnings.warn("LLM is being used on CPU, which may be slow")
88
90
  self.device = torch.device('cpu')
89
91
  self.autocast_context = nullcontext()
90
92
  else:
@@ -54,8 +54,11 @@ class SentenceTransformer(torch.nn.Module):
54
54
  self,
55
55
  text: List[str],
56
56
  batch_size: Optional[int] = None,
57
- output_device: Optional[torch.device] = None,
57
+ output_device: Optional[Union[torch.device, str]] = None,
58
58
  ) -> Tensor:
59
+ is_empty = len(text) == 0
60
+ text = ['dummy'] if is_empty else text
61
+
59
62
  batch_size = len(text) if batch_size is None else batch_size
60
63
 
61
64
  embs: List[Tensor] = []
@@ -70,11 +73,13 @@ class SentenceTransformer(torch.nn.Module):
70
73
  emb = self(
71
74
  input_ids=token.input_ids.to(self.device),
72
75
  attention_mask=token.attention_mask.to(self.device),
73
- ).to(output_device or 'cpu')
76
+ ).to(output_device)
74
77
 
75
78
  embs.append(emb)
76
79
 
77
- return torch.cat(embs, dim=0) if len(embs) > 1 else embs[0]
80
+ out = torch.cat(embs, dim=0) if len(embs) > 1 else embs[0]
81
+ out = out[:0] if is_empty else out
82
+ return out
78
83
 
79
84
  def __repr__(self) -> str:
80
85
  return f'{self.__class__.__name__}(model_name={self.model_name})'