pyg-nightly 2.7.0.dev20250905__py3-none-any.whl → 2.7.0.dev20250907__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. {pyg_nightly-2.7.0.dev20250905.dist-info → pyg_nightly-2.7.0.dev20250907.dist-info}/METADATA +2 -1
  2. {pyg_nightly-2.7.0.dev20250905.dist-info → pyg_nightly-2.7.0.dev20250907.dist-info}/RECORD +32 -25
  3. torch_geometric/__init__.py +1 -1
  4. torch_geometric/data/__init__.py +0 -5
  5. torch_geometric/data/lightning/datamodule.py +2 -2
  6. torch_geometric/datasets/molecule_gpt_dataset.py +1 -1
  7. torch_geometric/datasets/web_qsp_dataset.py +262 -210
  8. torch_geometric/graphgym/imports.py +2 -2
  9. torch_geometric/llm/__init__.py +9 -0
  10. torch_geometric/{data → llm}/large_graph_indexer.py +124 -61
  11. torch_geometric/llm/models/__init__.py +23 -0
  12. torch_geometric/{nn → llm}/models/g_retriever.py +68 -49
  13. torch_geometric/{nn → llm}/models/git_mol.py +1 -1
  14. torch_geometric/{nn/nlp → llm/models}/llm.py +167 -33
  15. torch_geometric/llm/models/llm_judge.py +158 -0
  16. torch_geometric/{nn → llm}/models/molecule_gpt.py +1 -1
  17. torch_geometric/{nn/nlp → llm/models}/sentence_transformer.py +42 -8
  18. torch_geometric/llm/models/txt2kg.py +353 -0
  19. torch_geometric/llm/rag_loader.py +154 -0
  20. torch_geometric/llm/utils/backend_utils.py +442 -0
  21. torch_geometric/llm/utils/feature_store.py +169 -0
  22. torch_geometric/llm/utils/graph_store.py +199 -0
  23. torch_geometric/llm/utils/vectorrag.py +124 -0
  24. torch_geometric/loader/__init__.py +0 -4
  25. torch_geometric/nn/__init__.py +0 -1
  26. torch_geometric/nn/models/__init__.py +0 -10
  27. torch_geometric/nn/models/sgformer.py +2 -0
  28. torch_geometric/loader/rag_loader.py +0 -107
  29. torch_geometric/nn/nlp/__init__.py +0 -9
  30. {pyg_nightly-2.7.0.dev20250905.dist-info → pyg_nightly-2.7.0.dev20250907.dist-info}/WHEEL +0 -0
  31. {pyg_nightly-2.7.0.dev20250905.dist-info → pyg_nightly-2.7.0.dev20250907.dist-info}/licenses/LICENSE +0 -0
  32. /torch_geometric/{nn → llm}/models/glem.py +0 -0
  33. /torch_geometric/{nn → llm}/models/protein_mpnn.py +0 -0
  34. /torch_geometric/{nn/nlp → llm/models}/vision_transformer.py +0 -0
@@ -1,120 +1,26 @@
1
1
  # Code adapted from the G-Retriever paper: https://arxiv.org/abs/2402.07630
2
- from typing import Any, Dict, List, Tuple, no_type_check
2
+ import gc
3
+ import os
4
+ from itertools import chain
5
+ from typing import Any, Dict, Iterator, List, Optional
3
6
 
4
- import numpy as np
5
7
  import torch
6
- from torch import Tensor
7
8
  from tqdm import tqdm
8
9
 
