pyg-nightly 2.7.0.dev20241124__py3-none-any.whl → 2.7.0.dev20241127__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pyg_nightly-2.7.0.dev20241124.dist-info → pyg_nightly-2.7.0.dev20241127.dist-info}/METADATA +1 -1
- {pyg_nightly-2.7.0.dev20241124.dist-info → pyg_nightly-2.7.0.dev20241127.dist-info}/RECORD +19 -13
- torch_geometric/__init__.py +1 -1
- torch_geometric/data/__init__.py +5 -0
- torch_geometric/data/large_graph_indexer.py +677 -0
- torch_geometric/datasets/__init__.py +2 -0
- torch_geometric/datasets/git_mol_dataset.py +263 -0
- torch_geometric/loader/__init__.py +2 -0
- torch_geometric/loader/rag_loader.py +106 -0
- torch_geometric/nn/models/__init__.py +2 -0
- torch_geometric/nn/models/g_retriever.py +12 -1
- torch_geometric/nn/models/git_mol.py +336 -0
- torch_geometric/nn/nlp/__init__.py +2 -0
- torch_geometric/nn/nlp/sentence_transformer.py +30 -0
- torch_geometric/nn/nlp/vision_transformer.py +33 -0
- torch_geometric/profile/__init__.py +2 -0
- torch_geometric/profile/nvtx.py +66 -0
- torch_geometric/sampler/base.py +8 -0
- {pyg_nightly-2.7.0.dev20241124.dist-info → pyg_nightly-2.7.0.dev20241127.dist-info}/WHEEL +0 -0
@@ -0,0 +1,263 @@
|
|
1
|
+
import sys
|
2
|
+
from typing import Any, Callable, Dict, List, Optional
|
3
|
+
|
4
|
+
import numpy as np
|
5
|
+
import torch
|
6
|
+
from tqdm import tqdm
|
7
|
+
|
8
|
+
from torch_geometric.data import (
|
9
|
+
Data,
|
10
|
+
InMemoryDataset,
|
11
|
+
download_google_url,
|
12
|
+
extract_zip,
|
13
|
+
)
|
14
|
+
from torch_geometric.io import fs
|
15
|
+
|
16
|
+
|
17
|
+
def safe_index(lst: List[Any], e: int) -> int:
|
18
|
+
return lst.index(e) if e in lst else len(lst) - 1
|
19
|
+
|
20
|
+
|
21
|
+
class GitMolDataset(InMemoryDataset):
|
22
|
+
r"""The dataset from the `"GIT-Mol: A Multi-modal Large Language Model
|
23
|
+
for Molecular Science with Graph, Image, and Text"
|
24
|
+
<https://arxiv.org/pdf/2308.06911>`_ paper.
|
25
|
+
|
26
|
+
Args:
|
27
|
+
root (str): Root directory where the dataset should be saved.
|
28
|
+
transform (callable, optional): A function/transform that takes in an
|
29
|
+
:obj:`torch_geometric.data.Data` object and returns a transformed
|
30
|
+
version. The data object will be transformed before every access.
|
31
|
+
(default: :obj:`None`)
|
32
|
+
pre_transform (callable, optional): A function/transform that takes in
|
33
|
+
an :obj:`torch_geometric.data.Data` object and returns a
|
34
|
+
transformed version. The data object will be transformed before
|
35
|
+
being saved to disk. (default: :obj:`None`)
|
36
|
+
pre_filter (callable, optional): A function that takes in an
|
37
|
+
:obj:`torch_geometric.data.Data` object and returns a boolean
|
38
|
+
value, indicating whether the data object should be included in the
|
39
|
+
final dataset. (default: :obj:`None`)
|
40
|
+
force_reload (bool, optional): Whether to re-process the dataset.
|
41
|
+
(default: :obj:`False`)
|
42
|
+
split (int, optional): Datasets split, train/valid/test=0/1/2.
|
43
|
+
(default: :obj:`0`)
|
44
|
+
"""
|
45
|
+
|
46
|
+
raw_url_id = '1loBXabD6ncAFY-vanRsVtRUSFkEtBweg'
|
47
|
+
|
48
|
+
def __init__(
|
49
|
+
self,
|
50
|
+
root: str,
|
51
|
+
transform: Optional[Callable] = None,
|
52
|
+
pre_transform: Optional[Callable] = None,
|
53
|
+
pre_filter: Optional[Callable] = None,
|
54
|
+
force_reload: bool = False,
|
55
|
+
split: int = 0,
|
56
|
+
):
|
57
|
+
from torchvision import transforms
|
58
|
+
|
59
|
+
self.split = split
|
60
|
+
|
61
|
+
if self.split == 0:
|
62
|
+
self.img_transform = transforms.Compose([
|
63
|
+
transforms.Resize((224, 224)),
|
64
|
+
transforms.RandomRotation(15),
|
65
|
+
transforms.ColorJitter(brightness=0.5, contrast=0.5, hue=0.5),
|
66
|
+
transforms.ToTensor(),
|
67
|
+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
68
|
+
std=[0.229, 0.224, 0.225])
|
69
|
+
])
|
70
|
+
else:
|
71
|
+
self.img_transform = transforms.Compose([
|
72
|
+
transforms.Resize((224, 224)),
|
73
|
+
transforms.ToTensor(),
|
74
|
+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
75
|
+
std=[0.229, 0.224, 0.225])
|
76
|
+
])
|
77
|
+
|
78
|
+
super().__init__(root, transform, pre_transform, pre_filter,
|
79
|
+
force_reload=force_reload)
|
80
|
+
|
81
|
+
self.load(self.processed_paths[0])
|
82
|
+
|
83
|
+
@property
|
84
|
+
def raw_file_names(self) -> List[str]:
|
85
|
+
return ['train_3500.pkl', 'valid_450.pkl', 'test_450.pkl']
|
86
|
+
|
87
|
+
@property
|
88
|
+
def processed_file_names(self) -> str:
|
89
|
+
return ['train.pt', 'valid.pt', 'test.pt'][self.split]
|
90
|
+
|
91
|
+
def download(self) -> None:
|
92
|
+
file_path = download_google_url(
|
93
|
+
self.raw_url_id,
|
94
|
+
self.raw_dir,
|
95
|
+
'gitmol.zip',
|
96
|
+
)
|
97
|
+
extract_zip(file_path, self.raw_dir)
|
98
|
+
|
99
|
+
def process(self) -> None:
|
100
|
+
import pandas as pd
|
101
|
+
from PIL import Image
|
102
|
+
|
103
|
+
try:
|
104
|
+
from rdkit import Chem, RDLogger
|
105
|
+
RDLogger.DisableLog('rdApp.*') # type: ignore
|
106
|
+
WITH_RDKIT = True
|
107
|
+
|
108
|
+
except ImportError:
|
109
|
+
WITH_RDKIT = False
|
110
|
+
|
111
|
+
if not WITH_RDKIT:
|
112
|
+
print(("Using a pre-processed version of the dataset. Please "
|
113
|
+
"install 'rdkit' to alternatively process the raw data."),
|
114
|
+
file=sys.stderr)
|
115
|
+
|
116
|
+
data_list = fs.torch_load(self.raw_paths[0])
|
117
|
+
data_list = [Data(**data_dict) for data_dict in data_list]
|
118
|
+
|
119
|
+
if self.pre_filter is not None:
|
120
|
+
data_list = [d for d in data_list if self.pre_filter(d)]
|
121
|
+
|
122
|
+
if self.pre_transform is not None:
|
123
|
+
data_list = [self.pre_transform(d) for d in data_list]
|
124
|
+
|
125
|
+
self.save(data_list, self.processed_paths[0])
|
126
|
+
return
|
127
|
+
|
128
|
+
allowable_features: Dict[str, List[Any]] = {
|
129
|
+
'possible_atomic_num_list':
|
130
|
+
list(range(1, 119)) + ['misc'],
|
131
|
+
'possible_formal_charge_list':
|
132
|
+
[-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 'misc'],
|
133
|
+
'possible_chirality_list': [
|
134
|
+
Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
|
135
|
+
Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
|
136
|
+
Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
|
137
|
+
Chem.rdchem.ChiralType.CHI_OTHER
|
138
|
+
],
|
139
|
+
'possible_hybridization_list': [
|
140
|
+
Chem.rdchem.HybridizationType.SP,
|
141
|
+
Chem.rdchem.HybridizationType.SP2,
|
142
|
+
Chem.rdchem.HybridizationType.SP3,
|
143
|
+
Chem.rdchem.HybridizationType.SP3D,
|
144
|
+
Chem.rdchem.HybridizationType.SP3D2,
|
145
|
+
Chem.rdchem.HybridizationType.UNSPECIFIED, 'misc'
|
146
|
+
],
|
147
|
+
'possible_numH_list': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'],
|
148
|
+
'possible_implicit_valence_list': [0, 1, 2, 3, 4, 5, 6],
|
149
|
+
'possible_degree_list': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 'misc'],
|
150
|
+
'possible_number_radical_e_list': [0, 1, 2, 3, 4, 'misc'],
|
151
|
+
'possible_is_aromatic_list': [False, True],
|
152
|
+
'possible_is_in_ring_list': [False, True],
|
153
|
+
'possible_bond_type_list': [
|
154
|
+
Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE,
|
155
|
+
Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC,
|
156
|
+
Chem.rdchem.BondType.ZERO
|
157
|
+
],
|
158
|
+
'possible_bond_dirs': [ # only for double bond stereo information
|
159
|
+
Chem.rdchem.BondDir.NONE, Chem.rdchem.BondDir.ENDUPRIGHT,
|
160
|
+
Chem.rdchem.BondDir.ENDDOWNRIGHT
|
161
|
+
],
|
162
|
+
'possible_bond_stereo_list': [
|
163
|
+
Chem.rdchem.BondStereo.STEREONONE,
|
164
|
+
Chem.rdchem.BondStereo.STEREOZ,
|
165
|
+
Chem.rdchem.BondStereo.STEREOE,
|
166
|
+
Chem.rdchem.BondStereo.STEREOCIS,
|
167
|
+
Chem.rdchem.BondStereo.STEREOTRANS,
|
168
|
+
Chem.rdchem.BondStereo.STEREOANY,
|
169
|
+
],
|
170
|
+
'possible_is_conjugated_list': [False, True]
|
171
|
+
}
|
172
|
+
|
173
|
+
data = pd.read_pickle(
|
174
|
+
f'{self.raw_dir}/igcdata_toy/{self.raw_file_names[self.split]}')
|
175
|
+
|
176
|
+
data_list = []
|
177
|
+
for _, r in tqdm(data.iterrows(), total=data.shape[0]):
|
178
|
+
smiles = r['isosmiles']
|
179
|
+
mol = Chem.MolFromSmiles(smiles.strip('\n'))
|
180
|
+
if mol is not None:
|
181
|
+
# text
|
182
|
+
summary = r['summary']
|
183
|
+
# image
|
184
|
+
cid = r['cid']
|
185
|
+
img_file = f'{self.raw_dir}/igcdata_toy/imgs/CID_{cid}.png'
|
186
|
+
img = Image.open(img_file).convert('RGB')
|
187
|
+
img = self.img_transform(img).unsqueeze(0)
|
188
|
+
# graph
|
189
|
+
atom_features_list = []
|
190
|
+
for atom in mol.GetAtoms(): # type: ignore
|
191
|
+
atom_feature = [
|
192
|
+
safe_index(
|
193
|
+
allowable_features['possible_atomic_num_list'],
|
194
|
+
atom.GetAtomicNum()),
|
195
|
+
allowable_features['possible_chirality_list'].index(
|
196
|
+
atom.GetChiralTag()),
|
197
|
+
safe_index(allowable_features['possible_degree_list'],
|
198
|
+
atom.GetTotalDegree()),
|
199
|
+
safe_index(
|
200
|
+
allowable_features['possible_formal_charge_list'],
|
201
|
+
atom.GetFormalCharge()),
|
202
|
+
safe_index(allowable_features['possible_numH_list'],
|
203
|
+
atom.GetTotalNumHs()),
|
204
|
+
safe_index(
|
205
|
+
allowable_features[
|
206
|
+
'possible_number_radical_e_list'],
|
207
|
+
atom.GetNumRadicalElectrons()),
|
208
|
+
safe_index(
|
209
|
+
allowable_features['possible_hybridization_list'],
|
210
|
+
atom.GetHybridization()),
|
211
|
+
allowable_features['possible_is_aromatic_list'].index(
|
212
|
+
atom.GetIsAromatic()),
|
213
|
+
allowable_features['possible_is_in_ring_list'].index(
|
214
|
+
atom.IsInRing()),
|
215
|
+
]
|
216
|
+
atom_features_list.append(atom_feature)
|
217
|
+
x = torch.tensor(np.array(atom_features_list),
|
218
|
+
dtype=torch.long)
|
219
|
+
|
220
|
+
edges_list = []
|
221
|
+
edge_features_list = []
|
222
|
+
for bond in mol.GetBonds(): # type: ignore
|
223
|
+
i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
|
224
|
+
edge_feature = [
|
225
|
+
safe_index(
|
226
|
+
allowable_features['possible_bond_type_list'],
|
227
|
+
bond.GetBondType()),
|
228
|
+
allowable_features['possible_bond_stereo_list'].index(
|
229
|
+
bond.GetStereo()),
|
230
|
+
allowable_features['possible_is_conjugated_list'].
|
231
|
+
index(bond.GetIsConjugated()),
|
232
|
+
]
|
233
|
+
edges_list.append((i, j))
|
234
|
+
edge_features_list.append(edge_feature)
|
235
|
+
edges_list.append((j, i))
|
236
|
+
edge_features_list.append(edge_feature)
|
237
|
+
|
238
|
+
edge_index = torch.tensor(
|
239
|
+
np.array(edges_list).T,
|
240
|
+
dtype=torch.long,
|
241
|
+
)
|
242
|
+
edge_attr = torch.tensor(
|
243
|
+
np.array(edge_features_list),
|
244
|
+
dtype=torch.long,
|
245
|
+
)
|
246
|
+
|
247
|
+
data = Data(
|
248
|
+
x=x,
|
249
|
+
edge_index=edge_index,
|
250
|
+
smiles=smiles,
|
251
|
+
edge_attr=edge_attr,
|
252
|
+
image=img,
|
253
|
+
caption=summary,
|
254
|
+
)
|
255
|
+
|
256
|
+
if self.pre_filter is not None and not self.pre_filter(data):
|
257
|
+
continue
|
258
|
+
if self.pre_transform is not None:
|
259
|
+
data = self.pre_transform(data)
|
260
|
+
|
261
|
+
data_list.append(data)
|
262
|
+
|
263
|
+
self.save(data_list, self.processed_paths[0])
|
@@ -22,6 +22,7 @@ from .dynamic_batch_sampler import DynamicBatchSampler
|
|
22
22
|
from .prefetch import PrefetchLoader
|
23
23
|
from .cache import CachedLoader
|
24
24
|
from .mixin import AffinityMixin
|
25
|
+
from .rag_loader import RAGQueryLoader
|
25
26
|
|
26
27
|
__all__ = classes = [
|
27
28
|
'DataLoader',
|
@@ -50,6 +51,7 @@ __all__ = classes = [
|
|
50
51
|
'PrefetchLoader',
|
51
52
|
'CachedLoader',
|
52
53
|
'AffinityMixin',
|
54
|
+
'RAGQueryLoader',
|
53
55
|
]
|
54
56
|
|
55
57
|
RandomNodeSampler = deprecated(
|
@@ -0,0 +1,106 @@
|
|
1
|
+
from abc import abstractmethod
|
2
|
+
from typing import Any, Callable, Dict, Optional, Protocol, Tuple, Union
|
3
|
+
|
4
|
+
from torch_geometric.data import Data, FeatureStore, HeteroData
|
5
|
+
from torch_geometric.sampler import HeteroSamplerOutput, SamplerOutput
|
6
|
+
from torch_geometric.typing import InputEdges, InputNodes
|
7
|
+
|
8
|
+
|
9
|
+
class RAGFeatureStore(Protocol):
|
10
|
+
"""Feature store for remote GNN RAG backend."""
|
11
|
+
@abstractmethod
|
12
|
+
def retrieve_seed_nodes(self, query: Any, **kwargs) -> InputNodes:
|
13
|
+
"""Makes a comparison between the query and all the nodes to get all
|
14
|
+
the closest nodes. Return the indices of the nodes that are to be seeds
|
15
|
+
for the RAG Sampler.
|
16
|
+
"""
|
17
|
+
...
|
18
|
+
|
19
|
+
@abstractmethod
|
20
|
+
def retrieve_seed_edges(self, query: Any, **kwargs) -> InputEdges:
|
21
|
+
"""Makes a comparison between the query and all the edges to get all
|
22
|
+
the closest nodes. Returns the edge indices that are to be the seeds
|
23
|
+
for the RAG Sampler.
|
24
|
+
"""
|
25
|
+
...
|
26
|
+
|
27
|
+
@abstractmethod
|
28
|
+
def load_subgraph(
|
29
|
+
self, sample: Union[SamplerOutput, HeteroSamplerOutput]
|
30
|
+
) -> Union[Data, HeteroData]:
|
31
|
+
"""Combines sampled subgraph output with features in a Data object."""
|
32
|
+
...
|
33
|
+
|
34
|
+
|
35
|
+
class RAGGraphStore(Protocol):
|
36
|
+
"""Graph store for remote GNN RAG backend."""
|
37
|
+
@abstractmethod
|
38
|
+
def sample_subgraph(self, seed_nodes: InputNodes, seed_edges: InputEdges,
|
39
|
+
**kwargs) -> Union[SamplerOutput, HeteroSamplerOutput]:
|
40
|
+
"""Sample a subgraph using the seeded nodes and edges."""
|
41
|
+
...
|
42
|
+
|
43
|
+
@abstractmethod
|
44
|
+
def register_feature_store(self, feature_store: FeatureStore):
|
45
|
+
"""Register a feature store to be used with the sampler. Samplers need
|
46
|
+
info from the feature store in order to work properly on HeteroGraphs.
|
47
|
+
"""
|
48
|
+
...
|
49
|
+
|
50
|
+
|
51
|
+
# TODO: Make compatible with Heterographs
|
52
|
+
|
53
|
+
|
54
|
+
class RAGQueryLoader:
|
55
|
+
def __init__(self, data: Tuple[RAGFeatureStore, RAGGraphStore],
|
56
|
+
local_filter: Optional[Callable[[Data, Any], Data]] = None,
|
57
|
+
seed_nodes_kwargs: Optional[Dict[str, Any]] = None,
|
58
|
+
seed_edges_kwargs: Optional[Dict[str, Any]] = None,
|
59
|
+
sampler_kwargs: Optional[Dict[str, Any]] = None,
|
60
|
+
loader_kwargs: Optional[Dict[str, Any]] = None):
|
61
|
+
"""Loader meant for making queries from a remote backend.
|
62
|
+
|
63
|
+
Args:
|
64
|
+
data (Tuple[RAGFeatureStore, RAGGraphStore]): Remote FeatureStore
|
65
|
+
and GraphStore to load from. Assumed to conform to the
|
66
|
+
protocols listed above.
|
67
|
+
local_filter (Optional[Callable[[Data, Any], Data]], optional):
|
68
|
+
Optional local transform to apply to data after retrieval.
|
69
|
+
Defaults to None.
|
70
|
+
seed_nodes_kwargs (Optional[Dict[str, Any]], optional): Paramaters
|
71
|
+
to pass into process for fetching seed nodes. Defaults to None.
|
72
|
+
seed_edges_kwargs (Optional[Dict[str, Any]], optional): Parameters
|
73
|
+
to pass into process for fetching seed edges. Defaults to None.
|
74
|
+
sampler_kwargs (Optional[Dict[str, Any]], optional): Parameters to
|
75
|
+
pass into process for sampling graph. Defaults to None.
|
76
|
+
loader_kwargs (Optional[Dict[str, Any]], optional): Parameters to
|
77
|
+
pass into process for loading graph features. Defaults to None.
|
78
|
+
"""
|
79
|
+
fstore, gstore = data
|
80
|
+
self.feature_store = fstore
|
81
|
+
self.graph_store = gstore
|
82
|
+
self.graph_store.register_feature_store(self.feature_store)
|
83
|
+
self.local_filter = local_filter
|
84
|
+
self.seed_nodes_kwargs = seed_nodes_kwargs or {}
|
85
|
+
self.seed_edges_kwargs = seed_edges_kwargs or {}
|
86
|
+
self.sampler_kwargs = sampler_kwargs or {}
|
87
|
+
self.loader_kwargs = loader_kwargs or {}
|
88
|
+
|
89
|
+
def query(self, query: Any) -> Data:
|
90
|
+
"""Retrieve a subgraph associated with the query with all its feature
|
91
|
+
attributes.
|
92
|
+
"""
|
93
|
+
seed_nodes = self.feature_store.retrieve_seed_nodes(
|
94
|
+
query, **self.seed_nodes_kwargs)
|
95
|
+
seed_edges = self.feature_store.retrieve_seed_edges(
|
96
|
+
query, **self.seed_edges_kwargs)
|
97
|
+
|
98
|
+
subgraph_sample = self.graph_store.sample_subgraph(
|
99
|
+
seed_nodes, seed_edges, **self.sampler_kwargs)
|
100
|
+
|
101
|
+
data = self.feature_store.load_subgraph(sample=subgraph_sample,
|
102
|
+
**self.loader_kwargs)
|
103
|
+
|
104
|
+
if self.local_filter:
|
105
|
+
data = self.local_filter(data, query)
|
106
|
+
return data
|
@@ -29,6 +29,7 @@ from .pmlp import PMLP
|
|
29
29
|
from .neural_fingerprint import NeuralFingerprint
|
30
30
|
from .visnet import ViSNet
|
31
31
|
from .g_retriever import GRetriever
|
32
|
+
from .git_mol import GITMol
|
32
33
|
from .molecule_gpt import MoleculeGPT
|
33
34
|
from .glem import GLEM
|
34
35
|
# Deprecated:
|
@@ -78,6 +79,7 @@ __all__ = classes = [
|
|
78
79
|
'NeuralFingerprint',
|
79
80
|
'ViSNet',
|
80
81
|
'GRetriever',
|
82
|
+
'GITMol',
|
81
83
|
'MoleculeGPT',
|
82
84
|
'GLEM',
|
83
85
|
]
|
@@ -21,6 +21,8 @@ class GRetriever(torch.nn.Module):
|
|
21
21
|
(default: :obj:`False`)
|
22
22
|
mlp_out_channels (int, optional): The size of each graph embedding
|
23
23
|
after projection. (default: :obj:`4096`)
|
24
|
+
mlp_out_tokens (int, optional): Number of LLM prefix tokens to
|
25
|
+
reserve for GNN output. (default: :obj:`1`)
|
24
26
|
|
25
27
|
.. warning::
|
26
28
|
This module has been tested with the following HuggingFace models
|
@@ -43,6 +45,7 @@ class GRetriever(torch.nn.Module):
|
|
43
45
|
gnn: torch.nn.Module,
|
44
46
|
use_lora: bool = False,
|
45
47
|
mlp_out_channels: int = 4096,
|
48
|
+
mlp_out_tokens: int = 1,
|
46
49
|
) -> None:
|
47
50
|
super().__init__()
|
48
51
|
|
@@ -77,7 +80,9 @@ class GRetriever(torch.nn.Module):
|
|
77
80
|
self.projector = torch.nn.Sequential(
|
78
81
|
torch.nn.Linear(mlp_hidden_channels, mlp_hidden_channels),
|
79
82
|
torch.nn.Sigmoid(),
|
80
|
-
torch.nn.Linear(mlp_hidden_channels,
|
83
|
+
torch.nn.Linear(mlp_hidden_channels,
|
84
|
+
mlp_out_channels * mlp_out_tokens),
|
85
|
+
torch.nn.Unflatten(-1, (mlp_out_tokens, mlp_out_channels)),
|
81
86
|
).to(self.llm.device)
|
82
87
|
|
83
88
|
def encode(
|
@@ -126,6 +131,9 @@ class GRetriever(torch.nn.Module):
|
|
126
131
|
x = self.projector(x)
|
127
132
|
xs = x.split(1, dim=0)
|
128
133
|
|
134
|
+
# Handle case where theres more than one embedding for each sample
|
135
|
+
xs = [x.squeeze(0) for x in xs]
|
136
|
+
|
129
137
|
# Handle questions without node features:
|
130
138
|
batch_unique = batch.unique()
|
131
139
|
batch_size = len(question)
|
@@ -182,6 +190,9 @@ class GRetriever(torch.nn.Module):
|
|
182
190
|
x = self.projector(x)
|
183
191
|
xs = x.split(1, dim=0)
|
184
192
|
|
193
|
+
# Handle case where theres more than one embedding for each sample
|
194
|
+
xs = [x.squeeze(0) for x in xs]
|
195
|
+
|
185
196
|
# Handle questions without node features:
|
186
197
|
batch_unique = batch.unique()
|
187
198
|
batch_size = len(question)
|