pyg-nightly 2.7.0.dev20250904__py3-none-any.whl → 2.7.0.dev20250906__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pyg_nightly-2.7.0.dev20250904.dist-info → pyg_nightly-2.7.0.dev20250906.dist-info}/METADATA +2 -1
- {pyg_nightly-2.7.0.dev20250904.dist-info → pyg_nightly-2.7.0.dev20250906.dist-info}/RECORD +34 -27
- torch_geometric/__init__.py +1 -1
- torch_geometric/data/__init__.py +0 -5
- torch_geometric/data/lightning/datamodule.py +2 -2
- torch_geometric/datasets/molecule_gpt_dataset.py +1 -1
- torch_geometric/datasets/web_qsp_dataset.py +262 -210
- torch_geometric/graphgym/imports.py +2 -2
- torch_geometric/llm/__init__.py +9 -0
- torch_geometric/{data → llm}/large_graph_indexer.py +124 -61
- torch_geometric/llm/models/__init__.py +23 -0
- torch_geometric/{nn → llm}/models/g_retriever.py +68 -49
- torch_geometric/{nn → llm}/models/git_mol.py +1 -1
- torch_geometric/{nn/nlp → llm/models}/llm.py +167 -33
- torch_geometric/llm/models/llm_judge.py +158 -0
- torch_geometric/{nn → llm}/models/molecule_gpt.py +1 -1
- torch_geometric/{nn/nlp → llm/models}/sentence_transformer.py +42 -8
- torch_geometric/llm/models/txt2kg.py +353 -0
- torch_geometric/llm/rag_loader.py +154 -0
- torch_geometric/llm/utils/backend_utils.py +442 -0
- torch_geometric/llm/utils/feature_store.py +169 -0
- torch_geometric/llm/utils/graph_store.py +199 -0
- torch_geometric/llm/utils/vectorrag.py +124 -0
- torch_geometric/loader/__init__.py +0 -4
- torch_geometric/metrics/link_pred.py +13 -2
- torch_geometric/nn/__init__.py +0 -1
- torch_geometric/nn/models/__init__.py +0 -10
- torch_geometric/nn/models/sgformer.py +2 -0
- torch_geometric/utils/cross_entropy.py +34 -13
- torch_geometric/loader/rag_loader.py +0 -107
- torch_geometric/nn/nlp/__init__.py +0 -9
- {pyg_nightly-2.7.0.dev20250904.dist-info → pyg_nightly-2.7.0.dev20250906.dist-info}/WHEEL +0 -0
- {pyg_nightly-2.7.0.dev20250904.dist-info → pyg_nightly-2.7.0.dev20250906.dist-info}/licenses/LICENSE +0 -0
- /torch_geometric/{nn → llm}/models/glem.py +0 -0
- /torch_geometric/{nn → llm}/models/protein_mpnn.py +0 -0
- /torch_geometric/{nn/nlp → llm/models}/vision_transformer.py +0 -0
@@ -1,120 +1,26 @@
|
|
1
1
|
# Code adapted from the G-Retriever paper: https://arxiv.org/abs/2402.07630
|
2
|
-
|
2
|
+
import gc
|
3
|
+
import os
|
4
|
+
from itertools import chain
|
5
|
+
from typing import Any, Dict, Iterator, List, Optional
|
3
6
|
|
4
|
-
import numpy as np
|
5
7
|
import torch
|
6
|
-
from torch import Tensor
|
7
8
|
from tqdm import tqdm
|
8
9
|
|
9
|
-
from torch_geometric.data import
|
10
|
-
from torch_geometric.
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
)
|
23
|
-
c = 0.01
|
24
|
-
|
25
|
-
from pcst_fast import pcst_fast
|
26
|
-
|
27
|
-
root = -1
|
28
|
-
num_clusters = 1
|
29
|
-
pruning = 'gw'
|
30
|
-
verbosity_level = 0
|
31
|
-
if topk > 0:
|
32
|
-
n_prizes = torch.nn.CosineSimilarity(dim=-1)(q_emb, data.x)
|
33
|
-
topk = min(topk, data.num_nodes)
|
34
|
-
_, topk_n_indices = torch.topk(n_prizes, topk, largest=True)
|
35
|
-
|
36
|
-
n_prizes = torch.zeros_like(n_prizes)
|
37
|
-
n_prizes[topk_n_indices] = torch.arange(topk, 0, -1).float()
|
38
|
-
else:
|
39
|
-
n_prizes = torch.zeros(data.num_nodes)
|
40
|
-
|
41
|
-
if topk_e > 0:
|
42
|
-
e_prizes = torch.nn.CosineSimilarity(dim=-1)(q_emb, data.edge_attr)
|
43
|
-
topk_e = min(topk_e, e_prizes.unique().size(0))
|
44
|
-
|
45
|
-
topk_e_values, _ = torch.topk(e_prizes.unique(), topk_e, largest=True)
|
46
|
-
e_prizes[e_prizes < topk_e_values[-1]] = 0.0
|
47
|
-
last_topk_e_value = topk_e
|
48
|
-
for k in range(topk_e):
|
49
|
-
indices = e_prizes == topk_e_values[k]
|
50
|
-
value = min((topk_e - k) / sum(indices), last_topk_e_value - c)
|
51
|
-
e_prizes[indices] = value
|
52
|
-
last_topk_e_value = value * (1 - c)
|
53
|
-
# reduce the cost of the edges such that at least one edge is selected
|
54
|
-
cost_e = min(cost_e, e_prizes.max().item() * (1 - c / 2))
|
55
|
-
else:
|
56
|
-
e_prizes = torch.zeros(data.num_edges)
|
57
|
-
|
58
|
-
costs = []
|
59
|
-
edges = []
|
60
|
-
virtual_n_prizes = []
|
61
|
-
virtual_edges = []
|
62
|
-
virtual_costs = []
|
63
|
-
mapping_n = {}
|
64
|
-
mapping_e = {}
|
65
|
-
for i, (src, dst) in enumerate(data.edge_index.t().numpy()):
|
66
|
-
prize_e = e_prizes[i]
|
67
|
-
if prize_e <= cost_e:
|
68
|
-
mapping_e[len(edges)] = i
|
69
|
-
edges.append((src, dst))
|
70
|
-
costs.append(cost_e - prize_e)
|
71
|
-
else:
|
72
|
-
virtual_node_id = data.num_nodes + len(virtual_n_prizes)
|
73
|
-
mapping_n[virtual_node_id] = i
|
74
|
-
virtual_edges.append((src, virtual_node_id))
|
75
|
-
virtual_edges.append((virtual_node_id, dst))
|
76
|
-
virtual_costs.append(0)
|
77
|
-
virtual_costs.append(0)
|
78
|
-
virtual_n_prizes.append(prize_e - cost_e)
|
79
|
-
|
80
|
-
prizes = np.concatenate([n_prizes, np.array(virtual_n_prizes)])
|
81
|
-
num_edges = len(edges)
|
82
|
-
if len(virtual_costs) > 0:
|
83
|
-
costs = np.array(costs + virtual_costs)
|
84
|
-
edges = np.array(edges + virtual_edges)
|
85
|
-
|
86
|
-
vertices, edges = pcst_fast(edges, prizes, costs, root, num_clusters,
|
87
|
-
pruning, verbosity_level)
|
88
|
-
|
89
|
-
selected_nodes = vertices[vertices < data.num_nodes]
|
90
|
-
selected_edges = [mapping_e[e] for e in edges if e < num_edges]
|
91
|
-
virtual_vertices = vertices[vertices >= data.num_nodes]
|
92
|
-
if len(virtual_vertices) > 0:
|
93
|
-
virtual_vertices = vertices[vertices >= data.num_nodes]
|
94
|
-
virtual_edges = [mapping_n[i] for i in virtual_vertices]
|
95
|
-
selected_edges = np.array(selected_edges + virtual_edges)
|
96
|
-
|
97
|
-
edge_index = data.edge_index[:, selected_edges]
|
98
|
-
selected_nodes = np.unique(
|
99
|
-
np.concatenate(
|
100
|
-
[selected_nodes, edge_index[0].numpy(), edge_index[1].numpy()]))
|
101
|
-
|
102
|
-
n = textual_nodes.iloc[selected_nodes]
|
103
|
-
e = textual_edges.iloc[selected_edges]
|
104
|
-
desc = n.to_csv(index=False) + '\n' + e.to_csv(
|
105
|
-
index=False, columns=['src', 'edge_attr', 'dst'])
|
106
|
-
|
107
|
-
mapping = {n: i for i, n in enumerate(selected_nodes.tolist())}
|
108
|
-
src = [mapping[i] for i in edge_index[0].tolist()]
|
109
|
-
dst = [mapping[i] for i in edge_index[1].tolist()]
|
110
|
-
|
111
|
-
data = Data(
|
112
|
-
x=data.x[selected_nodes],
|
113
|
-
edge_index=torch.tensor([src, dst]),
|
114
|
-
edge_attr=data.edge_attr[selected_edges],
|
115
|
-
)
|
116
|
-
|
117
|
-
return data, desc
|
10
|
+
from torch_geometric.data import InMemoryDataset
|
11
|
+
from torch_geometric.llm.large_graph_indexer import (
|
12
|
+
EDGE_RELATION,
|
13
|
+
LargeGraphIndexer,
|
14
|
+
TripletLike,
|
15
|
+
get_features_for_triplets_groups,
|
16
|
+
)
|
17
|
+
from torch_geometric.llm.models import SentenceTransformer
|
18
|
+
from torch_geometric.llm.utils.backend_utils import retrieval_via_pcst
|
19
|
+
|
20
|
+
|
21
|
+
def preprocess_triplet(triplet: TripletLike) -> TripletLike:
|
22
|
+
h, r, t = triplet
|
23
|
+
return str(h).lower(), str(r).lower(), str(t).lower()
|
118
24
|
|
119
25
|
|
120
26
|
class KGQABaseDataset(InMemoryDataset):
|
@@ -130,8 +36,16 @@ class KGQABaseDataset(InMemoryDataset):
|
|
130
36
|
If :obj:`"test"`, loads the test dataset. (default: :obj:`"train"`)
|
131
37
|
force_reload (bool, optional): Whether to re-process the dataset.
|
132
38
|
(default: :obj:`False`)
|
39
|
+
verbose (bool, optional): Whether to print output. Defaults to False.
|
133
40
|
use_pcst (bool, optional): Whether to preprocess the dataset's graph
|
134
41
|
with PCST or return the full graphs. (default: :obj:`True`)
|
42
|
+
load_dataset_kwargs (dict, optional):
|
43
|
+
Keyword arguments for the `datasets.load_dataset` function.
|
44
|
+
(default: :obj:`{}`)
|
45
|
+
retrieval_kwargs (dict, optional):
|
46
|
+
Keyword arguments for the
|
47
|
+
`get_features_for_triplets_groups` function.
|
48
|
+
(default: :obj:`{}`)
|
135
49
|
"""
|
136
50
|
def __init__(
|
137
51
|
self,
|
@@ -139,115 +53,206 @@ class KGQABaseDataset(InMemoryDataset):
|
|
139
53
|
root: str,
|
140
54
|
split: str = "train",
|
141
55
|
force_reload: bool = False,
|
56
|
+
verbose: bool = False,
|
142
57
|
use_pcst: bool = True,
|
143
|
-
|
58
|
+
load_dataset_kwargs: Optional[Dict[str, Any]] = None,
|
59
|
+
retrieval_kwargs: Optional[Dict[str, Any]] = None,
|
144
60
|
) -> None:
|
61
|
+
self.split = split
|
145
62
|
self.dataset_name = dataset_name
|
146
63
|
self.use_pcst = use_pcst
|
64
|
+
self.load_dataset_kwargs = load_dataset_kwargs or {}
|
65
|
+
"""
|
66
|
+
NOTE: If running into memory issues,
|
67
|
+
try reducing this batch size for the LargeGraphIndexer
|
68
|
+
used to build our KG.
|
69
|
+
Example: self.retrieval_kwargs = {"batch_size": 64}
|
70
|
+
"""
|
71
|
+
self.retrieval_kwargs = retrieval_kwargs or {}
|
72
|
+
|
73
|
+
# Caching custom subsets of the dataset results in unsupported behavior
|
74
|
+
if 'split' in self.load_dataset_kwargs:
|
75
|
+
print("WARNING: Caching custom subsets of the dataset \
|
76
|
+
results in unsupported behavior.\
|
77
|
+
Please specify a separate root directory for each split,\
|
78
|
+
or set force_reload=True on subsequent instantiations\
|
79
|
+
of the dataset.")
|
80
|
+
|
81
|
+
self.required_splits = ['train', 'validation', 'test']
|
82
|
+
|
83
|
+
self.verbose = verbose
|
84
|
+
self.force_reload = force_reload
|
147
85
|
super().__init__(root, force_reload=force_reload)
|
148
|
-
|
149
|
-
|
86
|
+
"""
|
87
|
+
NOTE: Current behavior is to process the entire dataset,
|
88
|
+
and only return the split specified by the user.
|
89
|
+
"""
|
90
|
+
if f'{split}_data.pt' not in set(self.processed_file_names):
|
150
91
|
raise ValueError(f"Invalid 'split' argument (got {split})")
|
92
|
+
if split == 'val':
|
93
|
+
split = 'validation'
|
151
94
|
|
152
|
-
|
153
|
-
|
95
|
+
self.load(self.processed_paths[self.required_splits.index(split)])
|
96
|
+
|
97
|
+
@property
|
98
|
+
def raw_file_names(self) -> List[str]:
|
99
|
+
return ["raw.pt"]
|
154
100
|
|
155
101
|
@property
|
156
102
|
def processed_file_names(self) -> List[str]:
|
157
|
-
return [
|
103
|
+
return ["train_data.pt", "val_data.pt", "test_data.pt"]
|
158
104
|
|
159
|
-
def
|
105
|
+
def download(self) -> None:
|
160
106
|
import datasets
|
161
|
-
import pandas as pd
|
162
|
-
|
163
|
-
datasets = datasets.load_dataset(self.dataset_name)
|
164
107
|
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
108
|
+
# HF Load Dataset by dataset name if no path is specified
|
109
|
+
self.load_dataset_kwargs['path'] = self.load_dataset_kwargs.get(
|
110
|
+
'path', self.dataset_name)
|
111
|
+
raw_dataset = datasets.load_dataset(**self.load_dataset_kwargs)
|
112
|
+
|
113
|
+
# Assert that the dataset contains the required splits
|
114
|
+
assert all(split in raw_dataset for split in self.required_splits), \
|
115
|
+
f"Dataset '{self.dataset_name}' is missing required splits: \
|
116
|
+
{self.required_splits}"
|
117
|
+
|
118
|
+
raw_dataset.save_to_disk(self.raw_paths[0])
|
119
|
+
|
120
|
+
def _get_trips(self) -> Iterator[TripletLike]:
|
121
|
+
# Iterate over each element's graph in each split of the dataset
|
122
|
+
# Using chain to lazily iterate without storing all trips in memory
|
123
|
+
split_iterators = []
|
124
|
+
|
125
|
+
for split in self.required_splits:
|
126
|
+
# Create an iterator for each element's graph in the current split
|
127
|
+
split_graphs = (element['graph']
|
128
|
+
for element in self.raw_dataset[split])
|
129
|
+
split_iterators.append(chain.from_iterable(split_graphs))
|
130
|
+
|
131
|
+
# Chain all split iterators together
|
132
|
+
return chain.from_iterable(split_iterators)
|
133
|
+
|
134
|
+
def _build_graph(self) -> None:
|
135
|
+
print("Encoding graph...")
|
136
|
+
trips = self._get_trips()
|
137
|
+
self.indexer: LargeGraphIndexer = LargeGraphIndexer.from_triplets(
|
138
|
+
trips, pre_transform=preprocess_triplet)
|
139
|
+
|
140
|
+
# Nodes:
|
141
|
+
print("\tEncoding nodes...")
|
142
|
+
nodes = self.indexer.get_unique_node_features()
|
143
|
+
x = self.model.encode(nodes, batch_size=256, output_device='cpu')
|
144
|
+
self.indexer.add_node_feature(new_feature_name="x", new_feature_vals=x)
|
145
|
+
|
146
|
+
# Edges:
|
147
|
+
print("\tEncoding edges...")
|
148
|
+
edges = self.indexer.get_unique_edge_features(
|
149
|
+
feature_name=EDGE_RELATION)
|
150
|
+
edge_attr = self.model.encode(edges, batch_size=256,
|
151
|
+
output_device='cpu')
|
152
|
+
self.indexer.add_edge_feature(
|
153
|
+
new_feature_name="edge_attr",
|
154
|
+
new_feature_vals=edge_attr,
|
155
|
+
map_from_feature=EDGE_RELATION,
|
156
|
+
)
|
157
|
+
|
158
|
+
print("\tSaving graph...")
|
159
|
+
self.indexer.save(self.indexer_path)
|
160
|
+
|
161
|
+
def _retrieve_subgraphs(self) -> None:
|
162
|
+
raw_splits = [
|
163
|
+
self.raw_dataset[split] for split in self.required_splits
|
164
|
+
]
|
165
|
+
zipped = zip(
|
166
|
+
self.required_splits,
|
167
|
+
raw_splits, # noqa
|
168
|
+
self.processed_paths,
|
169
|
+
)
|
170
|
+
for split_name, dataset, path in zipped:
|
171
|
+
print(f"Processing {split_name} split...")
|
172
|
+
|
173
|
+
print("\tEncoding questions...")
|
174
|
+
split_questions = [str(element['question']) for element in dataset]
|
175
|
+
split_q_embs = self.model.encode(split_questions, batch_size=256,
|
176
|
+
output_device='cpu')
|
177
|
+
|
178
|
+
print("\tRetrieving subgraphs...")
|
179
|
+
results_graphs = []
|
180
|
+
retrieval_kwargs = {
|
181
|
+
**self.retrieval_kwargs,
|
182
|
+
**{
|
183
|
+
'pre_transform': preprocess_triplet,
|
184
|
+
'verbose': self.verbose,
|
185
|
+
}
|
186
|
+
}
|
187
|
+
graph_gen = get_features_for_triplets_groups(
|
188
|
+
self.indexer, (element['graph'] for element in dataset),
|
189
|
+
**retrieval_kwargs)
|
190
|
+
|
191
|
+
for index in tqdm(range(len(dataset)), disable=not self.verbose):
|
192
|
+
data_i = dataset[index]
|
193
|
+
graph = next(graph_gen)
|
194
|
+
textual_nodes = self.textual_nodes.iloc[
|
195
|
+
graph["node_idx"]].reset_index()
|
196
|
+
textual_edges = self.textual_edges.iloc[
|
197
|
+
graph["edge_idx"]].reset_index()
|
198
|
+
if self.use_pcst and len(textual_nodes) > 0 and len(
|
199
|
+
textual_edges) > 0:
|
200
|
+
subgraph, desc = retrieval_via_pcst(
|
201
|
+
graph,
|
202
|
+
split_q_embs[index],
|
203
|
+
textual_nodes,
|
204
|
+
textual_edges,
|
238
205
|
)
|
239
206
|
else:
|
240
|
-
desc =
|
241
|
-
index=False
|
242
|
-
|
243
|
-
|
207
|
+
desc = textual_nodes.to_csv(
|
208
|
+
index=False) + "\n" + textual_edges.to_csv(
|
209
|
+
index=False,
|
210
|
+
columns=["src", "edge_attr", "dst"],
|
211
|
+
)
|
212
|
+
subgraph = graph
|
213
|
+
question = f"Question: {data_i['question']}\nAnswer: "
|
214
|
+
label = ("|").join(data_i["answer"]).lower()
|
215
|
+
|
216
|
+
subgraph["question"] = question
|
217
|
+
subgraph["label"] = label
|
218
|
+
subgraph["desc"] = desc
|
219
|
+
results_graphs.append(subgraph.to("cpu"))
|
220
|
+
print("\tSaving subgraphs...")
|
221
|
+
self.save(results_graphs, path)
|
244
222
|
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
223
|
+
def process(self) -> None:
|
224
|
+
import datasets
|
225
|
+
from pandas import DataFrame
|
226
|
+
self.raw_dataset = datasets.load_from_disk(self.raw_paths[0])
|
249
227
|
|
250
|
-
|
228
|
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
229
|
+
model_name = 'sentence-transformers/all-roberta-large-v1'
|
230
|
+
self.model: SentenceTransformer = SentenceTransformer(model_name).to(
|
231
|
+
device)
|
232
|
+
self.model.eval()
|
233
|
+
self.indexer_path = os.path.join(self.processed_dir,
|
234
|
+
"large_graph_indexer")
|
235
|
+
if self.force_reload or not os.path.exists(self.indexer_path):
|
236
|
+
self._build_graph()
|
237
|
+
else:
|
238
|
+
print("Loading graph...")
|
239
|
+
self.indexer = LargeGraphIndexer.from_disk(self.indexer_path)
|
240
|
+
self.textual_nodes = DataFrame.from_dict(
|
241
|
+
{"node_attr": self.indexer.get_node_features()})
|
242
|
+
self.textual_nodes["node_id"] = self.textual_nodes.index
|
243
|
+
self.textual_nodes = self.textual_nodes[["node_id", "node_attr"]]
|
244
|
+
self.textual_edges = DataFrame(self.indexer.get_edge_features(),
|
245
|
+
columns=["src", "edge_attr", "dst"])
|
246
|
+
self.textual_edges["src"] = [
|
247
|
+
self.indexer._nodes[h] for h in self.textual_edges["src"]
|
248
|
+
]
|
249
|
+
self.textual_edges["dst"] = [
|
250
|
+
self.indexer._nodes[h] for h in self.textual_edges["dst"]
|
251
|
+
]
|
252
|
+
self._retrieve_subgraphs()
|
253
|
+
|
254
|
+
gc.collect()
|
255
|
+
torch.cuda.empty_cache()
|
251
256
|
|
252
257
|
|
253
258
|
class WebQSPDataset(KGQABaseDataset):
|
@@ -262,13 +267,40 @@ class WebQSPDataset(KGQABaseDataset):
|
|
262
267
|
If :obj:`"test"`, loads the test dataset. (default: :obj:`"train"`)
|
263
268
|
force_reload (bool, optional): Whether to re-process the dataset.
|
264
269
|
(default: :obj:`False`)
|
270
|
+
verbose (bool, optional): Whether to print output. Defaults to False.
|
265
271
|
use_pcst (bool, optional): Whether to preprocess the dataset's graph
|
266
272
|
with PCST or return the full graphs. (default: :obj:`True`)
|
273
|
+
load_dataset_kwargs (dict, optional):
|
274
|
+
Keyword arguments for the `datasets.load_dataset` function.
|
275
|
+
(default: :obj:`{}`)
|
276
|
+
retrieval_kwargs (dict, optional):
|
277
|
+
Keyword arguments for the
|
278
|
+
`get_features_for_triplets_groups` function.
|
279
|
+
(default: :obj:`{}`)
|
267
280
|
"""
|
268
|
-
def __init__(
|
269
|
-
|
281
|
+
def __init__(
|
282
|
+
self,
|
283
|
+
root: str,
|
284
|
+
split: str = "train",
|
285
|
+
force_reload: bool = False,
|
286
|
+
verbose: bool = False,
|
287
|
+
use_pcst: bool = True,
|
288
|
+
load_dataset_kwargs: Optional[Dict[str, Any]] = None,
|
289
|
+
retrieval_kwargs: Optional[Dict[str, Any]] = None,
|
290
|
+
) -> None:
|
291
|
+
load_dataset_kwargs = load_dataset_kwargs or {}
|
292
|
+
retrieval_kwargs = retrieval_kwargs or {}
|
293
|
+
# Modify these paramters if running into memory/compute issues
|
294
|
+
default_retrieval_kwargs = {
|
295
|
+
'max_batch_size': 250, # Lower batch size to reduce memory usage
|
296
|
+
'num_workers':
|
297
|
+
None, # Use all available workers, or set to number of threads
|
298
|
+
}
|
299
|
+
retrieval_kwargs = {**default_retrieval_kwargs, **retrieval_kwargs}
|
270
300
|
dataset_name = 'rmanluo/RoG-webqsp'
|
271
|
-
super().__init__(dataset_name, root, split, force_reload,
|
301
|
+
super().__init__(dataset_name, root, split, force_reload, verbose,
|
302
|
+
use_pcst, load_dataset_kwargs=load_dataset_kwargs,
|
303
|
+
retrieval_kwargs=retrieval_kwargs)
|
272
304
|
|
273
305
|
|
274
306
|
class CWQDataset(KGQABaseDataset):
|
@@ -283,10 +315,30 @@ class CWQDataset(KGQABaseDataset):
|
|
283
315
|
If :obj:`"test"`, loads the test dataset. (default: :obj:`"train"`)
|
284
316
|
force_reload (bool, optional): Whether to re-process the dataset.
|
285
317
|
(default: :obj:`False`)
|
318
|
+
verbose (bool, optional): Whether to print output. Defaults to False.
|
286
319
|
use_pcst (bool, optional): Whether to preprocess the dataset's graph
|
287
320
|
with PCST or return the full graphs. (default: :obj:`True`)
|
321
|
+
load_dataset_kwargs (dict, optional):
|
322
|
+
Keyword arguments for the `datasets.load_dataset` function.
|
323
|
+
(default: :obj:`{}`)
|
324
|
+
retrieval_kwargs (dict, optional):
|
325
|
+
Keyword arguments for the
|
326
|
+
`get_features_for_triplets_groups` function.
|
327
|
+
(default: :obj:`{}`)
|
288
328
|
"""
|
289
|
-
def __init__(
|
290
|
-
|
329
|
+
def __init__(
|
330
|
+
self,
|
331
|
+
root: str,
|
332
|
+
split: str = "train",
|
333
|
+
force_reload: bool = False,
|
334
|
+
verbose: bool = False,
|
335
|
+
use_pcst: bool = True,
|
336
|
+
load_dataset_kwargs: Optional[Dict[str, Any]] = None,
|
337
|
+
retrieval_kwargs: Optional[Dict[str, Any]] = None,
|
338
|
+
) -> None:
|
339
|
+
load_dataset_kwargs = load_dataset_kwargs or {}
|
340
|
+
retrieval_kwargs = retrieval_kwargs or {}
|
291
341
|
dataset_name = 'rmanluo/RoG-cwq'
|
292
|
-
super().__init__(dataset_name, root, split, force_reload,
|
342
|
+
super().__init__(dataset_name, root, split, force_reload, verbose,
|
343
|
+
use_pcst, load_dataset_kwargs=load_dataset_kwargs,
|
344
|
+
retrieval_kwargs=retrieval_kwargs)
|
@@ -3,11 +3,11 @@ import warnings
|
|
3
3
|
import torch
|
4
4
|
|
5
5
|
try:
|
6
|
-
import
|
6
|
+
import lightning.pytorch as pl
|
7
7
|
_pl_is_available = True
|
8
8
|
except ImportError:
|
9
9
|
try:
|
10
|
-
import
|
10
|
+
import pytorch_lightning as pl
|
11
11
|
_pl_is_available = True
|
12
12
|
except ImportError:
|
13
13
|
_pl_is_available = False
|