9
- from torch_geometric.data import Data, InMemoryDataset
10
- from torch_geometric.nn.nlp import SentenceTransformer
11
-
12
-
13
- @no_type_check
14
- def retrieval_via_pcst(
15
- data: Data,
16
- q_emb: Tensor,
17
- textual_nodes: Any,
18
- textual_edges: Any,
19
- topk: int = 3,
20
- topk_e: int = 3,
21
- cost_e: float = 0.5,
22
- ) -> Tuple[Data, str]:
23
- c = 0.01
24
-
25
- from pcst_fast import pcst_fast
26
-
27
- root = -1
28
- num_clusters = 1
29
- pruning = 'gw'
30
- verbosity_level = 0
31
- if topk > 0:
32
- n_prizes = torch.nn.CosineSimilarity(dim=-1)(q_emb, data.x)
33
- topk = min(topk, data.num_nodes)
34
- _, topk_n_indices = torch.topk(n_prizes, topk, largest=True)
35
-
36
- n_prizes = torch.zeros_like(n_prizes)
37
- n_prizes[topk_n_indices] = torch.arange(topk, 0, -1).float()
38
- else:
39
- n_prizes = torch.zeros(data.num_nodes)
40
-
41
- if topk_e > 0:
42
- e_prizes = torch.nn.CosineSimilarity(dim=-1)(q_emb, data.edge_attr)
43
- topk_e = min(topk_e, e_prizes.unique().size(0))
44
-
45
- topk_e_values, _ = torch.topk(e_prizes.unique(), topk_e, largest=True)
46
- e_prizes[e_prizes < topk_e_values[-1]] = 0.0
47
- last_topk_e_value = topk_e
48
- for k in range(topk_e):
49
- indices = e_prizes == topk_e_values[k]
50
- value = min((topk_e - k) / sum(indices), last_topk_e_value - c)
51
- e_prizes[indices] = value
52
- last_topk_e_value = value * (1 - c)
53
- # reduce the cost of the edges such that at least one edge is selected
54
- cost_e = min(cost_e, e_prizes.max().item() * (1 - c / 2))
55
- else:
56
- e_prizes = torch.zeros(data.num_edges)
57
-
58
- costs = []
59
- edges = []
60
- virtual_n_prizes = []
61
- virtual_edges = []
62
- virtual_costs = []
63
- mapping_n = {}
64
- mapping_e = {}
65
- for i, (src, dst) in enumerate(data.edge_index.t().numpy()):
66
- prize_e = e_prizes[i]
67
- if prize_e <= cost_e:
68
- mapping_e[len(edges)] = i
69
- edges.append((src, dst))
70
- costs.append(cost_e - prize_e)
71
- else:
72
- virtual_node_id = data.num_nodes + len(virtual_n_prizes)
73
- mapping_n[virtual_node_id] = i
74
- virtual_edges.append((src, virtual_node_id))
75
- virtual_edges.append((virtual_node_id, dst))
76
- virtual_costs.append(0)
77
- virtual_costs.append(0)
78
- virtual_n_prizes.append(prize_e - cost_e)
79
-
80
- prizes = np.concatenate([n_prizes, np.array(virtual_n_prizes)])
81
- num_edges = len(edges)
82
- if len(virtual_costs) > 0:
83
- costs = np.array(costs + virtual_costs)
84
- edges = np.array(edges + virtual_edges)
85
-
86
- vertices, edges = pcst_fast(edges, prizes, costs, root, num_clusters,
87
- pruning, verbosity_level)
88
-
89
- selected_nodes = vertices[vertices < data.num_nodes]
90
- selected_edges = [mapping_e[e] for e in edges if e < num_edges]
91
- virtual_vertices = vertices[vertices >= data.num_nodes]
92
- if len(virtual_vertices) > 0:
93
- virtual_vertices = vertices[vertices >= data.num_nodes]
94
- virtual_edges = [mapping_n[i] for i in virtual_vertices]
95
- selected_edges = np.array(selected_edges + virtual_edges)
96
-
97
- edge_index = data.edge_index[:, selected_edges]
98
- selected_nodes = np.unique(
99
- np.concatenate(
100
- [selected_nodes, edge_index[0].numpy(), edge_index[1].numpy()]))
101
-
102
- n = textual_nodes.iloc[selected_nodes]
103
- e = textual_edges.iloc[selected_edges]
104
- desc = n.to_csv(index=False) + '\n' + e.to_csv(
105
- index=False, columns=['src', 'edge_attr', 'dst'])
106
-
107
- mapping = {n: i for i, n in enumerate(selected_nodes.tolist())}
108
- src = [mapping[i] for i in edge_index[0].tolist()]
109
- dst = [mapping[i] for i in edge_index[1].tolist()]
110
-
111
- data = Data(
112
- x=data.x[selected_nodes],
113
- edge_index=torch.tensor([src, dst]),
114
- edge_attr=data.edge_attr[selected_edges],
115
- )
116
-
117
- return data, desc
10
+ from torch_geometric.data import InMemoryDataset
11
+ from torch_geometric.llm.large_graph_indexer import (
12
+ EDGE_RELATION,
13
+ LargeGraphIndexer,
14
+ TripletLike,
15
+ get_features_for_triplets_groups,
16
+ )
17
+ from torch_geometric.llm.models import SentenceTransformer
18
+ from torch_geometric.llm.utils.backend_utils import retrieval_via_pcst
19
+
20
+
21
+ def preprocess_triplet(triplet: TripletLike) -> TripletLike:
22
+ h, r, t = triplet
23
+ return str(h).lower(), str(r).lower(), str(t).lower()
118
24
 
