pyg-nightly 2.7.0.dev20250905__py3-none-any.whl → 2.7.0.dev20250906__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pyg_nightly-2.7.0.dev20250905.dist-info → pyg_nightly-2.7.0.dev20250906.dist-info}/METADATA +2 -1
- {pyg_nightly-2.7.0.dev20250905.dist-info → pyg_nightly-2.7.0.dev20250906.dist-info}/RECORD +32 -25
- torch_geometric/__init__.py +1 -1
- torch_geometric/data/__init__.py +0 -5
- torch_geometric/data/lightning/datamodule.py +2 -2
- torch_geometric/datasets/molecule_gpt_dataset.py +1 -1
- torch_geometric/datasets/web_qsp_dataset.py +262 -210
- torch_geometric/graphgym/imports.py +2 -2
- torch_geometric/llm/__init__.py +9 -0
- torch_geometric/{data → llm}/large_graph_indexer.py +124 -61
- torch_geometric/llm/models/__init__.py +23 -0
- torch_geometric/{nn → llm}/models/g_retriever.py +68 -49
- torch_geometric/{nn → llm}/models/git_mol.py +1 -1
- torch_geometric/{nn/nlp → llm/models}/llm.py +167 -33
- torch_geometric/llm/models/llm_judge.py +158 -0
- torch_geometric/{nn → llm}/models/molecule_gpt.py +1 -1
- torch_geometric/{nn/nlp → llm/models}/sentence_transformer.py +42 -8
- torch_geometric/llm/models/txt2kg.py +353 -0
- torch_geometric/llm/rag_loader.py +154 -0
- torch_geometric/llm/utils/backend_utils.py +442 -0
- torch_geometric/llm/utils/feature_store.py +169 -0
- torch_geometric/llm/utils/graph_store.py +199 -0
- torch_geometric/llm/utils/vectorrag.py +124 -0
- torch_geometric/loader/__init__.py +0 -4
- torch_geometric/nn/__init__.py +0 -1
- torch_geometric/nn/models/__init__.py +0 -10
- torch_geometric/nn/models/sgformer.py +2 -0
- torch_geometric/loader/rag_loader.py +0 -107
- torch_geometric/nn/nlp/__init__.py +0 -9
- {pyg_nightly-2.7.0.dev20250905.dist-info → pyg_nightly-2.7.0.dev20250906.dist-info}/WHEEL +0 -0
- {pyg_nightly-2.7.0.dev20250905.dist-info → pyg_nightly-2.7.0.dev20250906.dist-info}/licenses/LICENSE +0 -0
- /torch_geometric/{nn → llm}/models/glem.py +0 -0
- /torch_geometric/{nn → llm}/models/protein_mpnn.py +0 -0
- /torch_geometric/{nn/nlp → llm/models}/vision_transformer.py +0 -0
@@ -2,7 +2,7 @@ import os
|
|
2
2
|
import pickle as pkl
|
3
3
|
import shutil
|
4
4
|
from dataclasses import dataclass
|
5
|
-
from itertools import chain
|
5
|
+
from itertools import chain, islice, tee
|
6
6
|
from typing import (
|
7
7
|
Any,
|
8
8
|
Callable,
|
@@ -37,15 +37,15 @@ def ordered_set(values: Iterable[str]) -> List[str]:
|
|
37
37
|
|
38
38
|
# TODO: Refactor Node and Edge funcs and attrs to be accessible via an Enum?
|
39
39
|
|
40
|
-
NODE_PID = "pid"
|
40
|
+
NODE_PID = "pid" # Encodes node id
|
41
41
|
|
42
42
|
NODE_KEYS = {NODE_PID}
|
43
43
|
|
44
|
-
EDGE_PID = "e_pid"
|
45
|
-
EDGE_HEAD = "h"
|
46
|
-
EDGE_RELATION = "r"
|
47
|
-
EDGE_TAIL = "t"
|
48
|
-
EDGE_INDEX = "edge_idx"
|
44
|
+
EDGE_PID = "e_pid" # Encodes source node, relation, destination node
|
45
|
+
EDGE_HEAD = "h" # Encodes source node
|
46
|
+
EDGE_RELATION = "r" # Encodes relation
|
47
|
+
EDGE_TAIL = "t" # Encodes destination node
|
48
|
+
EDGE_INDEX = "edge_idx" # Encodes source node, destination node
|
49
49
|
|
50
50
|
EDGE_KEYS = {EDGE_PID, EDGE_HEAD, EDGE_RELATION, EDGE_TAIL, EDGE_INDEX}
|
51
51
|
|
@@ -88,6 +88,7 @@ class LargeGraphIndexer:
|
|
88
88
|
Args:
|
89
89
|
nodes (Iterable[str]): Node ids in the graph.
|
90
90
|
edges (KnowledgeGraphLike): Edge ids in the graph.
|
91
|
+
Example: [("cats", "eat", "dogs")]
|
91
92
|
node_attr (Optional[Dict[str, List[Any]]], optional): Mapping node
|
92
93
|
attribute name and list of their values in order of unique node
|
93
94
|
ids. Defaults to None.
|
@@ -148,7 +149,6 @@ class LargeGraphIndexer:
|
|
148
149
|
self.edge_attr[EDGE_TAIL].append(t)
|
149
150
|
self.edge_attr[EDGE_INDEX].append(
|
150
151
|
(self._nodes[h], self._nodes[t]))
|
151
|
-
|
152
152
|
for i, tup in enumerate(edges):
|
153
153
|
self._edges[tup] = i
|
154
154
|
|
@@ -164,7 +164,8 @@ class LargeGraphIndexer:
|
|
164
164
|
|
165
165
|
Args:
|
166
166
|
triplets (KnowledgeGraphLike): Series of triplets representing
|
167
|
-
knowledge graph relations.
|
167
|
+
knowledge graph relations. Example: [("cats", "eat", dogs")].
|
168
|
+
Note: Please ensure triplets are unique.
|
168
169
|
pre_transform (Optional[Callable[[TripletLike], TripletLike]]):
|
169
170
|
Optional preprocessing function to apply to triplets.
|
170
171
|
Defaults to None.
|
@@ -173,8 +174,8 @@ class LargeGraphIndexer:
|
|
173
174
|
LargeGraphIndexer: Index of unique nodes and edges.
|
174
175
|
"""
|
175
176
|
# NOTE: Right now assumes that all trips can be loaded into memory
|
176
|
-
nodes =
|
177
|
-
edges =
|
177
|
+
nodes = []
|
178
|
+
edges = []
|
178
179
|
|
179
180
|
if pre_transform is not None:
|
180
181
|
|
@@ -183,16 +184,17 @@ class LargeGraphIndexer:
|
|
183
184
|
for trip in trips:
|
184
185
|
yield pre_transform(trip)
|
185
186
|
|
186
|
-
triplets = apply_transform(triplets)
|
187
|
+
triplets = list(apply_transform(triplets))
|
187
188
|
|
188
189
|
for h, r, t in triplets:
|
189
190
|
|
190
191
|
for node in (h, t):
|
191
|
-
nodes.
|
192
|
+
nodes.append(node)
|
192
193
|
|
193
194
|
edge_idx = (h, r, t)
|
194
|
-
edges.
|
195
|
-
|
195
|
+
edges.append(edge_idx)
|
196
|
+
nodes = ordered_set(nodes)
|
197
|
+
edges = ordered_set(edges)
|
196
198
|
return cls(list(nodes), list(edges))
|
197
199
|
|
198
200
|
@classmethod
|
@@ -291,13 +293,12 @@ class LargeGraphIndexer:
|
|
291
293
|
values = self.node_attr[feature_name].values
|
292
294
|
else:
|
293
295
|
values = self.node_attr[feature_name]
|
294
|
-
|
295
296
|
# TODO: torch_geometric.utils.select
|
296
297
|
if isinstance(values, torch.Tensor):
|
297
298
|
idxs = list(
|
298
299
|
self.get_node_features_iter(feature_name, pids,
|
299
300
|
index_only=True))
|
300
|
-
return values[torch.tensor(idxs)]
|
301
|
+
return values[torch.tensor(idxs).long()]
|
301
302
|
return list(self.get_node_features_iter(feature_name, pids))
|
302
303
|
|
303
304
|
def get_node_features_iter(
|
@@ -421,7 +422,7 @@ class LargeGraphIndexer:
|
|
421
422
|
idxs = list(
|
422
423
|
self.get_edge_features_iter(feature_name, pids,
|
423
424
|
index_only=True))
|
424
|
-
return values[torch.tensor(idxs)]
|
425
|
+
return values[torch.tensor(idxs).long()]
|
425
426
|
return list(self.get_edge_features_iter(feature_name, pids))
|
426
427
|
|
427
428
|
def get_edge_features_iter(
|
@@ -532,7 +533,6 @@ class LargeGraphIndexer:
|
|
532
533
|
"""
|
533
534
|
x = torch.Tensor(self.get_node_features(node_feature_name))
|
534
535
|
node_id = torch.LongTensor(range(len(x)))
|
535
|
-
|
536
536
|
edge_index = torch.t(
|
537
537
|
torch.LongTensor(self.get_edge_features(EDGE_INDEX)))
|
538
538
|
|
@@ -572,8 +572,10 @@ def get_features_for_triplets_groups(
|
|
572
572
|
triplet_groups: Iterable[KnowledgeGraphLike],
|
573
573
|
node_feature_name: str = "x",
|
574
574
|
edge_feature_name: str = "edge_attr",
|
575
|
-
pre_transform:
|
575
|
+
pre_transform: Callable[[TripletLike], TripletLike] = lambda trip: trip,
|
576
576
|
verbose: bool = False,
|
577
|
+
max_batch_size: int = 250,
|
578
|
+
num_workers: Optional[int] = None,
|
577
579
|
) -> Iterator[Data]:
|
578
580
|
"""Given an indexer and a series of triplet groups (like a dataset),
|
579
581
|
retrieve the specified node and edge features for each triplet from the
|
@@ -587,62 +589,123 @@ def get_features_for_triplets_groups(
|
|
587
589
|
Defaults to "x".
|
588
590
|
edge_feature_name (str, optional): edge feature to fetch.
|
589
591
|
Defaults to "edge_attr".
|
590
|
-
pre_transform (
|
592
|
+
pre_transform (Callable[[TripletLike], TripletLike]):
|
591
593
|
Optional preprocessing to perform on triplets.
|
592
594
|
Defaults to None.
|
593
|
-
verbose (bool, optional): Whether to print progress.
|
595
|
+
verbose (bool, optional): Whether to print progress.
|
596
|
+
Defaults to False.
|
597
|
+
max_batch_size (int, optional):
|
598
|
+
Maximum batch size for fetching features.
|
599
|
+
Defaults to 250.
|
600
|
+
num_workers (int, optional):
|
601
|
+
Number of workers to use for fetching features.
|
602
|
+
Defaults to None (all available).
|
594
603
|
|
595
604
|
Yields:
|
596
605
|
Iterator[Data]: For each triplet group, yield a data object containing
|
597
606
|
the unique graph and features from the index.
|
598
607
|
"""
|
599
|
-
|
608
|
+
def apply_transform(trips: Iterable[TripletLike]) -> Iterator[TripletLike]:
|
609
|
+
for trip in trips:
|
610
|
+
yield pre_transform(tuple(trip))
|
600
611
|
|
601
|
-
|
602
|
-
|
603
|
-
|
604
|
-
|
605
|
-
|
606
|
-
triplet_groups = (list(apply_transform(triplets))
|
607
|
-
for triplets in triplet_groups)
|
612
|
+
# Carefully trying to avoid loading all triplets into memory at once
|
613
|
+
# While also still tracking the number of elements for tqdm
|
614
|
+
triplet_groups: List[Iterator[TripletLike]] = [
|
615
|
+
apply_transform(triplets) for triplets in triplet_groups
|
616
|
+
]
|
608
617
|
|
609
618
|
node_keys = []
|
610
619
|
edge_keys = []
|
611
620
|
edge_index = []
|
621
|
+
"""
|
622
|
+
For each KG, we gather the node_indices, edge_keys,
|
623
|
+
and edge_indices needed to construct each Data object
|
624
|
+
"""
|
612
625
|
|
613
|
-
for
|
626
|
+
for kg_triplets in tqdm(triplet_groups, disable=not verbose):
|
627
|
+
kg_triplets_nodes, kg_triplets_edge_keys, kg_triplets_edge_index = tee(
|
628
|
+
kg_triplets, 3)
|
629
|
+
"""
|
630
|
+
Don't apply pre_transform here,
|
631
|
+
because it has already been applied on the triplet groups/
|
632
|
+
"""
|
614
633
|
small_graph_indexer = LargeGraphIndexer.from_triplets(
|
615
|
-
|
634
|
+
kg_triplets_nodes)
|
616
635
|
|
617
636
|
node_keys.append(small_graph_indexer.get_node_features())
|
618
|
-
edge_keys.append(
|
637
|
+
edge_keys.append(
|
638
|
+
small_graph_indexer.get_edge_features(pids=kg_triplets_edge_keys))
|
619
639
|
edge_index.append(
|
620
|
-
small_graph_indexer.get_edge_features(
|
621
|
-
|
622
|
-
|
623
|
-
|
624
|
-
|
625
|
-
|
626
|
-
|
627
|
-
|
628
|
-
|
629
|
-
|
630
|
-
|
631
|
-
|
632
|
-
|
633
|
-
edge_attr = torch.Tensor(edge_feats[last_edge_idx:last_edge_idx +
|
634
|
-
elen])
|
635
|
-
last_edge_idx += len(ekeys)
|
636
|
-
|
637
|
-
edge_idx = torch.LongTensor(eidx).T
|
638
|
-
|
639
|
-
data_obj = Data(x=x, edge_attr=edge_attr, edge_index=edge_idx)
|
640
|
-
data_obj[NODE_PID] = node_keys
|
641
|
-
data_obj[EDGE_PID] = edge_keys
|
642
|
-
data_obj["node_idx"] = [indexer._nodes[k] for k in nkeys]
|
643
|
-
data_obj["edge_idx"] = [indexer._edges[e] for e in ekeys]
|
640
|
+
small_graph_indexer.get_edge_features(
|
641
|
+
EDGE_INDEX,
|
642
|
+
kg_triplets_edge_index,
|
643
|
+
))
|
644
|
+
"""
|
645
|
+
We get the embeddings for each node and edge key in the KG,
|
646
|
+
but we need to do so in batches.
|
647
|
+
Batches that are too small waste compute time,
|
648
|
+
as each call to get features has an upfront cost.
|
649
|
+
Batches that are too large waste memory,
|
650
|
+
as we need to store all the result embeddings in memory.
|
651
|
+
"""
|
644
652
|
|
645
|
-
|
653
|
+
def _fetch_feature_batch(batches):
|
654
|
+
node_key_batch, edge_key_batch, edge_index_batch = batches
|
655
|
+
node_feats = indexer.get_node_features(
|
656
|
+
feature_name=node_feature_name,
|
657
|
+
pids=chain.from_iterable(node_key_batch))
|
658
|
+
edge_feats = indexer.get_edge_features(
|
659
|
+
feature_name=edge_feature_name,
|
660
|
+
pids=chain.from_iterable(edge_key_batch))
|
661
|
+
|
662
|
+
last_node_idx, last_edge_idx = 0, 0
|
663
|
+
for (nkeys, ekeys, eidx) in zip(node_key_batch, edge_key_batch,
|
664
|
+
edge_index_batch):
|
665
|
+
nlen, elen = len(nkeys), len(ekeys)
|
666
|
+
x = torch.Tensor(node_feats[last_node_idx:last_node_idx + nlen])
|
667
|
+
last_node_idx += len(nkeys)
|
668
|
+
|
669
|
+
edge_attr = torch.Tensor(edge_feats[last_edge_idx:last_edge_idx +
|
670
|
+
elen])
|
671
|
+
last_edge_idx += len(ekeys)
|
672
|
+
|
673
|
+
edge_idx = torch.LongTensor(eidx).T
|
674
|
+
|
675
|
+
data_obj = Data(x=x, edge_attr=edge_attr, edge_index=edge_idx)
|
676
|
+
data_obj[NODE_PID] = node_keys
|
677
|
+
data_obj[EDGE_PID] = edge_keys
|
678
|
+
data_obj["node_idx"] = [indexer._nodes[k] for k in nkeys]
|
679
|
+
data_obj["edge_idx"] = [indexer._edges[e] for e in ekeys]
|
680
|
+
|
681
|
+
yield data_obj
|
682
|
+
|
683
|
+
# NOTE: Backport of itertools.batched from Python 3.12
|
684
|
+
def batched(iterable, n, *, strict=False):
|
685
|
+
# batched('ABCDEFG', 3) → ABC DEF G
|
686
|
+
if n < 1:
|
687
|
+
raise ValueError('n must be at least one')
|
688
|
+
iterator = iter(iterable)
|
689
|
+
while batch := tuple(islice(iterator, n)):
|
690
|
+
if strict and len(batch) != n:
|
691
|
+
raise ValueError('batched(): incomplete batch')
|
692
|
+
yield batch
|
693
|
+
|
694
|
+
import multiprocessing as mp
|
695
|
+
import multiprocessing.pool as mpp
|
696
|
+
num_workers = num_workers if num_workers is not None else mp.cpu_count()
|
697
|
+
ideal_batch_size = min(max_batch_size,
|
698
|
+
max(1,
|
699
|
+
len(triplet_groups) // num_workers))
|
700
|
+
|
701
|
+
node_key_batches = batched(node_keys, ideal_batch_size)
|
702
|
+
edge_key_batches = batched(edge_keys, ideal_batch_size)
|
703
|
+
edge_index_batches = batched(edge_index, ideal_batch_size)
|
704
|
+
batches = zip(node_key_batches, edge_key_batches, edge_index_batches)
|
705
|
+
|
706
|
+
with mpp.ThreadPool() as pool:
|
707
|
+
result = pool.map(_fetch_feature_batch, batches)
|
708
|
+
yield from chain.from_iterable(result)
|
646
709
|
|
647
710
|
|
648
711
|
def get_features_for_triplets(
|
@@ -650,7 +713,7 @@ def get_features_for_triplets(
|
|
650
713
|
triplets: KnowledgeGraphLike,
|
651
714
|
node_feature_name: str = "x",
|
652
715
|
edge_feature_name: str = "edge_attr",
|
653
|
-
pre_transform:
|
716
|
+
pre_transform: Callable[[TripletLike], TripletLike] = lambda trip: trip,
|
654
717
|
verbose: bool = False,
|
655
718
|
) -> Data:
|
656
719
|
"""For a given set of triplets retrieve a Data object containing the
|
@@ -663,7 +726,7 @@ def get_features_for_triplets(
|
|
663
726
|
Defaults to "x".
|
664
727
|
edge_feature_name (str, optional): Feature to use for edge features.
|
665
728
|
Defaults to "edge_attr".
|
666
|
-
pre_transform (
|
729
|
+
pre_transform (Callable[[TripletLike], TripletLike]):
|
667
730
|
Optional preprocessing function for triplets. Defaults to None.
|
668
731
|
verbose (bool, optional): Whether to print progress. Defaults to False.
|
669
732
|
|
@@ -674,5 +737,5 @@ def get_features_for_triplets(
|
|
674
737
|
gen = get_features_for_triplets_groups(indexer, [triplets],
|
675
738
|
node_feature_name,
|
676
739
|
edge_feature_name, pre_transform,
|
677
|
-
verbose)
|
740
|
+
verbose, max_batch_size=1)
|
678
741
|
return next(gen)
|
@@ -0,0 +1,23 @@
|
|
1
|
+
from .sentence_transformer import SentenceTransformer
|
2
|
+
from .vision_transformer import VisionTransformer
|
3
|
+
from .llm import LLM
|
4
|
+
from .txt2kg import TXT2KG
|
5
|
+
from .llm_judge import LLMJudge
|
6
|
+
from .g_retriever import GRetriever
|
7
|
+
from .molecule_gpt import MoleculeGPT
|
8
|
+
from .glem import GLEM
|
9
|
+
from .protein_mpnn import ProteinMPNN
|
10
|
+
from .git_mol import GITMol
|
11
|
+
|
12
|
+
__all__ = [
|
13
|
+
'SentenceTransformer',
|
14
|
+
'VisionTransformer',
|
15
|
+
'LLM',
|
16
|
+
'LLMJudge',
|
17
|
+
'TXT2KG',
|
18
|
+
'GRetriever',
|
19
|
+
'MoleculeGPT',
|
20
|
+
'GLEM',
|
21
|
+
'ProteinMPNN',
|
22
|
+
'GITMol',
|
23
|
+
]
|
@@ -3,7 +3,7 @@ from typing import List, Optional
|
|
3
3
|
import torch
|
4
4
|
from torch import Tensor
|
5
5
|
|
6
|
-
from torch_geometric.
|
6
|
+
from torch_geometric.llm.models.llm import LLM, MAX_NEW_TOKENS
|
7
7
|
from torch_geometric.utils import scatter
|
8
8
|
|
9
9
|
|
@@ -19,8 +19,6 @@ class GRetriever(torch.nn.Module):
|
|
19
19
|
:obj:`peft` for training the LLM, see
|
20
20
|
`here <https://huggingface.co/docs/peft/en/index>`_ for details.
|
21
21
|
(default: :obj:`False`)
|
22
|
-
mlp_out_channels (int, optional): The size of each graph embedding
|
23
|
-
after projection. (default: :obj:`4096`)
|
24
22
|
mlp_out_tokens (int, optional): Number of LLM prefix tokens to
|
25
23
|
reserve for GNN output. (default: :obj:`1`)
|
26
24
|
|
@@ -42,15 +40,14 @@ class GRetriever(torch.nn.Module):
|
|
42
40
|
def __init__(
|
43
41
|
self,
|
44
42
|
llm: LLM,
|
45
|
-
gnn: torch.nn.Module,
|
43
|
+
gnn: torch.nn.Module = None,
|
46
44
|
use_lora: bool = False,
|
47
|
-
mlp_out_channels: int = 4096,
|
48
45
|
mlp_out_tokens: int = 1,
|
49
46
|
) -> None:
|
50
47
|
super().__init__()
|
51
48
|
|
52
49
|
self.llm = llm
|
53
|
-
self.gnn = gnn.to(self.llm.device)
|
50
|
+
self.gnn = gnn.to(self.llm.device) if gnn is not None else None
|
54
51
|
|
55
52
|
self.word_embedding = self.llm.word_embedding
|
56
53
|
self.llm_generator = self.llm.llm
|
@@ -76,14 +73,18 @@ class GRetriever(torch.nn.Module):
|
|
76
73
|
)
|
77
74
|
self.llm_generator = get_peft_model(self.llm_generator, config)
|
78
75
|
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
torch.nn.
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
76
|
+
if self.gnn is not None:
|
77
|
+
mlp_out_channels = llm.word_embedding.embedding_dim
|
78
|
+
mlp_hidden_channels = self.gnn.out_channels
|
79
|
+
self.projector = torch.nn.Sequential(
|
80
|
+
torch.nn.Linear(mlp_hidden_channels, mlp_hidden_channels),
|
81
|
+
torch.nn.Sigmoid(),
|
82
|
+
torch.nn.Linear(mlp_hidden_channels,
|
83
|
+
mlp_out_channels * mlp_out_tokens),
|
84
|
+
torch.nn.Unflatten(-1, (mlp_out_tokens, mlp_out_channels)),
|
85
|
+
).to(self.llm.device)
|
86
|
+
|
87
|
+
self.seq_length_stats = []
|
87
88
|
|
88
89
|
def encode(
|
89
90
|
self,
|
@@ -98,7 +99,16 @@ class GRetriever(torch.nn.Module):
|
|
98
99
|
edge_attr = edge_attr.to(self.llm.device)
|
99
100
|
batch = batch.to(self.llm.device)
|
100
101
|
|
101
|
-
|
102
|
+
model_specific_kwargs = {}
|
103
|
+
|
104
|
+
# duck typing for SGFormer to get around circular import
|
105
|
+
if (hasattr(self.gnn, 'trans_conv')
|
106
|
+
and hasattr(self.gnn, 'graph_conv')):
|
107
|
+
model_specific_kwargs['batch'] = batch
|
108
|
+
else:
|
109
|
+
model_specific_kwargs['edge_attr'] = edge_attr
|
110
|
+
|
111
|
+
out = self.gnn(x, edge_index, **model_specific_kwargs)
|
102
112
|
return scatter(out, batch, dim=0, reduce='mean')
|
103
113
|
|
104
114
|
def forward(
|
@@ -127,27 +137,32 @@ class GRetriever(torch.nn.Module):
|
|
127
137
|
to give to the LLM, such as textified knowledge graphs.
|
128
138
|
(default: :obj:`None`)
|
129
139
|
"""
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
140
|
+
xs = None
|
141
|
+
if self.gnn is not None:
|
142
|
+
x = self.encode(x, edge_index, batch, edge_attr)
|
143
|
+
x = self.projector(x)
|
144
|
+
xs = x.split(1, dim=0)
|
145
|
+
|
146
|
+
# Handle case where theres more than one embedding for each sample
|
147
|
+
xs = [x.squeeze(0) for x in xs]
|
148
|
+
|
149
|
+
# Handle questions without node features:
|
150
|
+
batch_unique = batch.unique()
|
151
|
+
batch_size = len(question)
|
152
|
+
if len(batch_unique) < batch_size:
|
153
|
+
xs = [
|
154
|
+
xs[i] if i in batch_unique else None
|
155
|
+
for i in range(batch_size)
|
156
|
+
]
|
145
157
|
(
|
146
158
|
inputs_embeds,
|
147
159
|
attention_mask,
|
148
160
|
label_input_ids,
|
149
161
|
) = self.llm._get_embeds(question, additional_text_context, xs, label)
|
150
162
|
|
163
|
+
max_seq_len = inputs_embeds.size(1)
|
164
|
+
self.seq_length_stats.append(max_seq_len)
|
165
|
+
|
151
166
|
with self.llm.autocast_context:
|
152
167
|
outputs = self.llm_generator(
|
153
168
|
inputs_embeds=inputs_embeds,
|
@@ -186,35 +201,39 @@ class GRetriever(torch.nn.Module):
|
|
186
201
|
max_out_tokens (int, optional): How many tokens for the LLM to
|
187
202
|
generate. (default: :obj:`32`)
|
188
203
|
"""
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
204
|
+
xs = None
|
205
|
+
if self.gnn is not None:
|
206
|
+
x = self.encode(x, edge_index, batch, edge_attr)
|
207
|
+
x = self.projector(x)
|
208
|
+
xs = x.split(1, dim=0)
|
209
|
+
|
210
|
+
# Handle case where theres more than one embedding for each sample
|
211
|
+
xs = [x.squeeze(0) for x in xs]
|
212
|
+
|
213
|
+
# Handle questions without node features:
|
214
|
+
batch_unique = batch.unique()
|
215
|
+
batch_size = len(question)
|
216
|
+
if len(batch_unique) < batch_size:
|
217
|
+
xs = [
|
218
|
+
xs[i] if i in batch_unique else None
|
219
|
+
for i in range(batch_size)
|
220
|
+
]
|
203
221
|
|
204
222
|
inputs_embeds, attention_mask, _ = self.llm._get_embeds(
|
205
223
|
question, additional_text_context, xs)
|
206
224
|
|
207
|
-
bos_token = self.llm.tokenizer(
|
208
|
-
|
209
|
-
|
210
|
-
).input_ids[0]
|
225
|
+
# bos_token = self.llm.tokenizer(
|
226
|
+
# self.llm.tokenizer.bos_token_id,
|
227
|
+
# add_special_tokens=False,
|
228
|
+
# ).input_ids[0]
|
211
229
|
|
212
230
|
with self.llm.autocast_context:
|
213
231
|
outputs = self.llm_generator.generate(
|
214
232
|
inputs_embeds=inputs_embeds,
|
215
233
|
max_new_tokens=max_out_tokens,
|
216
234
|
attention_mask=attention_mask,
|
217
|
-
bos_token_id=
|
235
|
+
bos_token_id=self.llm.tokenizer.bos_token_id,
|
236
|
+
pad_token_id=self.llm.tokenizer.eos_token_id,
|
218
237
|
use_cache=True # Important to set!
|
219
238
|
)
|
220
239
|
|
@@ -5,8 +5,8 @@ import torch.nn.functional as F
|
|
5
5
|
from torch import Tensor
|
6
6
|
from torch.nn import BatchNorm1d, LayerNorm, Linear, ReLU, Sequential
|
7
7
|
|
8
|
+
from torch_geometric.llm.models import SentenceTransformer, VisionTransformer
|
8
9
|
from torch_geometric.nn import GINEConv
|
9
|
-
from torch_geometric.nn.nlp import SentenceTransformer, VisionTransformer
|
10
10
|
from torch_geometric.utils import add_self_loops, to_dense_batch
|
11
11
|
|
12
12
|
|