pyg-nightly 2.7.0.dev20241124__py3-none-any.whl → 2.7.0.dev20241125__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: pyg-nightly
3
- Version: 2.7.0.dev20241124
3
+ Version: 2.7.0.dev20241125
4
4
  Summary: Graph Neural Network Library for PyTorch
5
5
  Keywords: deep-learning,pytorch,geometric-deep-learning,graph-neural-networks,graph-convolutional-networks
6
6
  Author-email: Matthias Fey <matthias@pyg.org>
@@ -1,4 +1,4 @@
1
- torch_geometric/__init__.py,sha256=3RVflrVxTFQoil4Sv-0x8Wr5IftwVD9-YYAwwGwnwzk,1904
1
+ torch_geometric/__init__.py,sha256=hPMlzqznHr3x2xBZhYBnmC-i7KVOX-tIpw1gy43En6g,1904
2
2
  torch_geometric/_compile.py,sha256=f-WQeH4VLi5Hn9lrgztFUCSrN_FImjhQa6BxFzcYC38,1338
3
3
  torch_geometric/_onnx.py,sha256=V9ffrIKSqhDw6xUZ12lkuSfNs48cQp2EeJ6Z19GfnVw,349
4
4
  torch_geometric/backend.py,sha256=lVaf7aLoVaB3M-UcByUJ1G4T4FOK6LXAg0CF4W3E8jo,1575
@@ -53,7 +53,7 @@ torch_geometric/data/temporal.py,sha256=WOJ6gFrTLikaLhUvotyUF5ql14FkE5Ox3hNkdSp6
53
53
  torch_geometric/data/view.py,sha256=XjkVSc-UWZFCT4DlXLShZtO8duhFQkS9gq88zZXANsk,1089
54
54
  torch_geometric/data/lightning/__init__.py,sha256=w3En1tJfy3kSqe1MycpOyZpHFO3fxBCgNCUOznPA3YU,178
55
55
  torch_geometric/data/lightning/datamodule.py,sha256=Bn9iaIfE4NWDDWWMqCvBeZ4bIW1Silx_Ol5CPJCliaQ,29242
56
- torch_geometric/datasets/__init__.py,sha256=f9YqoX9WTSVMzjuLfFD_eCsC4iQ5kbFNQiZru3n6qw0,6013
56
+ torch_geometric/datasets/__init__.py,sha256=HYgogFHWZabd5yLfc1E4eHy9QsY6ILFRPTgfOorNwWQ,6077
57
57
  torch_geometric/datasets/actor.py,sha256=oUxgJIX8bi5hJr1etWNYIFyVQNDDXi1nyVpHGGMEAGQ,4304
58
58
  torch_geometric/datasets/airfrans.py,sha256=212gYsk7PvF-qcmvM2YXaOBhFrS79evAGg_sPHXih4w,5439
59
59
  torch_geometric/datasets/airports.py,sha256=b3gkv3gY2JkUpmGiz36Z-g7EcnSfU8lBG1YsCOWdJ6k,3758
@@ -92,6 +92,7 @@ torch_geometric/datasets/gdelt_lite.py,sha256=zE1WagpgmsQARQhEgdCBtALRKyuQvIZqxT
92
92
  torch_geometric/datasets/ged_dataset.py,sha256=dtd-C6pCygNHLXgVfg3ZTWtTVHKT13Q3GlGrze1_rpo,9551
93
93
  torch_geometric/datasets/gemsec.py,sha256=oMTSryTgyed9z_4ydg3ql12KM-_35uqL1AoNls5nG8M,2820
94
94
  torch_geometric/datasets/geometry.py,sha256=-BxUMirZcUOf01c3avvF0b6wGPn-4S3Zj3Oau1RaJVk,4223
95
+ torch_geometric/datasets/git_mol_dataset.py,sha256=fdE7hG_gF9bNGHaUITkEnHsZPf9FZy6F66SvvXJ5Tgc,10713
95
96
  torch_geometric/datasets/github.py,sha256=Qhqhkvi6eZ8VF_HqP1rL2iYToZavFNsQh7J1WdeM9dA,2687
96
97
  torch_geometric/datasets/gnn_benchmark_dataset.py,sha256=4P8n7czF-gf1egLYlAcSSvfB0GXIKpAbH5UjsuFld1M,6976
97
98
  torch_geometric/datasets/heterophilous_graph_dataset.py,sha256=yHHtwl4uPrid0vPOxvPV3sIS8HWdswar8FJ0h0OQ9is,4224
@@ -420,7 +421,7 @@ torch_geometric/nn/kge/distmult.py,sha256=dGQ0bVzjreZgFN1lXE23_IIidsiOq7ehPrMb-N
420
421
  torch_geometric/nn/kge/loader.py,sha256=5Uc1j3OUMQnBYSHDqL7pLCty1siFLzoPkztigYO2zP8,771
421
422
  torch_geometric/nn/kge/rotate.py,sha256=XLuO1AbyTt5cJxr97ZzoyAyIEsHKesgW5TvDmnGJAao,3208
422
423
  torch_geometric/nn/kge/transe.py,sha256=jlejq5BLMm-sb1wWcLDp7pZqCdelWBgjDIC8ctbjSdU,3088
