pyg-nightly 2.7.0.dev20241124__py3-none-any.whl → 2.7.0.dev20241127__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,263 @@
1
+ import sys
2
+ from typing import Any, Callable, Dict, List, Optional
3
+
4
+ import numpy as np
5
+ import torch
6
+ from tqdm import tqdm
7
+
8
+ from torch_geometric.data import (
9
+ Data,
10
+ InMemoryDataset,
11
+ download_google_url,
12
+ extract_zip,
13
+ )
14
+ from torch_geometric.io import fs
15
+
16
+
17
+ def safe_index(lst: List[Any], e: int) -> int:
18
+ return lst.index(e) if e in lst else len(lst) - 1
19
+
20
+
21
+ class GitMolDataset(InMemoryDataset):
22
+ r"""The dataset from the `"GIT-Mol: A Multi-modal Large Language Model
23
+ for Molecular Science with Graph, Image, and Text"
24
+ <https://arxiv.org/pdf/2308.06911>`_ paper.
25
+
26
+ Args:
27
+ root (str): Root directory where the dataset should be saved.
28
+ transform (callable, optional): A function/transform that takes in an
29
+ :obj:`torch_geometric.data.Data` object and returns a transformed
30
+ version. The data object will be transformed before every access.
31
+ (default: :obj:`None`)
32
+ pre_transform (callable, optional): A function/transform that takes in
33
+ an :obj:`torch_geometric.data.Data` object and returns a
34
+ transformed version. The data object will be transformed before
35
+ being saved to disk. (default: :obj:`None`)
36
+ pre_filter (callable, optional): A function that takes in an
37
+ :obj:`torch_geometric.data.Data` object and returns a boolean
38
+ value, indicating whether the data object should be included in the
39
+ final dataset. (default: :obj:`None`)
40
+ force_reload (bool, optional): Whether to re-process the dataset.
41
+ (default: :obj:`False`)
42
+ split (int, optional): Datasets split, train/valid/test=0/1/2.
43
+ (default: :obj:`0`)
44
+ """
45
+
46
+ raw_url_id = '1loBXabD6ncAFY-vanRsVtRUSFkEtBweg'
47
+
48
+ def __init__(
49
+ self,
50
+ root: str,
51
+ transform: Optional[Callable] = None,
52
+ pre_transform: Optional[Callable] = None,
53
+ pre_filter: Optional[Callable] = None,
54
+ force_reload: bool = False,
55
+ split: int = 0,
56
+ ):
57
+ from torchvision import transforms
58
+
59
+ self.split = split
60
+
61
+ if self.split == 0:
62
+ self.img_transform = transforms.Compose([
63
+ transforms.Resize((224, 224)),
64
+ transforms.RandomRotation(15),
65
+ transforms.ColorJitter(brightness=0.5, contrast=0.5, hue=0.5),
66
+ transforms.ToTensor(),
67
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
68
+ std=[0.229, 0.224, 0.225])
69
+ ])
70
+ else:
71
+ self.img_transform = transforms.Compose([
72
+ transforms.Resize((224, 224)),
73
+ transforms.ToTensor(),
74
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
75
+ std=[0.229, 0.224, 0.225])
76
+ ])
77
+
78
+ super().__init__(root, transform, pre_transform, pre_filter,
79
+ force_reload=force_reload)
80
+
81
+ self.load(self.processed_paths[0])
82
+
83
+ @property
84
+ def raw_file_names(self) -> List[str]:
85
+ return ['train_3500.pkl', 'valid_450.pkl', 'test_450.pkl']
86
+
87
+ @property
88
+ def processed_file_names(self) -> str:
89
+ return ['train.pt', 'valid.pt', 'test.pt'][self.split]
90
+
91
+ def download(self) -> None:
92
+ file_path = download_google_url(
93
+ self.raw_url_id,
94
+ self.raw_dir,
95
+ 'gitmol.zip',
96
+ )
97
+ extract_zip(file_path, self.raw_dir)
98
+
99
+ def process(self) -> None:
100
+ import pandas as pd
101
+ from PIL import Image
102
+
103
+ try:
104
+ from rdkit import Chem, RDLogger
105
+ RDLogger.DisableLog('rdApp.*') # type: ignore
106
+ WITH_RDKIT = True
107
+
108
+ except ImportError:
109
+ WITH_RDKIT = False
110
+
111
+ if not WITH_RDKIT:
112
+ print(("Using a pre-processed version of the dataset. Please "
113
+ "install 'rdkit' to alternatively process the raw data."),
114
+ file=sys.stderr)
115
+
116
+ data_list = fs.torch_load(self.raw_paths[0])
117
+ data_list = [Data(**data_dict) for data_dict in data_list]
118
+
119
+ if self.pre_filter is not None:
120
+ data_list = [d for d in data_list if self.pre_filter(d)]
121
+
122
+ if self.pre_transform is not None:
123
+ data_list = [self.pre_transform(d) for d in data_list]
124
+
125
+ self.save(data_list, self.processed_paths[0])
126
+ return
127
+
128
+ allowable_features: Dict[str, List[Any]] = {
129
+ 'possible_atomic_num_list':
130
+ list(range(1, 119)) + ['misc'],
131
+ 'possible_formal_charge_list':
132
+ [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 'misc'],
133
+ 'possible_chirality_list': [
134
+ Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
135
+ Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
136
+ Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
137
+ Chem.rdchem.ChiralType.CHI_OTHER
138
+ ],
139
+ 'possible_hybridization_list': [
140
+ Chem.rdchem.HybridizationType.SP,
141
+ Chem.rdchem.HybridizationType.SP2,
142
+ Chem.rdchem.HybridizationType.SP3,
143
+ Chem.rdchem.HybridizationType.SP3D,
144
+ Chem.rdchem.HybridizationType.SP3D2,
145
+ Chem.rdchem.HybridizationType.UNSPECIFIED, 'misc'
146
+ ],
147
+ 'possible_numH_list': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'],
148
+ 'possible_implicit_valence_list': [0, 1, 2, 3, 4, 5, 6],
149
+ 'possible_degree_list': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 'misc'],
150
+ 'possible_number_radical_e_list': [0, 1, 2, 3, 4, 'misc'],
151
+ 'possible_is_aromatic_list': [False, True],
152
+ 'possible_is_in_ring_list': [False, True],
153
+ 'possible_bond_type_list': [
154
+ Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE,
155
+ Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC,
156
+ Chem.rdchem.BondType.ZERO
157
+ ],
158
+ 'possible_bond_dirs': [ # only for double bond stereo information
159
+ Chem.rdchem.BondDir.NONE, Chem.rdchem.BondDir.ENDUPRIGHT,
160
+ Chem.rdchem.BondDir.ENDDOWNRIGHT
161
+ ],
162
+ 'possible_bond_stereo_list': [
163
+ Chem.rdchem.BondStereo.STEREONONE,
164
+ Chem.rdchem.BondStereo.STEREOZ,
165
+ Chem.rdchem.BondStereo.STEREOE,
166
+ Chem.rdchem.BondStereo.STEREOCIS,
167
+ Chem.rdchem.BondStereo.STEREOTRANS,
168
+ Chem.rdchem.BondStereo.STEREOANY,
169
+ ],
170
+ 'possible_is_conjugated_list': [False, True]
171
+ }
172
+
173
+ data = pd.read_pickle(
174
+ f'{self.raw_dir}/igcdata_toy/{self.raw_file_names[self.split]}')
175
+
176
+ data_list = []
177
+ for _, r in tqdm(data.iterrows(), total=data.shape[0]):
178
+ smiles = r['isosmiles']
179
+ mol = Chem.MolFromSmiles(smiles.strip('\n'))
180
+ if mol is not None:
181
+ # text
182
+ summary = r['summary']
183
+ # image
184
+ cid = r['cid']
185
+ img_file = f'{self.raw_dir}/igcdata_toy/imgs/CID_{cid}.png'
186
+ img = Image.open(img_file).convert('RGB')
187
+ img = self.img_transform(img).unsqueeze(0)
188
+ # graph
189
+ atom_features_list = []
190
+ for atom in mol.GetAtoms(): # type: ignore
191
+ atom_feature = [
192
+ safe_index(
193
+ allowable_features['possible_atomic_num_list'],
194
+ atom.GetAtomicNum()),
195
+ allowable_features['possible_chirality_list'].index(
196
+ atom.GetChiralTag()),
197
+ safe_index(allowable_features['possible_degree_list'],
198
+ atom.GetTotalDegree()),
199
+ safe_index(
200
+ allowable_features['possible_formal_charge_list'],
201
+ atom.GetFormalCharge()),
202
+ safe_index(allowable_features['possible_numH_list'],
203
+ atom.GetTotalNumHs()),
204
+ safe_index(
205
+ allowable_features[
206
+ 'possible_number_radical_e_list'],
207
+ atom.GetNumRadicalElectrons()),
208
+ safe_index(
209
+ allowable_features['possible_hybridization_list'],
210
+ atom.GetHybridization()),
211
+ allowable_features['possible_is_aromatic_list'].index(
212
+ atom.GetIsAromatic()),
213
+ allowable_features['possible_is_in_ring_list'].index(
214
+ atom.IsInRing()),
215
+ ]
216
+ atom_features_list.append(atom_feature)
217
+ x = torch.tensor(np.array(atom_features_list),
218
+ dtype=torch.long)
219
+
220
+ edges_list = []
221
+ edge_features_list = []
222
+ for bond in mol.GetBonds(): # type: ignore
223
+ i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
224
+ edge_feature = [
225
+ safe_index(
226
+ allowable_features['possible_bond_type_list'],
227
+ bond.GetBondType()),
228
+ allowable_features['possible_bond_stereo_list'].index(
229
+ bond.GetStereo()),
230
+ allowable_features['possible_is_conjugated_list'].
231
+ index(bond.GetIsConjugated()),
232
+ ]
233
+ edges_list.append((i, j))
234
+ edge_features_list.append(edge_feature)
235
+ edges_list.append((j, i))
236
+ edge_features_list.append(edge_feature)
237
+
238
+ edge_index = torch.tensor(
239
+ np.array(edges_list).T,
240
+ dtype=torch.long,
241
+ )
242
+ edge_attr = torch.tensor(
243
+ np.array(edge_features_list),
244
+ dtype=torch.long,
245
+ )
246
+
247
+ data = Data(
248
+ x=x,
249
+ edge_index=edge_index,
250
+ smiles=smiles,
251
+ edge_attr=edge_attr,
252
+ image=img,
253
+ caption=summary,
254
+ )
255
+
256
+ if self.pre_filter is not None and not self.pre_filter(data):
257
+ continue
258
+ if self.pre_transform is not None:
259
+ data = self.pre_transform(data)
260
+
261
+ data_list.append(data)
262
+
263
+ self.save(data_list, self.processed_paths[0])
@@ -22,6 +22,7 @@ from .dynamic_batch_sampler import DynamicBatchSampler
22
22
  from .prefetch import PrefetchLoader
