pyg-nightly 2.7.0.dev20250905__py3-none-any.whl → 2.7.0.dev20250906__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. {pyg_nightly-2.7.0.dev20250905.dist-info → pyg_nightly-2.7.0.dev20250906.dist-info}/METADATA +2 -1
  2. {pyg_nightly-2.7.0.dev20250905.dist-info → pyg_nightly-2.7.0.dev20250906.dist-info}/RECORD +32 -25
  3. torch_geometric/__init__.py +1 -1
  4. torch_geometric/data/__init__.py +0 -5
  5. torch_geometric/data/lightning/datamodule.py +2 -2
  6. torch_geometric/datasets/molecule_gpt_dataset.py +1 -1
  7. torch_geometric/datasets/web_qsp_dataset.py +262 -210
  8. torch_geometric/graphgym/imports.py +2 -2
  9. torch_geometric/llm/__init__.py +9 -0
  10. torch_geometric/{data → llm}/large_graph_indexer.py +124 -61
  11. torch_geometric/llm/models/__init__.py +23 -0
  12. torch_geometric/{nn → llm}/models/g_retriever.py +68 -49
  13. torch_geometric/{nn → llm}/models/git_mol.py +1 -1
  14. torch_geometric/{nn/nlp → llm/models}/llm.py +167 -33
  15. torch_geometric/llm/models/llm_judge.py +158 -0
  16. torch_geometric/{nn → llm}/models/molecule_gpt.py +1 -1
  17. torch_geometric/{nn/nlp → llm/models}/sentence_transformer.py +42 -8
  18. torch_geometric/llm/models/txt2kg.py +353 -0
  19. torch_geometric/llm/rag_loader.py +154 -0
  20. torch_geometric/llm/utils/backend_utils.py +442 -0
  21. torch_geometric/llm/utils/feature_store.py +169 -0
  22. torch_geometric/llm/utils/graph_store.py +199 -0
  23. torch_geometric/llm/utils/vectorrag.py +124 -0
  24. torch_geometric/loader/__init__.py +0 -4
  25. torch_geometric/nn/__init__.py +0 -1
  26. torch_geometric/nn/models/__init__.py +0 -10
  27. torch_geometric/nn/models/sgformer.py +2 -0
  28. torch_geometric/loader/rag_loader.py +0 -107
  29. torch_geometric/nn/nlp/__init__.py +0 -9
  30. {pyg_nightly-2.7.0.dev20250905.dist-info → pyg_nightly-2.7.0.dev20250906.dist-info}/WHEEL +0 -0
  31. {pyg_nightly-2.7.0.dev20250905.dist-info → pyg_nightly-2.7.0.dev20250906.dist-info}/licenses/LICENSE +0 -0
  32. /torch_geometric/{nn → llm}/models/glem.py +0 -0
  33. /torch_geometric/{nn → llm}/models/protein_mpnn.py +0 -0
  34. /torch_geometric/{nn/nlp → llm/models}/vision_transformer.py +0 -0
@@ -2,7 +2,7 @@ import os
2
2
  import pickle as pkl
3
3
  import shutil
4
4
  from dataclasses import dataclass