423
- torch_geometric/nn/models/__init__.py,sha256=dr2-YsRzUdVBM6Ut78FB9Wbjn-kzV0gPwOlWGPdQLY4,2108
424
+ torch_geometric/nn/models/__init__.py,sha256=vWMKzGBVxA1Fm0uGDLnH4jzYgfhK34CQTRJ-xi5pf5k,2150
424
425
  torch_geometric/nn/models/attentive_fp.py,sha256=tkgvw28wg9-JqHIfBllfCwTHrZIUiv85yZJcDqjz3z0,6634
425
426
  torch_geometric/nn/models/autoencoder.py,sha256=nGje-zty78Y3hxOJ9o0_6QziJjOvBlknk6z0_fDQwQU,10770
426
427
  torch_geometric/nn/models/basic_gnn.py,sha256=PGa0RUMyvrNy_5yRI2jX_zwPsmZXwOQWfsWvxOiHsSk,31225
@@ -431,6 +432,7 @@ torch_geometric/nn/models/deepgcn.py,sha256=tIgT03cj8MghYlxEozpoGvGG_CwpJrGDxv1Z
431
432
  torch_geometric/nn/models/dimenet.py,sha256=Kc5p-rB5q-0e8lY22l-OdQTscTxJh2lTEpeRFMdL4RY,36186
432
433
  torch_geometric/nn/models/dimenet_utils.py,sha256=Eyn_EiJqwKvuYj6BtRpSxrzMG3v4Gk98X9MxZ7uvwm4,5069
433
434
  torch_geometric/nn/models/g_retriever.py,sha256=VueRImNJlh1WvRWcsSXliSw8RlxlzWlu2WSFs_VQaJc,7749
435
+ torch_geometric/nn/models/git_mol.py,sha256=Wc6Hx6RDDR7sDWRWHfA5eK9e9gFsrTZ9OLmpMfoj3pE,12676
434
436
  torch_geometric/nn/models/glem.py,sha256=gqQF4jlU7U_u5-zGeJZuHiEqhSXa-wLU5TghN4u5fYY,16389
435
437
  torch_geometric/nn/models/gnnff.py,sha256=15dkiLgy0LmH1hnUrpeoHioIp4BPTfjpVATpnGRt9E0,7860
436
438
  torch_geometric/nn/models/graph_mixer.py,sha256=mthMeCOikR8gseEsu4oJ3Cd9C35zHSv1p32ROwnG-6s,9246
@@ -454,9 +456,10 @@ torch_geometric/nn/models/schnet.py,sha256=0aaHrVtxApdvn3RHCGLQJW1MbIb--CSYUrx9O
454
456
  torch_geometric/nn/models/signed_gcn.py,sha256=J40CnedFIqtKI1LhW1ITSEFRbA_XiJZL6lASrKwUEAI,9841
455
457
  torch_geometric/nn/models/tgn.py,sha256=kEGdfLJybkbMT4UMoAh2nCzfX3_nDjfm1cicuPHEwAM,11878
456
458
  torch_geometric/nn/models/visnet.py,sha256=97OFMCsPDEI5BCSi7RhoRcU2CNRp7zck2tEzrltFZj4,43192
457
- torch_geometric/nn/nlp/__init__.py,sha256=JJESTA7w_K8v60XbCd25IqmrKKHLz5OiNexMHYGV2mE,138
459
+ torch_geometric/nn/nlp/__init__.py,sha256=q6CPUiJHcc9bXw90lyj-ID4F3kfW8uPM-SOxW9uCMHs,213
458
460
  torch_geometric/nn/nlp/llm.py,sha256=M15Qn0yHyA6HL2rHCH2p4H6hKjUvLfnzlxdfEFvRxSA,11732
459
- torch_geometric/nn/nlp/sentence_transformer.py,sha256=VzMtNUYk6FvOVc3PdVets9_2Sb2FdQbzu9H3m6teRlI,3417
461
+ torch_geometric/nn/nlp/sentence_transformer.py,sha256=q5M7SGtrUzoSiNhKCGFb7JatWiukdhNF6zdq2yiqxwE,4475
462
+ torch_geometric/nn/nlp/vision_transformer.py,sha256=diVBefjIynzYs8WBlcpTeSVnw1PUecHY--B9Yd-W2hA,863
460
463
  torch_geometric/nn/norm/__init__.py,sha256=u2qIDrkbeuObGVXSAIftAlvSd6ouGTtxznCfD-59UiA,669
461
464
  torch_geometric/nn/norm/batch_norm.py,sha256=sJKrinHGwA-noIgteg1RD2W06rd0zskD-rXuY-36glY,8283
462
465
  torch_geometric/nn/norm/diff_group_norm.py,sha256=b57XvNekrUYGDjNJlGeqvaMGNJmHwopSF0_yyBWlLuA,4722
@@ -623,6 +626,6 @@ torch_geometric/utils/undirected.py,sha256=H_nfpI0_WluOG6VfjPyldvcjL4w5USAKWu2x5
623
626
  torch_geometric/visualization/__init__.py,sha256=PyR_4K5SafsJrBr6qWrkjKr6GBL1b7FtZybyXCDEVwY,154
624
627
  torch_geometric/visualization/graph.py,sha256=ZuLPL92yGRi7lxlqsUPwL_EVVXF7P2kMcveTtW79vpA,4784
625
628
  torch_geometric/visualization/influence.py,sha256=CWMvuNA_Nf1sfbJmQgn58yS4OFpeKXeZPe7kEuvkUBw,477