23
23
  from .cache import CachedLoader
24
24
  from .mixin import AffinityMixin
25
+ from .rag_loader import RAGQueryLoader
25
26
 
26
27
  __all__ = classes = [
27
28
  'DataLoader',
@@ -50,6 +51,7 @@ __all__ = classes = [
50
51
  'PrefetchLoader',
51
52
  'CachedLoader',
52
53
  'AffinityMixin',
54
+ 'RAGQueryLoader',
53
55
  ]
54
56
 
55
57
  RandomNodeSampler = deprecated(
@@ -0,0 +1,106 @@
1
+ from abc import abstractmethod
2
+ from typing import Any, Callable, Dict, Optional, Protocol, Tuple, Union
3
+
4
+ from torch_geometric.data import Data, FeatureStore, HeteroData
5
+ from torch_geometric.sampler import HeteroSamplerOutput, SamplerOutput
6
+ from torch_geometric.typing import InputEdges, InputNodes
7
+
8
+
9
+ class RAGFeatureStore(Protocol):
10
+ """Feature store for remote GNN RAG backend."""
11
+ @abstractmethod
12
+ def retrieve_seed_nodes(self, query: Any, **kwargs) -> InputNodes:
13
+ """Makes a comparison between the query and all the nodes to get all
14
+ the closest nodes. Return the indices of the nodes that are to be seeds
15
+ for the RAG Sampler.
16
+ """
17
+ ...
18
+
19
+ @abstractmethod
20
+ def retrieve_seed_edges(self, query: Any, **kwargs) -> InputEdges:
21
+ """Makes a comparison between the query and all the edges to get all
22
+ the closest nodes. Returns the edge indices that are to be the seeds
23
+ for the RAG Sampler.
24
+ """
25
+ ...
26
+
27
+ @abstractmethod
28
+ def load_subgraph(
29
+ self, sample: Union[SamplerOutput, HeteroSamplerOutput]
30
+ ) -> Union[Data, HeteroData]:
31
+ """Combines sampled subgraph output with features in a Data object."""
32
+ ...
33
+
34
+
35
+ class RAGGraphStore(Protocol):
36
+ """Graph store for remote GNN RAG backend."""
37
+ @abstractmethod
38
+ def sample_subgraph(self, seed_nodes: InputNodes, seed_edges: InputEdges,
39
+ **kwargs) -> Union[SamplerOutput, HeteroSamplerOutput]:
40
+ """Sample a subgraph using the seeded nodes and edges."""
41
+ ...
42
+
43
+ @abstractmethod
44
+ def register_feature_store(self, feature_store: FeatureStore):
45
+ """Register a feature store to be used with the sampler. Samplers need
46
+ info from the feature store in order to work properly on HeteroGraphs.
47
+ """
48
+ ...
49
+
50
+
51
+ # TODO: Make compatible with Heterographs
52
+
53
+
54
+ class RAGQueryLoader:
55
+ def __init__(self, data: Tuple[RAGFeatureStore, RAGGraphStore],
56
+ local_filter: Optional[Callable[[Data, Any], Data]] = None,
57
+ seed_nodes_kwargs: Optional[Dict[str, Any]] = None,
58
+ seed_edges_kwargs: Optional[Dict[str, Any]] = None,
59
+ sampler_kwargs: Optional[Dict[str, Any]] = None,
60
+ loader_kwargs: Optional[Dict[str, Any]] = None):
61
+ """Loader meant for making queries from a remote backend.
62
+
63
+ Args:
64
+ data (Tuple[RAGFeatureStore, RAGGraphStore]): Remote FeatureStore
65
+ and GraphStore to load from. Assumed to conform to the
66
+ protocols listed above.
67
+ local_filter (Optional[Callable[[Data, Any], Data]], optional):
68
+ Optional local transform to apply to data after retrieval.
69
+ Defaults to None.
70
+ seed_nodes_kwargs (Optional[Dict[str, Any]], optional): Paramaters
71
+ to pass into process for fetching seed nodes. Defaults to None.
72
+ seed_edges_kwargs (Optional[Dict[str, Any]], optional): Parameters
73
+ to pass into process for fetching seed edges. Defaults to None.
74
+ sampler_kwargs (Optional[Dict[str, Any]], optional): Parameters to
75
+ pass into process for sampling graph. Defaults to None.
76
+ loader_kwargs (Optional[Dict[str, Any]], optional): Parameters to
77
+ pass into process for loading graph features. Defaults to None.
78
+ """
79
+ fstore, gstore = data
80
+ self.feature_store = fstore
81
+ self.graph_store = gstore
82
+ self.graph_store.register_feature_store(self.feature_store)
83
+ self.local_filter = local_filter
84
+ self.seed_nodes_kwargs = seed_nodes_kwargs or {}
85
+ self.seed_edges_kwargs = seed_edges_kwargs or {}
86
+ self.sampler_kwargs = sampler_kwargs or {}
87
+ self.loader_kwargs = loader_kwargs or {}
88
+
89
+ def query(self, query: Any) -> Data:
90
+ """Retrieve a subgraph associated with the query with all its feature
91
+ attributes.
92
+ """
93
+ seed_nodes = self.feature_store.retrieve_seed_nodes(
94
+ query, **self.seed_nodes_kwargs)
95
+ seed_edges = self.feature_store.retrieve_seed_edges(
96
+ query, **self.seed_edges_kwargs)
97
+
98
+ subgraph_sample = self.graph_store.sample_subgraph(
99
+ seed_nodes, seed_edges, **self.sampler_kwargs)
100
+
101
+ data = self.feature_store.load_subgraph(sample=subgraph_sample,
102
+ **self.loader_kwargs)
103
+
104
+ if self.local_filter:
105
+ data = self.local_filter(data, query)
106
+ return data
@@ -29,6 +29,7 @@ from .pmlp import PMLP
29
29
  from .neural_fingerprint import NeuralFingerprint
30
30
  from .visnet import ViSNet
31
31
  from .g_retriever import GRetriever
32
+ from .git_mol import GITMol
32
33
  from .molecule_gpt import MoleculeGPT
33
34
  from .glem import GLEM
34
35
  # Deprecated:
@@ -78,6 +79,7 @@ __all__ = classes = [
78
79
  'NeuralFingerprint',
79
80
  'ViSNet',
80
81
  'GRetriever',
82
+ 'GITMol',
81
83
  'MoleculeGPT',
82
84
  'GLEM',
83
85
  ]
@@ -21,6 +21,8 @@ class GRetriever(torch.nn.Module):
21
21
  (default: :obj:`False`)
22
22
  mlp_out_channels (int, optional): The size of each graph embedding
23
23
  after projection. (default: :obj:`4096`)
24
+ mlp_out_tokens (int, optional): Number of LLM prefix tokens to
25
+ reserve for GNN output. (default: :obj:`1`)
24
26
 
25
27
  .. warning::
26
28
  This module has been tested with the following HuggingFace models
@@ -43,6 +45,7 @@ class GRetriever(torch.nn.Module):
43
45
  gnn: torch.nn.Module,
44
46
  use_lora: bool = False,
45
47
  mlp_out_channels: int = 4096,
48
+ mlp_out_tokens: int = 1,
46
49
  ) -> None:
47
50
  super().__init__()
48
51
 
@@ -77,7 +80,9 @@ class GRetriever(torch.nn.Module):
77
80
  self.projector = torch.nn.Sequential(
78
81
  torch.nn.Linear(mlp_hidden_channels, mlp_hidden_channels),
79
82
  torch.nn.Sigmoid(),
80
- torch.nn.Linear(mlp_hidden_channels, mlp_out_channels),
83
+ torch.nn.Linear(mlp_hidden_channels,
84
+ mlp_out_channels * mlp_out_tokens),
85
+ torch.nn.Unflatten(-1, (mlp_out_tokens, mlp_out_channels)),
81
86
  ).to(self.llm.device)
82
87
 
83
88
  def encode(
@@ -126,6 +131,9 @@ class GRetriever(torch.nn.Module):
126
131
  x = self.projector(x)
127
132
  xs = x.split(1, dim=0)
128
133
 
134
+ # Handle case where theres more than one embedding for each sample
135
+ xs = [x.squeeze(0) for x in xs]
136
+
129
137
  # Handle questions without node features:
130
138
  batch_unique = batch.unique()
131
139
  batch_size = len(question)
@@ -182,6 +190,9 @@ class GRetriever(torch.nn.Module):
182
190
  x = self.projector(x)
183
191
  xs = x.split(1, dim=0)
184
192
 
193
+ # Handle case where theres more than one embedding for each sample
194
+ xs = [x.squeeze(0) for x in xs]
195
+
185
196
  # Handle questions without node features:
186
197
  batch_unique = batch.unique()
187
198
  batch_size = len(question)