119
25
 
120
26
  class KGQABaseDataset(InMemoryDataset):
@@ -130,8 +36,16 @@ class KGQABaseDataset(InMemoryDataset):
130
36
  If :obj:`"test"`, loads the test dataset. (default: :obj:`"train"`)
131
37
  force_reload (bool, optional): Whether to re-process the dataset.
132
38
  (default: :obj:`False`)
39
+ verbose (bool, optional): Whether to print output. Defaults to False.
133
40
  use_pcst (bool, optional): Whether to preprocess the dataset's graph
134
41
  with PCST or return the full graphs. (default: :obj:`True`)
42
+ load_dataset_kwargs (dict, optional):
43
+ Keyword arguments for the `datasets.load_dataset` function.
44
+ (default: :obj:`{}`)
45
+ retrieval_kwargs (dict, optional):
46
+ Keyword arguments for the
47
+ `get_features_for_triplets_groups` function.
48
+ (default: :obj:`{}`)
135
49
  """
136
50
  def __init__(
137
51
  self,
@@ -139,115 +53,206 @@ class KGQABaseDataset(InMemoryDataset):
139
53
  root: str,
140
54
  split: str = "train",
141
55
  force_reload: bool = False,
56
+ verbose: bool = False,
142
57
  use_pcst: bool = True,
143
- use_cwq: bool = True,
58
+ load_dataset_kwargs: Optional[Dict[str, Any]] = None,
59
+ retrieval_kwargs: Optional[Dict[str, Any]] = None,
144
60
  ) -> None:
61
+ self.split = split
145
62
  self.dataset_name = dataset_name
146
63
  self.use_pcst = use_pcst
64
+ self.load_dataset_kwargs = load_dataset_kwargs or {}
65
+ """
66
+ NOTE: If running into memory issues,
67
+ try reducing this batch size for the LargeGraphIndexer
68
+ used to build our KG.
69
+ Example: self.retrieval_kwargs = {"batch_size": 64}
70
+ """
71
+ self.retrieval_kwargs = retrieval_kwargs or {}
72
+
73
+ # Caching custom subsets of the dataset results in unsupported behavior
74
+ if 'split' in self.load_dataset_kwargs:
75
+ print("WARNING: Caching custom subsets of the dataset \
76
+ results in unsupported behavior.\
77
+ Please specify a separate root directory for each split,\
78
+ or set force_reload=True on subsequent instantiations\
79
+ of the dataset.")
80
+
81
+ self.required_splits = ['train', 'validation', 'test']
82
+
83
+ self.verbose = verbose
84
+ self.force_reload = force_reload
147
85
  super().__init__(root, force_reload=force_reload)
148
-
149
- if split not in {'train', 'val', 'test'}:
86
+ """
87
+ NOTE: Current behavior is to process the entire dataset,
88
+ and only return the split specified by the user.
89
+ """
90
+ if f'{split}_data.pt' not in set(self.processed_file_names):
150
91
  raise ValueError(f"Invalid 'split' argument (got {split})")
92
+ if split == 'val':
93
+ split = 'validation'
151
94
 
152
- path = self.processed_paths[['train', 'val', 'test'].index(split)]
153
- self.load(path)
95
+ self.load(self.processed_paths[self.required_splits.index(split)])
96
+
97
+ @property
98
+ def raw_file_names(self) -> List[str]:
99
+ return ["raw.pt"]
154
100
 
155
101
  @property
156
102
  def processed_file_names(self) -> List[str]:
157
- return ['train_data.pt', 'val_data.pt', 'test_data.pt']
103
+ return ["train_data.pt", "val_data.pt", "test_data.pt"]
158
104
 
159
- def process(self) -> None:
105
+ def download(self) -> None:
160
106
  import datasets
161
- import pandas as pd
162
-
163
- datasets = datasets.load_dataset(self.dataset_name)
164
107
 
165
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
166
- model_name = 'sentence-transformers/all-roberta-large-v1'
167
- model = SentenceTransformer(model_name).to(device)
168
- model.eval()
169
-
170
- for dataset, path in zip(
171
- [datasets['train'], datasets['validation'], datasets['test']],
172
- self.processed_paths,
173
- ):
174
- questions = [example["question"] for example in dataset]
175
- question_embs = model.encode(
176
- questions,
177
- batch_size=256,
178
- output_device='cpu',
179
- )
180
-
181
- data_list = []
182
- for i, example in enumerate(tqdm(dataset)):
183
- raw_nodes: Dict[str, int] = {}
184
- raw_edges = []
185
- for tri in example["graph"]:
186
- h, r, t = tri
187
- h = h.lower()
188
- t = t.lower()
189
- if h not in raw_nodes:
190
- raw_nodes[h] = len(raw_nodes)
191
- if t not in raw_nodes:
192
- raw_nodes[t] = len(raw_nodes)
193
- raw_edges.append({
194
- "src": raw_nodes[h],
195
- "edge_attr": r,
196
- "dst": raw_nodes[t]
197
- })
198
- nodes = pd.DataFrame([{
199
- "node_id": v,
200
- "node_attr": k,
201
- } for k, v in raw_nodes.items()],
202
- columns=["node_id", "node_attr"])
203
- edges = pd.DataFrame(raw_edges,
204
- columns=["src", "edge_attr", "dst"])
205
-
206
- nodes.node_attr = nodes.node_attr.fillna("")
207
- x = model.encode(
208
- nodes.node_attr.tolist(),
209
- batch_size=256,
210
- output_device='cpu',
211
- )
212
- edge_attr = model.encode(
213
- edges.edge_attr.tolist(),
214
- batch_size=256,
215
- output_device='cpu',
216
- )
217
- edge_index = torch.tensor([
218
- edges.src.tolist(),
219
- edges.dst.tolist(),
220
- ], dtype=torch.long)
221
-
222
- question = f"Question: {example['question']}\nAnswer: "
223
- label = ('|').join(example['answer']).lower()
224
- data = Data(
225
- x=x,
226
- edge_index=edge_index,
227
- edge_attr=edge_attr,
228
- )
229
- if self.use_pcst and len(nodes) > 0 and len(edges) > 0:
230
- data, desc = retrieval_via_pcst(
231
- data,
232
- question_embs[i],
233
- nodes,
234
- edges,
235
- topk=3,
236
- topk_e=5,
237
- cost_e=0.5,
108
+ # HF Load Dataset by dataset name if no path is specified
109
+ self.load_dataset_kwargs['path'] = self.load_dataset_kwargs.get(
110
+ 'path', self.dataset_name)
111
+ raw_dataset = datasets.load_dataset(**self.load_dataset_kwargs)
112
+
113
+ # Assert that the dataset contains the required splits
114
+ assert all(split in raw_dataset for split in self.required_splits), \
115
+ f"Dataset '{self.dataset_name}' is missing required splits: \
116
+ {self.required_splits}"
117
+
118
+ raw_dataset.save_to_disk(self.raw_paths[0])
119
+
120
+ def _get_trips(self) -> Iterator[TripletLike]:
121
+ # Iterate over each element's graph in each split of the dataset
122
+ # Using chain to lazily iterate without storing all trips in memory
123
+ split_iterators = []
124
+
125
+ for split in self.required_splits:
126
+ # Create an iterator for each element's graph in the current split
127
+ split_graphs = (element['graph']
128
+ for element in self.raw_dataset[split])
129
+ split_iterators.append(chain.from_iterable(split_graphs))
130
+
131
+ # Chain all split iterators together
132
+ return chain.from_iterable(split_iterators)
133
+
134
+ def _build_graph(self) -> None:
135
+ print("Encoding graph...")
136
+ trips = self._get_trips()
137
+ self.indexer: LargeGraphIndexer = LargeGraphIndexer.from_triplets(
138
+ trips, pre_transform=preprocess_triplet)
139
+
140
+ # Nodes:
141
+ print("\tEncoding nodes...")
142
+ nodes = self.indexer.get_unique_node_features()
143
+ x = self.model.encode(nodes, batch_size=256, output_device='cpu')
144
+ self.indexer.add_node_feature(new_feature_name="x", new_feature_vals=x)
145
+
146
+ # Edges:
147
+ print("\tEncoding edges...")
148
+ edges = self.indexer.get_unique_edge_features(
149
+ feature_name=EDGE_RELATION)
150
+ edge_attr = self.model.encode(edges, batch_size=256,
151
+ output_device='cpu')
152
+ self.indexer.add_edge_feature(
153
+ new_feature_name="edge_attr",
154
+ new_feature_vals=edge_attr,
155
+ map_from_feature=EDGE_RELATION,
156
+ )
157
+
158
+ print("\tSaving graph...")
159
+ self.indexer.save(self.indexer_path)
160
+
161
+ def _retrieve_subgraphs(self) -> None:
162
+ raw_splits = [
163
+ self.raw_dataset[split] for split in self.required_splits
164
+ ]
165
+ zipped = zip(
166
+ self.required_splits,
167
+ raw_splits, # noqa
168
+ self.processed_paths,
169
+ )
170
+ for split_name, dataset, path in zipped:
171
+ print(f"Processing {split_name} split...")
172
+
173
+ print("\tEncoding questions...")
174
+ split_questions = [str(element['question']) for element in dataset]
175
+ split_q_embs = self.model.encode(split_questions, batch_size=256,
176
+ output_device='cpu')
177
+
178
+ print("\tRetrieving subgraphs...")
179
+ results_graphs = []
180
+ retrieval_kwargs = {
181
+ **self.retrieval_kwargs,
182
+ **{
183
+ 'pre_transform': preprocess_triplet,
184
+ 'verbose': self.verbose,
185
+ }
186
+ }
187
+ graph_gen = get_features_for_triplets_groups(
188
+ self.indexer, (element['graph'] for element in dataset),
189
+ **retrieval_kwargs)
190
+
191
+ for index in tqdm(range(len(dataset)), disable=not self.verbose):
192
+ data_i = dataset[index]
193
+ graph = next(graph_gen)
194
+ textual_nodes = self.textual_nodes.iloc[
195
+ graph["node_idx"]].reset_index()
196
+ textual_edges = self.textual_edges.iloc[
197
+ graph["edge_idx"]].reset_index()
198
+ if self.use_pcst and len(textual_nodes) > 0 and len(
199
+ textual_edges) > 0:
200
+ subgraph, desc = retrieval_via_pcst(
201
+ graph,
202
+ split_q_embs[index],
203
+ textual_nodes,
204
+ textual_edges,
238
205
  )
239
206
  else:
240
- desc = nodes.to_csv(index=False) + "\n" + edges.to_csv(
241
- index=False,
242
- columns=["src", "edge_attr", "dst"],
243
- )
207
+ desc = textual_nodes.to_csv(
208
+ index=False) + "\n" + textual_edges.to_csv(
209
+ index=False,
210
+ columns=["src", "edge_attr", "dst"],
211
+ )
212
+ subgraph = graph
213
+ question = f"Question: {data_i['question']}\nAnswer: "
214
+ label = ("|").join(data_i["answer"]).lower()
215
+
216
+ subgraph["question"] = question
217
+ subgraph["label"] = label
218
+ subgraph["desc"] = desc
219
+ results_graphs.append(subgraph.to("cpu"))
220
+ print("\tSaving subgraphs...")
221
+ self.save(results_graphs, path)
244
222
 
245
- data.question = question
246
- data.label = label
247
- data.desc = desc
248
- data_list.append(data)
223
+ def process(self) -> None:
224
+ import datasets
225
+ from pandas import DataFrame
226
+ self.raw_dataset = datasets.load_from_disk(self.raw_paths[0])
249
227
 
250
- self.save(data_list, path)
228
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
229
+ model_name = 'sentence-transformers/all-roberta-large-v1'
230
+ self.model: SentenceTransformer = SentenceTransformer(model_name).to(
231
+ device)
232
+ self.model.eval()
233
+ self.indexer_path = os.path.join(self.processed_dir,
234
+ "large_graph_indexer")
235
+ if self.force_reload or not os.path.exists(self.indexer_path):
236
+ self._build_graph()
237
+ else:
238
+ print("Loading graph...")
239
+ self.indexer = LargeGraphIndexer.from_disk(self.indexer_path)
240
+ self.textual_nodes = DataFrame.from_dict(
241
+ {"node_attr": self.indexer.get_node_features()})
242
+ self.textual_nodes["node_id"] = self.textual_nodes.index
243
+ self.textual_nodes = self.textual_nodes[["node_id", "node_attr"]]
244
+ self.textual_edges = DataFrame(self.indexer.get_edge_features(),
245
+ columns=["src", "edge_attr", "dst"])
246
+ self.textual_edges["src"] = [
247
+ self.indexer._nodes[h] for h in self.textual_edges["src"]
248
+ ]
249
+ self.textual_edges["dst"] = [
250
+ self.indexer._nodes[h] for h in self.textual_edges["dst"]
251
+ ]
252
+ self._retrieve_subgraphs()
253
+
254
+ gc.collect()
255
+ torch.cuda.empty_cache()
251
256
 
252
257
 
253
258
  class WebQSPDataset(KGQABaseDataset):
@@ -262,13 +267,40 @@ class WebQSPDataset(KGQABaseDataset):
262
267
  If :obj:`"test"`, loads the test dataset. (default: :obj:`"train"`)
263
268
  force_reload (bool, optional): Whether to re-process the dataset.
264
269
  (default: :obj:`False`)
270
+ verbose (bool, optional): Whether to print output. Defaults to False.
265
271
  use_pcst (bool, optional): Whether to preprocess the dataset's graph
266
272
  with PCST or return the full graphs. (default: :obj:`True`)
273
+ load_dataset_kwargs (dict, optional):
274
+ Keyword arguments for the `datasets.load_dataset` function.
275
+ (default: :obj:`{}`)
276
+ retrieval_kwargs (dict, optional):
277
+ Keyword arguments for the
278
+ `get_features_for_triplets_groups` function.
279
+ (default: :obj:`{}`)
267
280
  """