626
- pyg_nightly-2.7.0.dev20241124.dist-info/WHEEL,sha256=CpUCUxeHQbRN5UGRQHYRJorO5Af-Qy_fHMctcQ8DSGI,82
627
- pyg_nightly-2.7.0.dev20241124.dist-info/METADATA,sha256=4Y_tgdPduB0ylxdgw9u2c98XGvWfN_0is1-4mppzP4Q,62979
628
- pyg_nightly-2.7.0.dev20241124.dist-info/RECORD,,
629
+ pyg_nightly-2.7.0.dev20241125.dist-info/WHEEL,sha256=CpUCUxeHQbRN5UGRQHYRJorO5Af-Qy_fHMctcQ8DSGI,82
630
+ pyg_nightly-2.7.0.dev20241125.dist-info/METADATA,sha256=bDgjxvVn0QZLKMZH40NUhX3W96-XohGqDUXoYJ8Ly3A,62979
631
+ pyg_nightly-2.7.0.dev20241125.dist-info/RECORD,,
@@ -30,7 +30,7 @@ from .lazy_loader import LazyLoader
30
30
  contrib = LazyLoader('contrib', globals(), 'torch_geometric.contrib')
31
31
  graphgym = LazyLoader('graphgym', globals(), 'torch_geometric.graphgym')
32
32
 
33
- __version__ = '2.7.0.dev20241124'
33
+ __version__ = '2.7.0.dev20241125'
34
34
 