5
- from itertools import chain
5
+ from itertools import chain, islice, tee
6
6
  from typing import (
7
7
  Any,
8
8
  Callable,
@@ -37,15 +37,15 @@ def ordered_set(values: Iterable[str]) -> List[str]:
37
37
 
38
38
  # TODO: Refactor Node and Edge funcs and attrs to be accessible via an Enum?
39
39
 
40
- NODE_PID = "pid"
40
+ NODE_PID = "pid" # Encodes node id
41
41
 
42
42
  NODE_KEYS = {NODE_PID}
43
43
 
44
- EDGE_PID = "e_pid"
45
- EDGE_HEAD = "h"
46
- EDGE_RELATION = "r"
47
- EDGE_TAIL = "t"
48
- EDGE_INDEX = "edge_idx"
44
+ EDGE_PID = "e_pid" # Encodes source node, relation, destination node
45
+ EDGE_HEAD = "h" # Encodes source node
46
+ EDGE_RELATION = "r" # Encodes relation
47
+ EDGE_TAIL = "t" # Encodes destination node
48
+ EDGE_INDEX = "edge_idx" # Encodes source node, destination node
49
49
 
50
50
  EDGE_KEYS = {EDGE_PID, EDGE_HEAD, EDGE_RELATION, EDGE_TAIL, EDGE_INDEX}
51
51
 
@@ -88,6 +88,7 @@ class LargeGraphIndexer:
88
88
  Args:
89
89
  nodes (Iterable[str]): Node ids in the graph.
90
90
  edges (KnowledgeGraphLike): Edge ids in the graph.
91
+ Example: [("cats", "eat", "dogs")]
91
92
  node_attr (Optional[Dict[str, List[Any]]], optional): Mapping node
92
93
  attribute name and list of their values in order of unique node
93
94
  ids. Defaults to None.
@@ -148,7 +149,6 @@ class LargeGraphIndexer:
148
149
  self.edge_attr[EDGE_TAIL].append(t)
149
150
  self.edge_attr[EDGE_INDEX].append(
150
151
  (self._nodes[h], self._nodes[t]))
151
-
152
152
  for i, tup in enumerate(edges):
153
153
  self._edges[tup] = i
154
154
 
@@ -164,7 +164,8 @@ class LargeGraphIndexer:
164
164
 
165
165
  Args:
166
166
  triplets (KnowledgeGraphLike): Series of triplets representing
167
- knowledge graph relations.
167
+ knowledge graph relations. Example: [("cats", "eat", dogs")].
168
+ Note: Please ensure triplets are unique.
168
169
  pre_transform (Optional[Callable[[TripletLike], TripletLike]]):
169
170
  Optional preprocessing function to apply to triplets.
170
171
  Defaults to None.
@@ -173,8 +174,8 @@ class LargeGraphIndexer:
173
174
  LargeGraphIndexer: Index of unique nodes and edges.
174
175
  """
175
176
  # NOTE: Right now assumes that all trips can be loaded into memory
176
- nodes = set()
177
- edges = set()
177
+ nodes = []
178
+ edges = []
178
179
 
179
180
  if pre_transform is not None:
180
181
 
@@ -183,16 +184,17 @@ class LargeGraphIndexer:
183
184
  for trip in trips:
184
185
  yield pre_transform(trip)
185
186
 
186
- triplets = apply_transform(triplets)
187
+ triplets = list(apply_transform(triplets))
187
188
 
188
189
  for h, r, t in triplets:
189
190
 
190
191
  for node in (h, t):
191
- nodes.add(node)
192
+ nodes.append(node)
192
193
 
193
194
  edge_idx = (h, r, t)
194
- edges.add(edge_idx)
195
-
195
+ edges.append(edge_idx)
196
+ nodes = ordered_set(nodes)
197
+ edges = ordered_set(edges)
196
198
  return cls(list(nodes), list(edges))
197
199
 
198
200
  @classmethod
@@ -291,13 +293,12 @@ class LargeGraphIndexer:
291
293
  values = self.node_attr[feature_name].values
292
294
  else:
293
295
  values = self.node_attr[feature_name]
294
-
295
296
  # TODO: torch_geometric.utils.select
296
297
  if isinstance(values, torch.Tensor):
297
298
  idxs = list(
298
299
  self.get_node_features_iter(feature_name, pids,
299
300
  index_only=True))
300
- return values[torch.tensor(idxs)]
301
+ return values[torch.tensor(idxs).long()]
301
302
  return list(self.get_node_features_iter(feature_name, pids))
302
303
 
303
304
  def get_node_features_iter(
@@ -421,7 +422,7 @@ class LargeGraphIndexer:
421
422
  idxs = list(
422
423
  self.get_edge_features_iter(feature_name, pids,
423
424
  index_only=True))
424
- return values[torch.tensor(idxs)]
425
+ return values[torch.tensor(idxs).long()]
425
426
  return list(self.get_edge_features_iter(feature_name, pids))
426
427
 
427
428
  def get_edge_features_iter(
@@ -532,7 +533,6 @@ class LargeGraphIndexer:
532
533
  """
533
534
  x = torch.Tensor(self.get_node_features(node_feature_name))
534
535
  node_id = torch.LongTensor(range(len(x)))
535
-
536
536
  edge_index = torch.t(
537
537
  torch.LongTensor(self.get_edge_features(EDGE_INDEX)))
538
538
 
@@ -572,8 +572,10 @@ def get_features_for_triplets_groups(
572
572
  triplet_groups: Iterable[KnowledgeGraphLike],
573
573
  node_feature_name: str = "x",
574
574
  edge_feature_name: str = "edge_attr",
575
- pre_transform: Optional[Callable[[TripletLike], TripletLike]] = None,
575
+ pre_transform: Callable[[TripletLike], TripletLike] = lambda trip: trip,
576
576
  verbose: bool = False,
577
+ max_batch_size: int = 250,
578
+ num_workers: Optional[int] = None,
577
579
  ) -> Iterator[Data]:
578
580
  """Given an indexer and a series of triplet groups (like a dataset),
579
581
  retrieve the specified node and edge features for each triplet from the
@@ -587,62 +589,123 @@ def get_features_for_triplets_groups(
587
589
  Defaults to "x".
588
590
  edge_feature_name (str, optional): edge feature to fetch.
589
591
  Defaults to "edge_attr".
590
- pre_transform (Optional[Callable[[TripletLike], TripletLike]]):
592
+ pre_transform (Callable[[TripletLike], TripletLike]):
591
593
  Optional preprocessing to perform on triplets.
592
594
  Defaults to None.
593
- verbose (bool, optional): Whether to print progress. Defaults to False.
595
+ verbose (bool, optional): Whether to print progress.
596
+ Defaults to False.
597
+ max_batch_size (int, optional):
598
+ Maximum batch size for fetching features.
599
+ Defaults to 250.
600
+ num_workers (int, optional):
601
+ Number of workers to use for fetching features.
602
+ Defaults to None (all available).
594
603
 
595
604
  Yields:
596
605
  Iterator[Data]: For each triplet group, yield a data object containing
597
606
  the unique graph and features from the index.
598
607
  """
599
- if pre_transform is not None:
608
+ def apply_transform(trips: Iterable[TripletLike]) -> Iterator[TripletLike]:
609
+ for trip in trips:
610
+ yield pre_transform(tuple(trip))
600
611
 
601
- def apply_transform(trips):
602
- for trip in trips:
603
- yield pre_transform(tuple(trip))
604
-
605
- # TODO: Make this safe for large amounts of triplets?
606
- triplet_groups = (list(apply_transform(triplets))
607
- for triplets in triplet_groups)
612
+ # Carefully trying to avoid loading all triplets into memory at once
613
+ # While also still tracking the number of elements for tqdm
614
+ triplet_groups: List[Iterator[TripletLike]] = [
615
+ apply_transform(triplets) for triplets in triplet_groups
616
+ ]
608
617
 
609
618
  node_keys = []
610
619
  edge_keys = []
611
620
  edge_index = []
621
+ """
622
+ For each KG, we gather the node_indices, edge_keys,
623
+ and edge_indices needed to construct each Data object
624
+ """
612
625
 
613
- for triplets in tqdm(triplet_groups, disable=not verbose):
626
+ for kg_triplets in tqdm(triplet_groups, disable=not verbose):
627
+ kg_triplets_nodes, kg_triplets_edge_keys, kg_triplets_edge_index = tee(
628
+ kg_triplets, 3)
629
+ """
630
+ Don't apply pre_transform here,
631
+ because it has already been applied on the triplet groups/
632
+ """
614
633
  small_graph_indexer = LargeGraphIndexer.from_triplets(
615
- triplets, pre_transform=pre_transform)
634
+ kg_triplets_nodes)
616
635
 
617
636
  node_keys.append(small_graph_indexer.get_node_features())
618
- edge_keys.append(small_graph_indexer.get_edge_features(pids=triplets))
637
+ edge_keys.append(
638
+ small_graph_indexer.get_edge_features(pids=kg_triplets_edge_keys))
619
639
  edge_index.append(
620
- small_graph_indexer.get_edge_features(EDGE_INDEX, triplets))
621
-
622
- node_feats = indexer.get_node_features(feature_name=node_feature_name,
623
- pids=chain.from_iterable(node_keys))
624
- edge_feats = indexer.get_edge_features(feature_name=edge_feature_name,
625
- pids=chain.from_iterable(edge_keys))
626
-
627
- last_node_idx, last_edge_idx = 0, 0
628
- for (nkeys, ekeys, eidx) in zip(node_keys, edge_keys, edge_index):
629
- nlen, elen = len(nkeys), len(ekeys)
630
- x = torch.Tensor(node_feats[last_node_idx:last_node_idx + nlen])
631
- last_node_idx += len(nkeys)
632
-
633
- edge_attr = torch.Tensor(edge_feats[last_edge_idx:last_edge_idx +
634
- elen])
635
- last_edge_idx += len(ekeys)
636
-
637
- edge_idx = torch.LongTensor(eidx).T
638
-
639
- data_obj = Data(x=x, edge_attr=edge_attr, edge_index=edge_idx)
640
- data_obj[NODE_PID] = node_keys
641
- data_obj[EDGE_PID] = edge_keys
642
- data_obj["node_idx"] = [indexer._nodes[k] for k in nkeys]
643
- data_obj["edge_idx"] = [indexer._edges[e] for e in ekeys]
640
+ small_graph_indexer.get_edge_features(
641
+ EDGE_INDEX,
642
+ kg_triplets_edge_index,
643
+ ))
644
+ """
645
+ We get the embeddings for each node and edge key in the KG,
646
+ but we need to do so in batches.
647
+ Batches that are too small waste compute time,
648
+ as each call to get features has an upfront cost.
649
+ Batches that are too large waste memory,
650
+ as we need to store all the result embeddings in memory.
651
+ """
644
652
 
645
- yield data_obj
653
+ def _fetch_feature_batch(batches):
654
+ node_key_batch, edge_key_batch, edge_index_batch = batches
655
+ node_feats = indexer.get_node_features(
656
+ feature_name=node_feature_name,
657
+ pids=chain.from_iterable(node_key_batch))
658
+ edge_feats = indexer.get_edge_features(
659
+ feature_name=edge_feature_name,
660
+ pids=chain.from_iterable(edge_key_batch))
661
+
662
+ last_node_idx, last_edge_idx = 0, 0
663
+ for (nkeys, ekeys, eidx) in zip(node_key_batch, edge_key_batch,
664
+ edge_index_batch):
665
+ nlen, elen = len(nkeys), len(ekeys)
666
+ x = torch.Tensor(node_feats[last_node_idx:last_node_idx + nlen])
667
+ last_node_idx += len(nkeys)
668
+
669
+ edge_attr = torch.Tensor(edge_feats[last_edge_idx:last_edge_idx +
670
+ elen])
671
+ last_edge_idx += len(ekeys)
672
+
673
+ edge_idx = torch.LongTensor(eidx).T
674
+
675
+ data_obj = Data(x=x, edge_attr=edge_attr, edge_index=edge_idx)
676
+ data_obj[NODE_PID] = node_keys
677
+ data_obj[EDGE_PID] = edge_keys
678
+ data_obj["node_idx"] = [indexer._nodes[k] for k in nkeys]
679
+ data_obj["edge_idx"] = [indexer._edges[e] for e in ekeys]
680
+
681
+ yield data_obj
682
+
683
+ # NOTE: Backport of itertools.batched from Python 3.12
684
+ def batched(iterable, n, *, strict=False):
685
+ # batched('ABCDEFG', 3) → ABC DEF G
686
+ if n < 1:
687
+ raise ValueError('n must be at least one')
688
+ iterator = iter(iterable)
689
+ while batch := tuple(islice(iterator, n)):
690
+ if strict and len(batch) != n:
691
+ raise ValueError('batched(): incomplete batch')
692
+ yield batch
693
+
694
+ import multiprocessing as mp
695
+ import multiprocessing.pool as mpp
696
+ num_workers = num_workers if num_workers is not None else mp.cpu_count()
697
+ ideal_batch_size = min(max_batch_size,
698
+ max(1,
699
+ len(triplet_groups) // num_workers))
700
+
701
+ node_key_batches = batched(node_keys, ideal_batch_size)
702
+ edge_key_batches = batched(edge_keys, ideal_batch_size)
703
+ edge_index_batches = batched(edge_index, ideal_batch_size)
704
+ batches = zip(node_key_batches, edge_key_batches, edge_index_batches)
705
+
706
+ with mpp.ThreadPool() as pool:
707
+ result = pool.map(_fetch_feature_batch, batches)
708
+ yield from chain.from_iterable(result)
646
709
 
647
710
 
648
711
  def get_features_for_triplets(
@@ -650,7 +713,7 @@ def get_features_for_triplets(
650
713
  triplets: KnowledgeGraphLike,
651
714
  node_feature_name: str = "x",
652
715
  edge_feature_name: str = "edge_attr",
653
- pre_transform: Optional[Callable[[TripletLike], TripletLike]] = None,
716
+ pre_transform: Callable[[TripletLike], TripletLike] = lambda trip: trip,
654
717
  verbose: bool = False,
655
718
  ) -> Data:
656
719
  """For a given set of triplets retrieve a Data object containing the
@@ -663,7 +726,7 @@ def get_features_for_triplets(
663
726
  Defaults to "x".
664
727
  edge_feature_name (str, optional): Feature to use for edge features.
665
728
  Defaults to "edge_attr".
666
- pre_transform (Optional[Callable[[TripletLike], TripletLike]]):
729
+ pre_transform (Callable[[TripletLike], TripletLike]):
667
730
  Optional preprocessing function for triplets. Defaults to None.
668
731
  verbose (bool, optional): Whether to print progress. Defaults to False.
669
732
 
@@ -674,5 +737,5 @@ def get_features_for_triplets(
674
737
  gen = get_features_for_triplets_groups(indexer, [triplets],
675
738
  node_feature_name,
676
739
  edge_feature_name, pre_transform,
677
- verbose)
740
+ verbose, max_batch_size=1)
678
741
  return next(gen)
@@ -0,0 +1,23 @@
1
+ from .sentence_transformer import SentenceTransformer
2
+ from .vision_transformer import VisionTransformer
3
+ from .llm import LLM
4
+ from .txt2kg import TXT2KG
5
+ from .llm_judge import LLMJudge
6
+ from .g_retriever import GRetriever
7
+ from .molecule_gpt import MoleculeGPT
8
+ from .glem import GLEM
9
+ from .protein_mpnn import ProteinMPNN
10
+ from .git_mol import GITMol
11
+
12
+ __all__ = [
13
+ 'SentenceTransformer',
14
+ 'VisionTransformer',
15
+ 'LLM',
16
+ 'LLMJudge',
17
+ 'TXT2KG',
18
+ 'GRetriever',
19
+ 'MoleculeGPT',
20
+ 'GLEM',
21
+ 'ProteinMPNN',
22
+ 'GITMol',
23
+ ]
@@ -3,7 +3,7 @@ from typing import List, Optional
3
3
  import torch
4
4
  from torch import Tensor
5
5
 
6
- from torch_geometric.nn.nlp.llm import BOS, LLM, MAX_NEW_TOKENS
6
+ from torch_geometric.llm.models.llm import LLM, MAX_NEW_TOKENS
7
7
  from torch_geometric.utils import scatter
8
8
 
9
9
 
@@ -19,8 +19,6 @@ class GRetriever(torch.nn.Module):
19
19
  :obj:`peft` for training the LLM, see
20
20
  `here <https://huggingface.co/docs/peft/en/index>`_ for details.
21
21
  (default: :obj:`False`)
22
- mlp_out_channels (int, optional): The size of each graph embedding
23
- after projection. (default: :obj:`4096`)
24
22
  mlp_out_tokens (int, optional): Number of LLM prefix tokens to
25
23
  reserve for GNN output. (default: :obj:`1`)
26
24
 
@@ -42,15 +40,14 @@ class GRetriever(torch.nn.Module):
42
40
  def __init__(
43
41
  self,
44
42
  llm: LLM,
45
- gnn: torch.nn.Module,
43
+ gnn: torch.nn.Module = None,
46
44
  use_lora: bool = False,
47
- mlp_out_channels: int = 4096,
48
45
  mlp_out_tokens: int = 1,
49
46
  ) -> None:
50
47
  super().__init__()
51
48
 
52
49
  self.llm = llm
53
- self.gnn = gnn.to(self.llm.device)
50
+ self.gnn = gnn.to(self.llm.device) if gnn is not None else None
54
51
 
55
52
  self.word_embedding = self.llm.word_embedding
56
53
  self.llm_generator = self.llm.llm
@@ -76,14 +73,18 @@ class GRetriever(torch.nn.Module):
76
73
  )
77
74
  self.llm_generator = get_peft_model(self.llm_generator, config)
78
75
 
79
- mlp_hidden_channels = self.gnn.out_channels
80
- self.projector = torch.nn.Sequential(
81
- torch.nn.Linear(mlp_hidden_channels, mlp_hidden_channels),
82
- torch.nn.Sigmoid(),
83
- torch.nn.Linear(mlp_hidden_channels,
84
- mlp_out_channels * mlp_out_tokens),
85
- torch.nn.Unflatten(-1, (mlp_out_tokens, mlp_out_channels)),
86
- ).to(self.llm.device)
76
+ if self.gnn is not None:
77
+ mlp_out_channels = llm.word_embedding.embedding_dim
78
+ mlp_hidden_channels = self.gnn.out_channels
79
+ self.projector = torch.nn.Sequential(
80
+ torch.nn.Linear(mlp_hidden_channels, mlp_hidden_channels),
81
+ torch.nn.Sigmoid(),
82
+ torch.nn.Linear(mlp_hidden_channels,
83
+ mlp_out_channels * mlp_out_tokens),
84
+ torch.nn.Unflatten(-1, (mlp_out_tokens, mlp_out_channels)),
85
+ ).to(self.llm.device)
86
+
87
+ self.seq_length_stats = []
87
88
 
88
89
  def encode(
89
90
  self,
@@ -98,7 +99,16 @@ class GRetriever(torch.nn.Module):
98
99
  edge_attr = edge_attr.to(self.llm.device)
99
100
  batch = batch.to(self.llm.device)
100
101
 
101
- out = self.gnn(x, edge_index, edge_attr=edge_attr)
102
+ model_specific_kwargs = {}
103
+
104
+ # duck typing for SGFormer to get around circular import
105
+ if (hasattr(self.gnn, 'trans_conv')
106
+ and hasattr(self.gnn, 'graph_conv')):
107
+ model_specific_kwargs['batch'] = batch
108
+ else:
109
+ model_specific_kwargs['edge_attr'] = edge_attr
110
+
111
+ out = self.gnn(x, edge_index, **model_specific_kwargs)
102
112
  return scatter(out, batch, dim=0, reduce='mean')
103
113
 
104
114
  def forward(
@@ -127,27 +137,32 @@ class GRetriever(torch.nn.Module):
127
137
  to give to the LLM, such as textified knowledge graphs.
128
138
  (default: :obj:`None`)
129
139
  """
130
- x = self.encode(x, edge_index, batch, edge_attr)
131
- x = self.projector(x)
132
- xs = x.split(1, dim=0)
133
-
134
- # Handle case where there's more than one embedding for each sample
135
- xs = [x.squeeze(0) for x in xs]
136
-
137
- # Handle questions without node features:
138
- batch_unique = batch.unique()
139
- batch_size = len(question)
140
- if len(batch_unique) < batch_size:
141
- xs = [
142
- xs[i] if i in batch_unique else None for i in range(batch_size)
143
- ]
144
-
140
+ xs = None
141
+ if self.gnn is not None:
142
+ x = self.encode(x, edge_index, batch, edge_attr)
143
+ x = self.projector(x)
144
+ xs = x.split(1, dim=0)
145
+
146
+ # Handle case where theres more than one embedding for each sample
147
+ xs = [x.squeeze(0) for x in xs]
148
+
149
+ # Handle questions without node features:
150
+ batch_unique = batch.unique()
151
+ batch_size = len(question)
152
+ if len(batch_unique) < batch_size:
153
+ xs = [
154
+ xs[i] if i in batch_unique else None
155
+ for i in range(batch_size)
156
+ ]
145
157
  (
146
158
  inputs_embeds,
147
159
  attention_mask,
148
160
  label_input_ids,
149
161
  ) = self.llm._get_embeds(question, additional_text_context, xs, label)
150
162
 
163
+ max_seq_len = inputs_embeds.size(1)
164
+ self.seq_length_stats.append(max_seq_len)
165
+
151
166
  with self.llm.autocast_context:
152
167
  outputs = self.llm_generator(
153
168
  inputs_embeds=inputs_embeds,
@@ -186,35 +201,39 @@ class GRetriever(torch.nn.Module):
186
201
  max_out_tokens (int, optional): How many tokens for the LLM to
187
202
  generate. (default: :obj:`32`)
188
203
  """
189
- x = self.encode(x, edge_index, batch, edge_attr)
190
- x = self.projector(x)
191
- xs = x.split(1, dim=0)
192
-
193
- # Handle case where there's more than one embedding for each sample
194
- xs = [x.squeeze(0) for x in xs]
195
-
196
- # Handle questions without node features:
197
- batch_unique = batch.unique()
198
- batch_size = len(question)
199
- if len(batch_unique) < batch_size:
200
- xs = [
201
- xs[i] if i in batch_unique else None for i in range(batch_size)
202
- ]
204
+ xs = None
205
+ if self.gnn is not None:
206
+ x = self.encode(x, edge_index, batch, edge_attr)
207
+ x = self.projector(x)
208
+ xs = x.split(1, dim=0)
209
+
210
+ # Handle case where theres more than one embedding for each sample
211
+ xs = [x.squeeze(0) for x in xs]
212
+
213
+ # Handle questions without node features:
214
+ batch_unique = batch.unique()
215
+ batch_size = len(question)
216
+ if len(batch_unique) < batch_size:
217
+ xs = [
218
+ xs[i] if i in batch_unique else None
219
+ for i in range(batch_size)
220
+ ]
203
221
 
204
222
  inputs_embeds, attention_mask, _ = self.llm._get_embeds(
205
223
  question, additional_text_context, xs)
206
224
 
207
- bos_token = self.llm.tokenizer(
208
- BOS,
209
- add_special_tokens=False,
210
- ).input_ids[0]
225
+ # bos_token = self.llm.tokenizer(
226
+ # self.llm.tokenizer.bos_token_id,
227
+ # add_special_tokens=False,
228
+ # ).input_ids[0]
211
229
 
212
230
  with self.llm.autocast_context:
213
231
  outputs = self.llm_generator.generate(
214
232
  inputs_embeds=inputs_embeds,
215
233
  max_new_tokens=max_out_tokens,
216
234
  attention_mask=attention_mask,
217
- bos_token_id=bos_token,
235
+ bos_token_id=self.llm.tokenizer.bos_token_id,
236
+ pad_token_id=self.llm.tokenizer.eos_token_id,
218
237
  use_cache=True # Important to set!
219
238
  )
220
239
 
@@ -5,8 +5,8 @@ import torch.nn.functional as F
5
5
  from torch import Tensor
6
6
  from torch.nn import BatchNorm1d, LayerNorm, Linear, ReLU, Sequential
7
7
 
8
+ from torch_geometric.llm.models import SentenceTransformer, VisionTransformer
8
9
  from torch_geometric.nn import GINEConv
9
- from torch_geometric.nn.nlp import SentenceTransformer, VisionTransformer
10
10
  from torch_geometric.utils import add_self_loops, to_dense_batch
11
11
 
12
12