268
- def __init__(self, root: str, split: str = "train",
269
- force_reload: bool = False, use_pcst: bool = True) -> None:
281
+ def __init__(
282
+ self,
283
+ root: str,
284
+ split: str = "train",
285
+ force_reload: bool = False,
286
+ verbose: bool = False,
287
+ use_pcst: bool = True,
288
+ load_dataset_kwargs: Optional[Dict[str, Any]] = None,
289
+ retrieval_kwargs: Optional[Dict[str, Any]] = None,
290
+ ) -> None:
291
+ load_dataset_kwargs = load_dataset_kwargs or {}
292
+ retrieval_kwargs = retrieval_kwargs or {}
293
+ # Modify these paramters if running into memory/compute issues
294
+ default_retrieval_kwargs = {
295
+ 'max_batch_size': 250, # Lower batch size to reduce memory usage
296
+ 'num_workers':
297
+ None, # Use all available workers, or set to number of threads
298
+ }
299
+ retrieval_kwargs = {**default_retrieval_kwargs, **retrieval_kwargs}
270
300
  dataset_name = 'rmanluo/RoG-webqsp'
271
- super().__init__(dataset_name, root, split, force_reload, use_pcst)
301
+ super().__init__(dataset_name, root, split, force_reload, verbose,
302
+ use_pcst, load_dataset_kwargs=load_dataset_kwargs,
303
+ retrieval_kwargs=retrieval_kwargs)
272
304
 