35
35
  __all__ = [
36
36
  'Index',
@@ -77,6 +77,7 @@ from .myket import MyketDataset
77
77
  from .brca_tgca import BrcaTcga
78
78
  from .neurograph import NeuroGraphDataset
79
79
  from .web_qsp_dataset import WebQSPDataset
80
+ from .git_mol_dataset import GitMolDataset
80
81
  from .molecule_gpt_dataset import MoleculeGPTDataset
81
82
  from .tag_dataset import TAGDataset
82
83
 
@@ -192,6 +193,7 @@ homo_datasets = [
192
193
  'BrcaTcga',
193
194
  'NeuroGraphDataset',
194
195
  'WebQSPDataset',
196
+ 'GitMolDataset',
195
197
  'MoleculeGPTDataset',
196
198
  'TAGDataset',
197
199
  ]
@@ -0,0 +1,263 @@
1
+ import sys
2
+ from typing import Any, Callable, Dict, List, Optional
3
+
4
+ import numpy as np
5
+ import torch
6
+ from tqdm import tqdm
7
+
8
+ from torch_geometric.data import (
9
+ Data,
10
+ InMemoryDataset,
11
+ download_google_url,
12
+ extract_zip,
13
+ )
14
+ from torch_geometric.io import fs
15
+
16
+
17
+ def safe_index(lst: List[Any], e: int) -> int:
18
+ return lst.index(e) if e in lst else len(lst) - 1
19
+
20
+
21
+ class GitMolDataset(InMemoryDataset):
22
+ r"""The dataset from the `"GIT-Mol: A Multi-modal Large Language Model
23
+ for Molecular Science with Graph, Image, and Text"
24
+ <https://arxiv.org/pdf/2308.06911>`_ paper.
25
+
26
+ Args:
27
+ root (str): Root directory where the dataset should be saved.
28
+ transform (callable, optional): A function/transform that takes in an
29
+ :obj:`torch_geometric.data.Data` object and returns a transformed
30
+ version. The data object will be transformed before every access.
31
+ (default: :obj:`None`)
32
+ pre_transform (callable, optional): A function/transform that takes in
33
+ an :obj:`torch_geometric.data.Data` object and returns a
34
+ transformed version. The data object will be transformed before
35
+ being saved to disk. (default: :obj:`None`)
36
+ pre_filter (callable, optional): A function that takes in an
37
+ :obj:`torch_geometric.data.Data` object and returns a boolean
38
+ value, indicating whether the data object should be included in the
39
+ final dataset. (default: :obj:`None`)
40
+ force_reload (bool, optional): Whether to re-process the dataset.
41
+ (default: :obj:`False`)
42
+ split (int, optional): Datasets split, train/valid/test=0/1/2.
43
+ (default: :obj:`0`)
44
+ """
45
+
46
+ raw_url_id = '1loBXabD6ncAFY-vanRsVtRUSFkEtBweg'
47
+
48
+ def __init__(
49
+ self,
50
+ root: str,
51
+ transform: Optional[Callable] = None,
52
+ pre_transform: Optional[Callable] = None,
53
+ pre_filter: Optional[Callable] = None,
54
+ force_reload: bool = False,
55
+ split: int = 0,
56
+ ):
57
+ from torchvision import transforms
58
+
59
+ self.split = split
60
+
61
+ if self.split == 0:
62
+ self.img_transform = transforms.Compose([
63
+ transforms.Resize((224, 224)),
64
+ transforms.RandomRotation(15),
65
+ transforms.ColorJitter(brightness=0.5, contrast=0.5, hue=0.5),
66
+ transforms.ToTensor(),
67
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
68
+ std=[0.229, 0.224, 0.225])
69
+ ])
70
+ else:
71
+ self.img_transform = transforms.Compose([
72
+ transforms.Resize((224, 224)),
73
+ transforms.ToTensor(),
74
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
75
+ std=[0.229, 0.224, 0.225])
76
+ ])
77
+
78
+ super().__init__(root, transform, pre_transform, pre_filter,
79
+ force_reload=force_reload)
80
+
81
+ self.load(self.processed_paths[0])
82
+
83
+ @property
84
+ def raw_file_names(self) -> List[str]:
85
+ return ['train_3500.pkl', 'valid_450.pkl', 'test_450.pkl']
86
+
87
+ @property
88
+ def processed_file_names(self) -> str:
89
+ return ['train.pt', 'valid.pt', 'test.pt'][self.split]
90
+
91
+ def download(self) -> None:
92
+ file_path = download_google_url(
93
+ self.raw_url_id,
94
+ self.raw_dir,
95
+ 'gitmol.zip',
96
+ )
97
+ extract_zip(file_path, self.raw_dir)
98
+
99
+ def process(self) -> None:
100
+ import pandas as pd
101
+ from PIL import Image
102
+
103
+ try:
104
+ from rdkit import Chem, RDLogger
105
+ RDLogger.DisableLog('rdApp.*') # type: ignore
106
+ WITH_RDKIT = True
107
+
108
+ except ImportError:
109
+ WITH_RDKIT = False
110
+
111
+ if not WITH_RDKIT:
112
+ print(("Using a pre-processed version of the dataset. Please "
113
+ "install 'rdkit' to alternatively process the raw data."),
114
+ file=sys.stderr)
115
+
116
+ data_list = fs.torch_load(self.raw_paths[0])
117
+ data_list = [Data(**data_dict) for data_dict in data_list]
118
+
119
+ if self.pre_filter is not None:
120
+ data_list = [d for d in data_list if self.pre_filter(d)]
121
+
122
+ if self.pre_transform is not None:
123
+ data_list = [self.pre_transform(d) for d in data_list]
124
+
125
+ self.save(data_list, self.processed_paths[0])
126
+ return
127
+
128
+ allowable_features: Dict[str, List[Any]] = {
129
+ 'possible_atomic_num_list':
130
+ list(range(1, 119)) + ['misc'],
131
+ 'possible_formal_charge_list':
132
+ [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 'misc'],
133
+ 'possible_chirality_list': [
134
+ Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
135
+ Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
136
+ Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
137
+ Chem.rdchem.ChiralType.CHI_OTHER
138
+ ],
139
+ 'possible_hybridization_list': [
140
+ Chem.rdchem.HybridizationType.SP,
141
+ Chem.rdchem.HybridizationType.SP2,
142
+ Chem.rdchem.HybridizationType.SP3,
143
+ Chem.rdchem.HybridizationType.SP3D,
144
+ Chem.rdchem.HybridizationType.SP3D2,
145
+ Chem.rdchem.HybridizationType.UNSPECIFIED, 'misc'
146
+ ],
147
+ 'possible_numH_list': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'],
148
+ 'possible_implicit_valence_list': [0, 1, 2, 3, 4, 5, 6],
149
+ 'possible_degree_list': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 'misc'],
150
+ 'possible_number_radical_e_list': [0, 1, 2, 3, 4, 'misc'],
151
+ 'possible_is_aromatic_list': [False, True],
152
+ 'possible_is_in_ring_list': [False, True],
153
+ 'possible_bond_type_list': [
154
+ Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE,
155
+ Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC,
156
+ Chem.rdchem.BondType.ZERO
157
+ ],
158
+ 'possible_bond_dirs': [ # only for double bond stereo information
159
+ Chem.rdchem.BondDir.NONE, Chem.rdchem.BondDir.ENDUPRIGHT,
160
+ Chem.rdchem.BondDir.ENDDOWNRIGHT
161
+ ],
162
+ 'possible_bond_stereo_list': [
163
+ Chem.rdchem.BondStereo.STEREONONE,
164
+ Chem.rdchem.BondStereo.STEREOZ,
165
+ Chem.rdchem.BondStereo.STEREOE,
166
+ Chem.rdchem.BondStereo.STEREOCIS,
167
+ Chem.rdchem.BondStereo.STEREOTRANS,
168
+ Chem.rdchem.BondStereo.STEREOANY,
169
+ ],
170
+ 'possible_is_conjugated_list': [False, True]
171
+ }
172
+
173
+ data = pd.read_pickle(
174
+ f'{self.raw_dir}/igcdata_toy/{self.raw_file_names[self.split]}')
175
+
176
+ data_list = []
177
+ for _, r in tqdm(data.iterrows(), total=data.shape[0]):
178
+ smiles = r['isosmiles']
179
+ mol = Chem.MolFromSmiles(smiles.strip('\n'))
180
+ if mol is not None:
181
+ # text
182
+ summary = r['summary']
183
+ # image
184
+ cid = r['cid']
185
+ img_file = f'{self.raw_dir}/igcdata_toy/imgs/CID_{cid}.png'
186
+ img = Image.open(img_file).convert('RGB')
187
+ img = self.img_transform(img).unsqueeze(0)
188
+ # graph
189
+ atom_features_list = []
190
+ for atom in mol.GetAtoms(): # type: ignore
191
+ atom_feature = [
192
+ safe_index(
193
+ allowable_features['possible_atomic_num_list'],
194
+ atom.GetAtomicNum()),
195
+ allowable_features['possible_chirality_list'].index(
196
+ atom.GetChiralTag()),
197
+ safe_index(allowable_features['possible_degree_list'],
198
+ atom.GetTotalDegree()),
199
+ safe_index(
200
+ allowable_features['possible_formal_charge_list'],
201
+ atom.GetFormalCharge()),
202
+ safe_index(allowable_features['possible_numH_list'],
203
+ atom.GetTotalNumHs()),
204
+ safe_index(
205
+ allowable_features[
206
+ 'possible_number_radical_e_list'],
207
+ atom.GetNumRadicalElectrons()),
208
+ safe_index(
209
+ allowable_features['possible_hybridization_list'],
210
+ atom.GetHybridization()),
211
+ allowable_features['possible_is_aromatic_list'].index(
212
+ atom.GetIsAromatic()),
213
+ allowable_features['possible_is_in_ring_list'].index(
214
+ atom.IsInRing()),
215
+ ]
216
+ atom_features_list.append(atom_feature)
217
+ x = torch.tensor(np.array(atom_features_list),
218
+ dtype=torch.long)
219
+
220
+ edges_list = []
221
+ edge_features_list = []
222
+ for bond in mol.GetBonds(): # type: ignore
223
+ i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
224
+ edge_feature = [
225
+ safe_index(
226
+ allowable_features['possible_bond_type_list'],
227
+ bond.GetBondType()),
228
+ allowable_features['possible_bond_stereo_list'].index(
229
+ bond.GetStereo()),
230
+ allowable_features['possible_is_conjugated_list'].
231
+ index(bond.GetIsConjugated()),
232
+ ]
233
+ edges_list.append((i, j))
234
+ edge_features_list.append(edge_feature)
235
+ edges_list.append((j, i))
236
+ edge_features_list.append(edge_feature)
237
+
238
+ edge_index = torch.tensor(
239
+ np.array(edges_list).T,
240
+ dtype=torch.long,
241
+ )
242
+ edge_attr = torch.tensor(
243
+ np.array(edge_features_list),
244
+ dtype=torch.long,
245
+ )
246
+
247
+ data = Data(
248
+ x=x,
249
+ edge_index=edge_index,
250
+ smiles=smiles,
251
+ edge_attr=edge_attr,
252
+ image=img,
253
+ caption=summary,
254
+ )
255
+
256
+ if self.pre_filter is not None and not self.pre_filter(data):
257
+ continue
258
+ if self.pre_transform is not None:
259
+ data = self.pre_transform(data)
260
+
261
+ data_list.append(data)
262
+
263
+ self.save(data_list, self.processed_paths[0])
@@ -29,6 +29,7 @@ from .pmlp import PMLP
29
29
  from .neural_fingerprint import NeuralFingerprint
30
30
  from .visnet import ViSNet
31
31
  from .g_retriever import GRetriever
32
+ from .git_mol import GITMol
32
33
  from .molecule_gpt import MoleculeGPT
33
34
  from .glem import GLEM
34
35
  # Deprecated:
@@ -78,6 +79,7 @@ __all__ = classes = [
78
79
  'NeuralFingerprint',
79
80
  'ViSNet',
80
81
  'GRetriever',
82
+ 'GITMol',
81
83
  'MoleculeGPT',
82
84
  'GLEM',
83
85
  ]
@@ -0,0 +1,336 @@
1
+ from typing import List, Optional
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import Tensor
6
+ from torch.nn import BatchNorm1d, LayerNorm, Linear, ReLU, Sequential
7
+
8
+ from torch_geometric.nn import GINEConv
9
+ from torch_geometric.nn.nlp import SentenceTransformer, VisionTransformer
10
+ from torch_geometric.utils import add_self_loops, to_dense_batch
11
+
12
+
13
+ class GraphEncoder(torch.nn.Module):
14
+ def __init__(
15
+ self,
16
+ num_layers: int,
17
+ in_channels: int,
18
+ dropout: float = 0.,
19
+ num_atom_type: int = 120,
20
+ num_chirality_tag: int = 3,
21
+ num_bond_type: int = 6,
22
+ num_bond_direction: int = 3,
23
+ ) -> None:
24
+ super().__init__()
25
+
26
+ self.num_layers = num_layers
27
+ self.dropout = dropout
28
+
29
+ self.x_embed1 = torch.nn.Embedding(num_atom_type, in_channels)
30
+ self.x_embed2 = torch.nn.Embedding(num_chirality_tag, in_channels)
31
+ self.edge_embed1 = torch.nn.Embedding(num_bond_type, in_channels)
32
+ self.edge_embed2 = torch.nn.Embedding(num_bond_direction, in_channels)
33
+
34
+ self.gnns = torch.nn.ModuleList()
35
+ self.batch_norms = torch.nn.ModuleList()
36
+ for _ in range(num_layers):
37
+ self.gnns.append(
38
+ GINEConv(
39
+ nn=Sequential(
40
+ Linear(in_channels, in_channels * 2),
41
+ ReLU(),
42
+ Linear(in_channels * 2, in_channels),
43
+ ),
44
+ train_eps=True,
45
+ edge_dim=in_channels,
46
+ ))
47
+ self.batch_norms.append(BatchNorm1d(in_channels))
48
+ self.reset_parameters()
49
+
50
+ def reset_parameters(self):
51
+ torch.nn.init.xavier_uniform_(self.x_embed1.weight.data)
52
+ torch.nn.init.xavier_uniform_(self.x_embed2.weight.data)
53
+ torch.nn.init.xavier_uniform_(self.edge_embed1.weight.data)
54
+ torch.nn.init.xavier_uniform_(self.edge_embed2.weight.data)
55
+
56
+ def forward(
57
+ self,
58
+ x: Tensor,
59
+ edge_index: Tensor,
60
+ batch: Tensor,
61
+ edge_attr: Tensor,
62
+ ) -> Tensor:
63
+ x = self.x_embed1(x[:, 0].long()) + self.x_embed2(x[:, 1].long())
64
+ edge_index, edge_attr = add_self_loops(
65
+ edge_index,
66
+ edge_attr,
67
+ fill_value=0,
68
+ num_nodes=x.size(0),
69
+ )
70
+ edge_attr = self.edge_embed1(edge_attr[:, 0]) + self.edge_embed2(
71
+ edge_attr[:, 1])
72
+ for i, (gnn, bn) in enumerate(zip(self.gnns, self.batch_norms)):
73
+ x = gnn(x, edge_index, edge_attr)
74
+ x = bn(x)
75
+ if i < self.num_layers - 1:
76
+ x = F.relu(x)
77
+ x = F.dropout(x, self.dropout, training=self.training)
78
+
79
+ x, mask = to_dense_batch(x, batch)
80
+ return x, mask
81
+
82
+
83
+ class GITFormer(torch.nn.Module):
84
+ def __init__(
85
+ self,
86
+ num_query_token: int,
87
+ vision_graph_width: int,
88
+ cross_attention_freq: int = 2,
89
+ ):
90
+ super().__init__()
91
+ from transformers import AutoConfig, AutoModel
92
+
93
+ config = AutoConfig.from_pretrained("allenai/scibert_scivocab_uncased")
94
+ config.encoder_width = vision_graph_width
95
+ # insert cross-attention layer every other block
96
+ config.add_cross_attention = True
97
+ config.is_decoder = True
98
+ config.cross_attention_freq = cross_attention_freq
99
+ config.query_length = num_query_token
100
+ self.Qformer = AutoModel.from_pretrained(
101
+ "allenai/scibert_scivocab_uncased", config=config)
102
+ self.query_tokens = torch.nn.Parameter(
103
+ torch.zeros(1, num_query_token, config.hidden_size))
104
+ self.query_tokens.data.normal_(mean=0.0, std=config.initializer_range)
105
+
106
+
107
+ class GITMol(torch.nn.Module):
108
+ r"""The GITMol model from the `"GIT-Mol: A Multi-modal Large Language
109
+ Model for Molecular Science with Graph, Image, and Text"
110
+ <https://arxiv.org/pdf/2308.06911>`_ paper.
111
+
112
+ .. note::
113
+ For an example of using :class:`GITMol`, see
114
+ `examples/llm/git_mol.py <https://github.com/pyg-team/
115
+ pytorch_geometric/blob/master/examples/llm/git_mol.py>`_.
116
+ """
117
+ def __init__(self) -> None:
118
+ super().__init__()
119
+ # graph
120
+ self.graph_encoder = GraphEncoder(num_layers=2, in_channels=16)
121
+ self.graph_proj = Linear(16, 768)
122
+ self.ln_graph = LayerNorm(768)
123
+ # text
124
+ self.text_encoder = SentenceTransformer(
125
+ model_name='allenai/scibert_scivocab_uncased',
126
+ pooling_strategy='last_hidden_state',
127
+ )
128
+ self.text_proj = Linear(768, 768)
129
+ self.ln_text = LayerNorm(768)
130
+ # vision
131
+ self.vision_encoder = VisionTransformer(
132
+ model_name='microsoft/swin-base-patch4-window7-224', )
133
+ self.vision_proj = Linear(1024, 768)
134
+ self.ln_vision = LayerNorm(768)
135
+ # cross-attention
136
+ self.gitformer = GITFormer(384, 768)
137
+
138
+ self.xtm_head = torch.nn.ModuleDict({
139
+ 'image':
140
+ Linear(self.gitformer.Qformer.config.hidden_size, 2),
141
+ 'graph':
142
+ Linear(self.gitformer.Qformer.config.hidden_size, 2),
143
+ 'cs_text':
144
+ Linear(self.gitformer.Qformer.config.hidden_size, 2),
145
+ })
146
+
147
+ self.xtc_proj = torch.nn.ModuleDict({
148
+ 'image':
149
+ Linear(self.gitformer.Qformer.config.hidden_size, 768),
150
+ 'graph':
151
+ Linear(self.gitformer.Qformer.config.hidden_size, 768),
152
+ 'cs_text':
153
+ Linear(self.gitformer.Qformer.config.hidden_size, 768),
154
+ })
155
+ self.temp = torch.nn.Parameter(0.07 * torch.ones([]))
156
+ self.model_freeze()
157
+
158
+ def model_freeze(self) -> None:
159
+ for param in self.graph_encoder.parameters():
160
+ param.requires_grad = False
161
+
162
+ for param in self.vision_encoder.parameters():
163
+ param.requires_grad = False
164
+
165
+ def forward(
166
+ self,
167
+ x: Tensor,
168
+ edge_index: Tensor,
169
+ batch: Tensor,
170
+ edge_attr: Optional[Tensor],
171
+ smiles: List[str],
172
+ images: Tensor,
173
+ captions: List[str],
174
+ ) -> Tensor:
175
+ batch_size = len(smiles)
176
+
177
+ x_vision = self.vision_encoder(images)
178
+ x_vision = self.vision_proj(x_vision)
179
+ x_vision = self.ln_vision(x_vision) # [bs, patch_len, d]
180
+ vision_atts = torch.ones(x_vision.size()[:-1],
181
+ dtype=torch.long).to(x_vision.device)
182
+ vision_targets = torch.arange(batch_size).to(x_vision.device)
183
+
184
+ x_graph, graph_atts = self.graph_encoder(x, edge_index, batch,
185
+ edge_attr)
186
+ x_graph = self.graph_proj(x_graph)
187
+ x_graph = self.ln_graph(x_graph) # [bs, node_len, d]
188
+ graph_targets = torch.arange(batch_size).to(x_graph.device)
189
+
190
+ x_smiles = self.text_encoder.encode(smiles) # [bs, seq_len, d]
191
+ smiles_atts = torch.ones(x_smiles.size()[:-1],
192
+ dtype=torch.long).to(x_smiles.device)
193
+ smiles_targets = torch.arange(batch_size).to(x_smiles.device)
194
+
195
+ caption_input_ids, caption_attention_masks = self.text_encoder.get_input_ids( # noqa: E501
196
+ captions)
197
+
198
+ text_output = self.gitformer.Qformer(
199
+ caption_input_ids,
200
+ attention_mask=caption_attention_masks,
201
+ return_dict=True,
202
+ )
203
+ text_feat = F.normalize(
204
+ self.text_proj(text_output.last_hidden_state[:, 0, :]), dim=-1)
205
+
206
+ loss = 0
207
+ for x_embed, x_atts, x_targets, modal in zip(
208
+ [x_graph, x_smiles, x_vision],
209
+ [graph_atts, smiles_atts, vision_atts],
210
+ [graph_targets, smiles_targets, vision_targets],
211
+ ['graph', 'cs_text', 'image'],
212
+ ):
213
+ loss += self._calc_xtc_loss(x_embed, x_atts, x_targets, text_feat,
214
+ modal)
215
+ loss += self._calc_xtm_loss(x_embed, caption_input_ids,
216
+ caption_attention_masks, modal)
217
+
218
+ return loss / 6
219
+
220
+ def _calc_xtm_loss(
221
+ self,
222
+ x_embeds: Tensor,
223
+ input_ids: Tensor,
224
+ attention_mask: Tensor,
225
+ modal: str,
226
+ ) -> Tensor:
227
+ # Initializing lists to hold the original and negative samples
228
+ x_embeds_list = []
229
+ text_input_ids_list = []
230
+ text_attention_mask_list = []
231
+
232
+ batch_size = x_embeds.size(0)
233
+ for i in range(batch_size):
234
+ # Original samples
235
+ x_embeds_list.append(x_embeds[i])
236
+ text_input_ids_list.append(input_ids[i, :])
237
+ text_attention_mask_list.append(attention_mask[i, :])
238
+
239
+ if batch_size > 1:
240
+ # Negative samples (neg_text_input_ids corresponds to x_embeds)
241
+ neg_text_input_ids = input_ids[i - 1 if i == batch_size -
242
+ 1 else i + 1, :]
243
+ neg_text_attention_mask = attention_mask[i -
244
+ 1 if i == batch_size -
245
+ 1 else i + 1, :]
246
+ text_input_ids_list.append(neg_text_input_ids)
247
+ text_attention_mask_list.append(neg_text_attention_mask)
248
+ x_embeds_list.append(x_embeds[i, :])
249
+
250
+ # Negative samples (text_input_ids corresponds to neg_x_embeds)
251
+ neg_x_embeds = x_embeds[i - 1 if i == batch_size - 1 else i +
252
+ 1, :]
253
+ x_embeds_list.append(neg_x_embeds)
254
+ text_input_ids_list.append(input_ids[i, :])
255
+ text_attention_mask_list.append(attention_mask[i, :])
256
+
257
+ # Stack all samples into two large tensors
258
+ x_embeds_all = torch.stack(x_embeds_list, dim=1) \
259
+ .reshape(-1, x_embeds.size(1), x_embeds.size(2))
260
+ text_input_ids_all = torch.stack(text_input_ids_list, dim=1) \
261
+ .reshape(-1, input_ids.size(1))
262
+ # Create image attention masks for the concatenated tensor
263
+ image_attns_all = torch.ones(x_embeds_all.size()[:-1],
264
+ dtype=torch.long).to(x_embeds_all.device)
265
+ query_tokens_xtm = self.gitformer.query_tokens.expand(
266
+ text_input_ids_all.shape[0], -1, -1)
267
+ query_attns_xtm = torch.ones(query_tokens_xtm.size()[:-1],
268
+ dtype=torch.long).to(x_embeds_all.device)
269
+
270
+ output_xtm = self.gitformer.Qformer(
271
+ inputs_embeds=query_tokens_xtm,
272
+ attention_mask=query_attns_xtm,
273
+ encoder_hidden_states=x_embeds_all,
274
+ encoder_attention_mask=image_attns_all,
275
+ return_dict=True,
276
+ ).last_hidden_state
277
+
278
+ xtm_embeddings = output_xtm[:, :query_tokens_xtm.size(1), :]
279
+
280
+ xtm_logit = self.xtm_head[modal](xtm_embeddings).mean(dim=1)
281
+ # Create labels: 1 for the original samples, 0 for the negative samples
282
+ if batch_size > 1:
283
+ labels = torch.cat(
284
+ [torch.ones(batch_size),
285
+ torch.zeros(batch_size * 2)], dim=0)
286
+ else:
287
+ labels = torch.ones(batch_size)
288
+ labels = labels.long().to(xtm_logit.device)
289
+
290
+ # Calculate cross entropy loss
291
+ return F.cross_entropy(xtm_logit, labels)
292
+
293
+ def _calc_xtc_loss(
294
+ self,
295
+ x_embeds: Tensor,
296
+ x_atts: Tensor,
297
+ x_targets: Tensor,
298
+ text_feat: Tensor,
299
+ modal: str,
300
+ ) -> Tensor:
301
+ query_tokens = self.gitformer.query_tokens.expand(
302
+ x_embeds.shape[0], -1, -1)
303
+
304
+ query_output = self.gitformer.Qformer(
305
+ inputs_embeds=query_tokens,
306
+ encoder_hidden_states=x_embeds,
307
+ encoder_attention_mask=x_atts,
308
+ return_dict=True,
309
+ ).last_hidden_state
310
+
311
+ x_feats = F.normalize(self.xtc_proj[modal](query_output), dim=-1)
312
+
313
+ sim_q2t = torch.matmul(
314
+ x_feats.unsqueeze(1),
315
+ text_feat.unsqueeze(-1),
316
+ ).squeeze(-1)
317
+
318
+ # modal-text similarity: aggregate across all query tokens
319
+ sim_x2t, _ = sim_q2t.max(-1)
320
+ sim_x2t = sim_x2t / self.temp
321
+
322
+ # text-query similarity
323
+ sim_t2q = torch.matmul(
324
+ text_feat.unsqueeze(1).unsqueeze(1),
325
+ x_feats.permute(0, 2, 1),
326
+ ).squeeze(-2)
327
+
328
+ # text-modal similarity: aggregate across all query tokens
329
+ sim_t2x, _ = sim_t2q.max(-1)
330
+ sim_t2x = sim_t2x / self.temp
331
+
332
+ loss_itc = (
333
+ F.cross_entropy(sim_x2t, x_targets, label_smoothing=0.1) +
334
+ F.cross_entropy(sim_t2x, x_targets, label_smoothing=0.1)) / 2
335
+
336
+ return loss_itc
@@ -1,7 +1,9 @@
1
1
  from .sentence_transformer import SentenceTransformer