273
305
 
274
306
  class CWQDataset(KGQABaseDataset):
@@ -283,10 +315,30 @@ class CWQDataset(KGQABaseDataset):
283
315
  If :obj:`"test"`, loads the test dataset. (default: :obj:`"train"`)
284
316
  force_reload (bool, optional): Whether to re-process the dataset.
285
317
  (default: :obj:`False`)
318
+ verbose (bool, optional): Whether to print output. Defaults to False.
286
319
  use_pcst (bool, optional): Whether to preprocess the dataset's graph
287
320
  with PCST or return the full graphs. (default: :obj:`True`)
321
+ load_dataset_kwargs (dict, optional):
322
+ Keyword arguments for the `datasets.load_dataset` function.
323
+ (default: :obj:`{}`)
324
+ retrieval_kwargs (dict, optional):
325
+ Keyword arguments for the
326
+ `get_features_for_triplets_groups` function.
327
+ (default: :obj:`{}`)
288
328
  """
289
- def __init__(self, root: str, split: str = "train",
290
- force_reload: bool = False, use_pcst: bool = True) -> None:
329
+ def __init__(
330
+ self,
331
+ root: str,
332
+ split: str = "train",
333
+ force_reload: bool = False,
334
+ verbose: bool = False,
335
+ use_pcst: bool = True,
336
+ load_dataset_kwargs: Optional[Dict[str, Any]] = None,
337
+ retrieval_kwargs: Optional[Dict[str, Any]] = None,
338
+ ) -> None:
339
+ load_dataset_kwargs = load_dataset_kwargs or {}
340
+ retrieval_kwargs = retrieval_kwargs or {}
291
341
  dataset_name = 'rmanluo/RoG-cwq'
292
- super().__init__(dataset_name, root, split, force_reload, use_pcst)
342
+ super().__init__(dataset_name, root, split, force_reload, verbose,
343
+ use_pcst, load_dataset_kwargs=load_dataset_kwargs,
344
+ retrieval_kwargs=retrieval_kwargs)
@@ -3,11 +3,11 @@ import warnings
3
3
  import torch
4
4
 
5
5
  try:
6
- import pytorch_lightning as pl
6
+ import lightning.pytorch as pl
7
7
  _pl_is_available = True
8
8
  except ImportError:
9
9
  try:
10
- import lightning.pytorch as pl
10
+ import pytorch_lightning as pl
11
11
  _pl_is_available = True
12
12
  except ImportError:
13
13
  _pl_is_available = False
@@ -0,0 +1,9 @@
1
+ from .large_graph_indexer import LargeGraphIndexer
2
+ from .rag_loader import RAGQueryLoader
3
+ from .utils import * # noqa
4
+ from .models import * # noqa
5
+
6
+ __all__ = classes = [
7
+ LargeGraphIndexer,
8
+ RAGQueryLoader,
9
+ ]