2
+ from .vision_transformer import VisionTransformer
2
3
  from .llm import LLM
3
4
 
4
5
  __all__ = classes = [
5
6
  'SentenceTransformer',
7
+ 'VisionTransformer',
6
8
  'LLM',
7
9
  ]
@@ -48,6 +48,36 @@ class SentenceTransformer(torch.nn.Module):
48
48
  emb = F.normalize(emb, p=2, dim=1)
49
49
  return emb
50
50
 
51
+ def get_input_ids(
52
+ self,
53
+ text: List[str],
54
+ batch_size: Optional[int] = None,
55
+ output_device: Optional[Union[torch.device, str]] = None,
56
+ ) -> Tensor:
57
+ is_empty = len(text) == 0
58
+ text = ['dummy'] if is_empty else text
59
+
60
+ batch_size = len(text) if batch_size is None else batch_size
61
+
62
+ input_ids: List[Tensor] = []
63
+ attention_masks: List[Tensor] = []
64
+ for start in range(0, len(text), batch_size):
65
+ token = self.tokenizer(
66
+ text[start:start + batch_size],
67
+ padding=True,
68
+ truncation=True,
69
+ return_tensors='pt',
70
+ )
71
+ input_ids.append(token.input_ids.to(self.device))
72
+ attention_masks.append(token.attention_mask.to(self.device))
73
+
74
+ def _out(x: List[Tensor]) -> Tensor:
75
+ out = torch.cat(x, dim=0) if len(x) > 1 else x[0]
76
+ out = out[:0] if is_empty else out
77
+ return out.to(output_device)
78
+
79
+ return _out(input_ids), _out(attention_masks)
80
+
51
81
  @property
52
82
  def device(self) -> torch.device:
53
83
  return next(iter(self.model.parameters())).device
@@ -0,0 +1,33 @@
1
+ from typing import Optional, Union
2
+
3
+ import torch
4
+ from torch import Tensor
5
+
6
+
7
+ class VisionTransformer(torch.nn.Module):
8
+ def __init__(
9
+ self,
10
+ model_name: str,
11
+ ) -> None:
12
+ super().__init__()
13
+ self.model_name = model_name
14
+
15
+ from transformers import SwinConfig, SwinModel
16
+
17
+ self.config = SwinConfig.from_pretrained(model_name)
18
+ self.model = SwinModel(self.config)
19
+
20
+ @torch.no_grad()
21
+ def forward(
22
+ self,
23
+ images: Tensor,
24
+ output_device: Optional[Union[torch.device, str]] = None,
25
+ ) -> Tensor:
26
+ return self.model(images).last_hidden_state.to(output_device)
27
+
28
+ @property
29
+ def device(self) -> torch.device:
30
+ return next(iter(self.model.parameters())).device
31
+
32
+ def __repr__(self) -> str:
33
+ return f'{self.__class__.__name__}(model_name={self.model_name})'