pyg-nightly 2.7.0.dev20241124__py3-none-any.whl → 2.7.0.dev20241125__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pyg_nightly-2.7.0.dev20241124.dist-info → pyg_nightly-2.7.0.dev20241125.dist-info}/METADATA +1 -1
- {pyg_nightly-2.7.0.dev20241124.dist-info → pyg_nightly-2.7.0.dev20241125.dist-info}/RECORD +11 -8
- torch_geometric/__init__.py +1 -1
- torch_geometric/datasets/__init__.py +2 -0
- torch_geometric/datasets/git_mol_dataset.py +263 -0
- torch_geometric/nn/models/__init__.py +2 -0
- torch_geometric/nn/models/git_mol.py +336 -0
- torch_geometric/nn/nlp/__init__.py +2 -0
- torch_geometric/nn/nlp/sentence_transformer.py +30 -0
- torch_geometric/nn/nlp/vision_transformer.py +33 -0
- {pyg_nightly-2.7.0.dev20241124.dist-info → pyg_nightly-2.7.0.dev20241125.dist-info}/WHEEL +0 -0
{pyg_nightly-2.7.0.dev20241124.dist-info → pyg_nightly-2.7.0.dev20241125.dist-info}/METADATA
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.3
|
2
2
|
Name: pyg-nightly
|
3
|
-
Version: 2.7.0.
|
3
|
+
Version: 2.7.0.dev20241125
|
4
4
|
Summary: Graph Neural Network Library for PyTorch
|
5
5
|
Keywords: deep-learning,pytorch,geometric-deep-learning,graph-neural-networks,graph-convolutional-networks
|
6
6
|
Author-email: Matthias Fey <matthias@pyg.org>
|
@@ -1,4 +1,4 @@
|
|
1
|
-
torch_geometric/__init__.py,sha256=
|
1
|
+
torch_geometric/__init__.py,sha256=hPMlzqznHr3x2xBZhYBnmC-i7KVOX-tIpw1gy43En6g,1904
|
2
2
|
torch_geometric/_compile.py,sha256=f-WQeH4VLi5Hn9lrgztFUCSrN_FImjhQa6BxFzcYC38,1338
|
3
3
|
torch_geometric/_onnx.py,sha256=V9ffrIKSqhDw6xUZ12lkuSfNs48cQp2EeJ6Z19GfnVw,349
|
4
4
|
torch_geometric/backend.py,sha256=lVaf7aLoVaB3M-UcByUJ1G4T4FOK6LXAg0CF4W3E8jo,1575
|
@@ -53,7 +53,7 @@ torch_geometric/data/temporal.py,sha256=WOJ6gFrTLikaLhUvotyUF5ql14FkE5Ox3hNkdSp6
|
|
53
53
|
torch_geometric/data/view.py,sha256=XjkVSc-UWZFCT4DlXLShZtO8duhFQkS9gq88zZXANsk,1089
|
54
54
|
torch_geometric/data/lightning/__init__.py,sha256=w3En1tJfy3kSqe1MycpOyZpHFO3fxBCgNCUOznPA3YU,178
|
55
55
|
torch_geometric/data/lightning/datamodule.py,sha256=Bn9iaIfE4NWDDWWMqCvBeZ4bIW1Silx_Ol5CPJCliaQ,29242
|
56
|
-
torch_geometric/datasets/__init__.py,sha256=
|
56
|
+
torch_geometric/datasets/__init__.py,sha256=HYgogFHWZabd5yLfc1E4eHy9QsY6ILFRPTgfOorNwWQ,6077
|
57
57
|
torch_geometric/datasets/actor.py,sha256=oUxgJIX8bi5hJr1etWNYIFyVQNDDXi1nyVpHGGMEAGQ,4304
|
58
58
|
torch_geometric/datasets/airfrans.py,sha256=212gYsk7PvF-qcmvM2YXaOBhFrS79evAGg_sPHXih4w,5439
|
59
59
|
torch_geometric/datasets/airports.py,sha256=b3gkv3gY2JkUpmGiz36Z-g7EcnSfU8lBG1YsCOWdJ6k,3758
|
@@ -92,6 +92,7 @@ torch_geometric/datasets/gdelt_lite.py,sha256=zE1WagpgmsQARQhEgdCBtALRKyuQvIZqxT
|
|
92
92
|
torch_geometric/datasets/ged_dataset.py,sha256=dtd-C6pCygNHLXgVfg3ZTWtTVHKT13Q3GlGrze1_rpo,9551
|
93
93
|
torch_geometric/datasets/gemsec.py,sha256=oMTSryTgyed9z_4ydg3ql12KM-_35uqL1AoNls5nG8M,2820
|
94
94
|
torch_geometric/datasets/geometry.py,sha256=-BxUMirZcUOf01c3avvF0b6wGPn-4S3Zj3Oau1RaJVk,4223
|
95
|
+
torch_geometric/datasets/git_mol_dataset.py,sha256=fdE7hG_gF9bNGHaUITkEnHsZPf9FZy6F66SvvXJ5Tgc,10713
|
95
96
|
torch_geometric/datasets/github.py,sha256=Qhqhkvi6eZ8VF_HqP1rL2iYToZavFNsQh7J1WdeM9dA,2687
|
96
97
|
torch_geometric/datasets/gnn_benchmark_dataset.py,sha256=4P8n7czF-gf1egLYlAcSSvfB0GXIKpAbH5UjsuFld1M,6976
|
97
98
|
torch_geometric/datasets/heterophilous_graph_dataset.py,sha256=yHHtwl4uPrid0vPOxvPV3sIS8HWdswar8FJ0h0OQ9is,4224
|
@@ -420,7 +421,7 @@ torch_geometric/nn/kge/distmult.py,sha256=dGQ0bVzjreZgFN1lXE23_IIidsiOq7ehPrMb-N
|
|
420
421
|
torch_geometric/nn/kge/loader.py,sha256=5Uc1j3OUMQnBYSHDqL7pLCty1siFLzoPkztigYO2zP8,771
|
421
422
|
torch_geometric/nn/kge/rotate.py,sha256=XLuO1AbyTt5cJxr97ZzoyAyIEsHKesgW5TvDmnGJAao,3208
|
422
423
|
torch_geometric/nn/kge/transe.py,sha256=jlejq5BLMm-sb1wWcLDp7pZqCdelWBgjDIC8ctbjSdU,3088
|
423
|
-
torch_geometric/nn/models/__init__.py,sha256=
|
424
|
+
torch_geometric/nn/models/__init__.py,sha256=vWMKzGBVxA1Fm0uGDLnH4jzYgfhK34CQTRJ-xi5pf5k,2150
|
424
425
|
torch_geometric/nn/models/attentive_fp.py,sha256=tkgvw28wg9-JqHIfBllfCwTHrZIUiv85yZJcDqjz3z0,6634
|
425
426
|
torch_geometric/nn/models/autoencoder.py,sha256=nGje-zty78Y3hxOJ9o0_6QziJjOvBlknk6z0_fDQwQU,10770
|
426
427
|
torch_geometric/nn/models/basic_gnn.py,sha256=PGa0RUMyvrNy_5yRI2jX_zwPsmZXwOQWfsWvxOiHsSk,31225
|
@@ -431,6 +432,7 @@ torch_geometric/nn/models/deepgcn.py,sha256=tIgT03cj8MghYlxEozpoGvGG_CwpJrGDxv1Z
|
|
431
432
|
torch_geometric/nn/models/dimenet.py,sha256=Kc5p-rB5q-0e8lY22l-OdQTscTxJh2lTEpeRFMdL4RY,36186
|
432
433
|
torch_geometric/nn/models/dimenet_utils.py,sha256=Eyn_EiJqwKvuYj6BtRpSxrzMG3v4Gk98X9MxZ7uvwm4,5069
|
433
434
|
torch_geometric/nn/models/g_retriever.py,sha256=VueRImNJlh1WvRWcsSXliSw8RlxlzWlu2WSFs_VQaJc,7749
|
435
|
+
torch_geometric/nn/models/git_mol.py,sha256=Wc6Hx6RDDR7sDWRWHfA5eK9e9gFsrTZ9OLmpMfoj3pE,12676
|
434
436
|
torch_geometric/nn/models/glem.py,sha256=gqQF4jlU7U_u5-zGeJZuHiEqhSXa-wLU5TghN4u5fYY,16389
|
435
437
|
torch_geometric/nn/models/gnnff.py,sha256=15dkiLgy0LmH1hnUrpeoHioIp4BPTfjpVATpnGRt9E0,7860
|
436
438
|
torch_geometric/nn/models/graph_mixer.py,sha256=mthMeCOikR8gseEsu4oJ3Cd9C35zHSv1p32ROwnG-6s,9246
|
@@ -454,9 +456,10 @@ torch_geometric/nn/models/schnet.py,sha256=0aaHrVtxApdvn3RHCGLQJW1MbIb--CSYUrx9O
|
|
454
456
|
torch_geometric/nn/models/signed_gcn.py,sha256=J40CnedFIqtKI1LhW1ITSEFRbA_XiJZL6lASrKwUEAI,9841
|
455
457
|
torch_geometric/nn/models/tgn.py,sha256=kEGdfLJybkbMT4UMoAh2nCzfX3_nDjfm1cicuPHEwAM,11878
|
456
458
|
torch_geometric/nn/models/visnet.py,sha256=97OFMCsPDEI5BCSi7RhoRcU2CNRp7zck2tEzrltFZj4,43192
|
457
|
-
torch_geometric/nn/nlp/__init__.py,sha256=
|
459
|
+
torch_geometric/nn/nlp/__init__.py,sha256=q6CPUiJHcc9bXw90lyj-ID4F3kfW8uPM-SOxW9uCMHs,213
|
458
460
|
torch_geometric/nn/nlp/llm.py,sha256=M15Qn0yHyA6HL2rHCH2p4H6hKjUvLfnzlxdfEFvRxSA,11732
|
459
|
-
torch_geometric/nn/nlp/sentence_transformer.py,sha256=
|
461
|
+
torch_geometric/nn/nlp/sentence_transformer.py,sha256=q5M7SGtrUzoSiNhKCGFb7JatWiukdhNF6zdq2yiqxwE,4475
|
462
|
+
torch_geometric/nn/nlp/vision_transformer.py,sha256=diVBefjIynzYs8WBlcpTeSVnw1PUecHY--B9Yd-W2hA,863
|
460
463
|
torch_geometric/nn/norm/__init__.py,sha256=u2qIDrkbeuObGVXSAIftAlvSd6ouGTtxznCfD-59UiA,669
|
461
464
|
torch_geometric/nn/norm/batch_norm.py,sha256=sJKrinHGwA-noIgteg1RD2W06rd0zskD-rXuY-36glY,8283
|
462
465
|
torch_geometric/nn/norm/diff_group_norm.py,sha256=b57XvNekrUYGDjNJlGeqvaMGNJmHwopSF0_yyBWlLuA,4722
|
@@ -623,6 +626,6 @@ torch_geometric/utils/undirected.py,sha256=H_nfpI0_WluOG6VfjPyldvcjL4w5USAKWu2x5
|
|
623
626
|
torch_geometric/visualization/__init__.py,sha256=PyR_4K5SafsJrBr6qWrkjKr6GBL1b7FtZybyXCDEVwY,154
|
624
627
|
torch_geometric/visualization/graph.py,sha256=ZuLPL92yGRi7lxlqsUPwL_EVVXF7P2kMcveTtW79vpA,4784
|
625
628
|
torch_geometric/visualization/influence.py,sha256=CWMvuNA_Nf1sfbJmQgn58yS4OFpeKXeZPe7kEuvkUBw,477
|
626
|
-
pyg_nightly-2.7.0.
|
627
|
-
pyg_nightly-2.7.0.
|
628
|
-
pyg_nightly-2.7.0.
|
629
|
+
pyg_nightly-2.7.0.dev20241125.dist-info/WHEEL,sha256=CpUCUxeHQbRN5UGRQHYRJorO5Af-Qy_fHMctcQ8DSGI,82
|
630
|
+
pyg_nightly-2.7.0.dev20241125.dist-info/METADATA,sha256=bDgjxvVn0QZLKMZH40NUhX3W96-XohGqDUXoYJ8Ly3A,62979
|
631
|
+
pyg_nightly-2.7.0.dev20241125.dist-info/RECORD,,
|
torch_geometric/__init__.py
CHANGED
@@ -30,7 +30,7 @@ from .lazy_loader import LazyLoader
|
|
30
30
|
contrib = LazyLoader('contrib', globals(), 'torch_geometric.contrib')
|
31
31
|
graphgym = LazyLoader('graphgym', globals(), 'torch_geometric.graphgym')
|
32
32
|
|
33
|
-
__version__ = '2.7.0.
|
33
|
+
__version__ = '2.7.0.dev20241125'
|
34
34
|
|
35
35
|
__all__ = [
|
36
36
|
'Index',
|
@@ -77,6 +77,7 @@ from .myket import MyketDataset
|
|
77
77
|
from .brca_tgca import BrcaTcga
|
78
78
|
from .neurograph import NeuroGraphDataset
|
79
79
|
from .web_qsp_dataset import WebQSPDataset
|
80
|
+
from .git_mol_dataset import GitMolDataset
|
80
81
|
from .molecule_gpt_dataset import MoleculeGPTDataset
|
81
82
|
from .tag_dataset import TAGDataset
|
82
83
|
|
@@ -192,6 +193,7 @@ homo_datasets = [
|
|
192
193
|
'BrcaTcga',
|
193
194
|
'NeuroGraphDataset',
|
194
195
|
'WebQSPDataset',
|
196
|
+
'GitMolDataset',
|
195
197
|
'MoleculeGPTDataset',
|
196
198
|
'TAGDataset',
|
197
199
|
]
|
@@ -0,0 +1,263 @@
|
|
1
|
+
import sys
|
2
|
+
from typing import Any, Callable, Dict, List, Optional
|
3
|
+
|
4
|
+
import numpy as np
|
5
|
+
import torch
|
6
|
+
from tqdm import tqdm
|
7
|
+
|
8
|
+
from torch_geometric.data import (
|
9
|
+
Data,
|
10
|
+
InMemoryDataset,
|
11
|
+
download_google_url,
|
12
|
+
extract_zip,
|
13
|
+
)
|
14
|
+
from torch_geometric.io import fs
|
15
|
+
|
16
|
+
|
17
|
+
def safe_index(lst: List[Any], e: int) -> int:
|
18
|
+
return lst.index(e) if e in lst else len(lst) - 1
|
19
|
+
|
20
|
+
|
21
|
+
class GitMolDataset(InMemoryDataset):
|
22
|
+
r"""The dataset from the `"GIT-Mol: A Multi-modal Large Language Model
|
23
|
+
for Molecular Science with Graph, Image, and Text"
|
24
|
+
<https://arxiv.org/pdf/2308.06911>`_ paper.
|
25
|
+
|
26
|
+
Args:
|
27
|
+
root (str): Root directory where the dataset should be saved.
|
28
|
+
transform (callable, optional): A function/transform that takes in an
|
29
|
+
:obj:`torch_geometric.data.Data` object and returns a transformed
|
30
|
+
version. The data object will be transformed before every access.
|
31
|
+
(default: :obj:`None`)
|
32
|
+
pre_transform (callable, optional): A function/transform that takes in
|
33
|
+
an :obj:`torch_geometric.data.Data` object and returns a
|
34
|
+
transformed version. The data object will be transformed before
|
35
|
+
being saved to disk. (default: :obj:`None`)
|
36
|
+
pre_filter (callable, optional): A function that takes in an
|
37
|
+
:obj:`torch_geometric.data.Data` object and returns a boolean
|
38
|
+
value, indicating whether the data object should be included in the
|
39
|
+
final dataset. (default: :obj:`None`)
|
40
|
+
force_reload (bool, optional): Whether to re-process the dataset.
|
41
|
+
(default: :obj:`False`)
|
42
|
+
split (int, optional): Datasets split, train/valid/test=0/1/2.
|
43
|
+
(default: :obj:`0`)
|
44
|
+
"""
|
45
|
+
|
46
|
+
raw_url_id = '1loBXabD6ncAFY-vanRsVtRUSFkEtBweg'
|
47
|
+
|
48
|
+
def __init__(
|
49
|
+
self,
|
50
|
+
root: str,
|
51
|
+
transform: Optional[Callable] = None,
|
52
|
+
pre_transform: Optional[Callable] = None,
|
53
|
+
pre_filter: Optional[Callable] = None,
|
54
|
+
force_reload: bool = False,
|
55
|
+
split: int = 0,
|
56
|
+
):
|
57
|
+
from torchvision import transforms
|
58
|
+
|
59
|
+
self.split = split
|
60
|
+
|
61
|
+
if self.split == 0:
|
62
|
+
self.img_transform = transforms.Compose([
|
63
|
+
transforms.Resize((224, 224)),
|
64
|
+
transforms.RandomRotation(15),
|
65
|
+
transforms.ColorJitter(brightness=0.5, contrast=0.5, hue=0.5),
|
66
|
+
transforms.ToTensor(),
|
67
|
+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
68
|
+
std=[0.229, 0.224, 0.225])
|
69
|
+
])
|
70
|
+
else:
|
71
|
+
self.img_transform = transforms.Compose([
|
72
|
+
transforms.Resize((224, 224)),
|
73
|
+
transforms.ToTensor(),
|
74
|
+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
75
|
+
std=[0.229, 0.224, 0.225])
|
76
|
+
])
|
77
|
+
|
78
|
+
super().__init__(root, transform, pre_transform, pre_filter,
|
79
|
+
force_reload=force_reload)
|
80
|
+
|
81
|
+
self.load(self.processed_paths[0])
|
82
|
+
|
83
|
+
@property
|
84
|
+
def raw_file_names(self) -> List[str]:
|
85
|
+
return ['train_3500.pkl', 'valid_450.pkl', 'test_450.pkl']
|
86
|
+
|
87
|
+
@property
|
88
|
+
def processed_file_names(self) -> str:
|
89
|
+
return ['train.pt', 'valid.pt', 'test.pt'][self.split]
|
90
|
+
|
91
|
+
def download(self) -> None:
|
92
|
+
file_path = download_google_url(
|
93
|
+
self.raw_url_id,
|
94
|
+
self.raw_dir,
|
95
|
+
'gitmol.zip',
|
96
|
+
)
|
97
|
+
extract_zip(file_path, self.raw_dir)
|
98
|
+
|
99
|
+
def process(self) -> None:
|
100
|
+
import pandas as pd
|
101
|
+
from PIL import Image
|
102
|
+
|
103
|
+
try:
|
104
|
+
from rdkit import Chem, RDLogger
|
105
|
+
RDLogger.DisableLog('rdApp.*') # type: ignore
|
106
|
+
WITH_RDKIT = True
|
107
|
+
|
108
|
+
except ImportError:
|
109
|
+
WITH_RDKIT = False
|
110
|
+
|
111
|
+
if not WITH_RDKIT:
|
112
|
+
print(("Using a pre-processed version of the dataset. Please "
|
113
|
+
"install 'rdkit' to alternatively process the raw data."),
|
114
|
+
file=sys.stderr)
|
115
|
+
|
116
|
+
data_list = fs.torch_load(self.raw_paths[0])
|
117
|
+
data_list = [Data(**data_dict) for data_dict in data_list]
|
118
|
+
|
119
|
+
if self.pre_filter is not None:
|
120
|
+
data_list = [d for d in data_list if self.pre_filter(d)]
|
121
|
+
|
122
|
+
if self.pre_transform is not None:
|
123
|
+
data_list = [self.pre_transform(d) for d in data_list]
|
124
|
+
|
125
|
+
self.save(data_list, self.processed_paths[0])
|
126
|
+
return
|
127
|
+
|
128
|
+
allowable_features: Dict[str, List[Any]] = {
|
129
|
+
'possible_atomic_num_list':
|
130
|
+
list(range(1, 119)) + ['misc'],
|
131
|
+
'possible_formal_charge_list':
|
132
|
+
[-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 'misc'],
|
133
|
+
'possible_chirality_list': [
|
134
|
+
Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
|
135
|
+
Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
|
136
|
+
Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
|
137
|
+
Chem.rdchem.ChiralType.CHI_OTHER
|
138
|
+
],
|
139
|
+
'possible_hybridization_list': [
|
140
|
+
Chem.rdchem.HybridizationType.SP,
|
141
|
+
Chem.rdchem.HybridizationType.SP2,
|
142
|
+
Chem.rdchem.HybridizationType.SP3,
|
143
|
+
Chem.rdchem.HybridizationType.SP3D,
|
144
|
+
Chem.rdchem.HybridizationType.SP3D2,
|
145
|
+
Chem.rdchem.HybridizationType.UNSPECIFIED, 'misc'
|
146
|
+
],
|
147
|
+
'possible_numH_list': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'],
|
148
|
+
'possible_implicit_valence_list': [0, 1, 2, 3, 4, 5, 6],
|
149
|
+
'possible_degree_list': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 'misc'],
|
150
|
+
'possible_number_radical_e_list': [0, 1, 2, 3, 4, 'misc'],
|
151
|
+
'possible_is_aromatic_list': [False, True],
|
152
|
+
'possible_is_in_ring_list': [False, True],
|
153
|
+
'possible_bond_type_list': [
|
154
|
+
Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE,
|
155
|
+
Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC,
|
156
|
+
Chem.rdchem.BondType.ZERO
|
157
|
+
],
|
158
|
+
'possible_bond_dirs': [ # only for double bond stereo information
|
159
|
+
Chem.rdchem.BondDir.NONE, Chem.rdchem.BondDir.ENDUPRIGHT,
|
160
|
+
Chem.rdchem.BondDir.ENDDOWNRIGHT
|
161
|
+
],
|
162
|
+
'possible_bond_stereo_list': [
|
163
|
+
Chem.rdchem.BondStereo.STEREONONE,
|
164
|
+
Chem.rdchem.BondStereo.STEREOZ,
|
165
|
+
Chem.rdchem.BondStereo.STEREOE,
|
166
|
+
Chem.rdchem.BondStereo.STEREOCIS,
|
167
|
+
Chem.rdchem.BondStereo.STEREOTRANS,
|
168
|
+
Chem.rdchem.BondStereo.STEREOANY,
|
169
|
+
],
|
170
|
+
'possible_is_conjugated_list': [False, True]
|
171
|
+
}
|
172
|
+
|
173
|
+
data = pd.read_pickle(
|
174
|
+
f'{self.raw_dir}/igcdata_toy/{self.raw_file_names[self.split]}')
|
175
|
+
|
176
|
+
data_list = []
|
177
|
+
for _, r in tqdm(data.iterrows(), total=data.shape[0]):
|
178
|
+
smiles = r['isosmiles']
|
179
|
+
mol = Chem.MolFromSmiles(smiles.strip('\n'))
|
180
|
+
if mol is not None:
|
181
|
+
# text
|
182
|
+
summary = r['summary']
|
183
|
+
# image
|
184
|
+
cid = r['cid']
|
185
|
+
img_file = f'{self.raw_dir}/igcdata_toy/imgs/CID_{cid}.png'
|
186
|
+
img = Image.open(img_file).convert('RGB')
|
187
|
+
img = self.img_transform(img).unsqueeze(0)
|
188
|
+
# graph
|
189
|
+
atom_features_list = []
|
190
|
+
for atom in mol.GetAtoms(): # type: ignore
|
191
|
+
atom_feature = [
|
192
|
+
safe_index(
|
193
|
+
allowable_features['possible_atomic_num_list'],
|
194
|
+
atom.GetAtomicNum()),
|
195
|
+
allowable_features['possible_chirality_list'].index(
|
196
|
+
atom.GetChiralTag()),
|
197
|
+
safe_index(allowable_features['possible_degree_list'],
|
198
|
+
atom.GetTotalDegree()),
|
199
|
+
safe_index(
|
200
|
+
allowable_features['possible_formal_charge_list'],
|
201
|
+
atom.GetFormalCharge()),
|
202
|
+
safe_index(allowable_features['possible_numH_list'],
|
203
|
+
atom.GetTotalNumHs()),
|
204
|
+
safe_index(
|
205
|
+
allowable_features[
|
206
|
+
'possible_number_radical_e_list'],
|
207
|
+
atom.GetNumRadicalElectrons()),
|
208
|
+
safe_index(
|
209
|
+
allowable_features['possible_hybridization_list'],
|
210
|
+
atom.GetHybridization()),
|
211
|
+
allowable_features['possible_is_aromatic_list'].index(
|
212
|
+
atom.GetIsAromatic()),
|
213
|
+
allowable_features['possible_is_in_ring_list'].index(
|
214
|
+
atom.IsInRing()),
|
215
|
+
]
|
216
|
+
atom_features_list.append(atom_feature)
|
217
|
+
x = torch.tensor(np.array(atom_features_list),
|
218
|
+
dtype=torch.long)
|
219
|
+
|
220
|
+
edges_list = []
|
221
|
+
edge_features_list = []
|
222
|
+
for bond in mol.GetBonds(): # type: ignore
|
223
|
+
i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
|
224
|
+
edge_feature = [
|
225
|
+
safe_index(
|
226
|
+
allowable_features['possible_bond_type_list'],
|
227
|
+
bond.GetBondType()),
|
228
|
+
allowable_features['possible_bond_stereo_list'].index(
|
229
|
+
bond.GetStereo()),
|
230
|
+
allowable_features['possible_is_conjugated_list'].
|
231
|
+
index(bond.GetIsConjugated()),
|
232
|
+
]
|
233
|
+
edges_list.append((i, j))
|
234
|
+
edge_features_list.append(edge_feature)
|
235
|
+
edges_list.append((j, i))
|
236
|
+
edge_features_list.append(edge_feature)
|
237
|
+
|
238
|
+
edge_index = torch.tensor(
|
239
|
+
np.array(edges_list).T,
|
240
|
+
dtype=torch.long,
|
241
|
+
)
|
242
|
+
edge_attr = torch.tensor(
|
243
|
+
np.array(edge_features_list),
|
244
|
+
dtype=torch.long,
|
245
|
+
)
|
246
|
+
|
247
|
+
data = Data(
|
248
|
+
x=x,
|
249
|
+
edge_index=edge_index,
|
250
|
+
smiles=smiles,
|
251
|
+
edge_attr=edge_attr,
|
252
|
+
image=img,
|
253
|
+
caption=summary,
|
254
|
+
)
|
255
|
+
|
256
|
+
if self.pre_filter is not None and not self.pre_filter(data):
|
257
|
+
continue
|
258
|
+
if self.pre_transform is not None:
|
259
|
+
data = self.pre_transform(data)
|
260
|
+
|
261
|
+
data_list.append(data)
|
262
|
+
|
263
|
+
self.save(data_list, self.processed_paths[0])
|
@@ -29,6 +29,7 @@ from .pmlp import PMLP
|
|
29
29
|
from .neural_fingerprint import NeuralFingerprint
|
30
30
|
from .visnet import ViSNet
|
31
31
|
from .g_retriever import GRetriever
|
32
|
+
from .git_mol import GITMol
|
32
33
|
from .molecule_gpt import MoleculeGPT
|
33
34
|
from .glem import GLEM
|
34
35
|
# Deprecated:
|
@@ -78,6 +79,7 @@ __all__ = classes = [
|
|
78
79
|
'NeuralFingerprint',
|
79
80
|
'ViSNet',
|
80
81
|
'GRetriever',
|
82
|
+
'GITMol',
|
81
83
|
'MoleculeGPT',
|
82
84
|
'GLEM',
|
83
85
|
]
|
@@ -0,0 +1,336 @@
|
|
1
|
+
from typing import List, Optional
|
2
|
+
|
3
|
+
import torch
|
4
|
+
import torch.nn.functional as F
|
5
|
+
from torch import Tensor
|
6
|
+
from torch.nn import BatchNorm1d, LayerNorm, Linear, ReLU, Sequential
|
7
|
+
|
8
|
+
from torch_geometric.nn import GINEConv
|
9
|
+
from torch_geometric.nn.nlp import SentenceTransformer, VisionTransformer
|
10
|
+
from torch_geometric.utils import add_self_loops, to_dense_batch
|
11
|
+
|
12
|
+
|
13
|
+
class GraphEncoder(torch.nn.Module):
|
14
|
+
def __init__(
|
15
|
+
self,
|
16
|
+
num_layers: int,
|
17
|
+
in_channels: int,
|
18
|
+
dropout: float = 0.,
|
19
|
+
num_atom_type: int = 120,
|
20
|
+
num_chirality_tag: int = 3,
|
21
|
+
num_bond_type: int = 6,
|
22
|
+
num_bond_direction: int = 3,
|
23
|
+
) -> None:
|
24
|
+
super().__init__()
|
25
|
+
|
26
|
+
self.num_layers = num_layers
|
27
|
+
self.dropout = dropout
|
28
|
+
|
29
|
+
self.x_embed1 = torch.nn.Embedding(num_atom_type, in_channels)
|
30
|
+
self.x_embed2 = torch.nn.Embedding(num_chirality_tag, in_channels)
|
31
|
+
self.edge_embed1 = torch.nn.Embedding(num_bond_type, in_channels)
|
32
|
+
self.edge_embed2 = torch.nn.Embedding(num_bond_direction, in_channels)
|
33
|
+
|
34
|
+
self.gnns = torch.nn.ModuleList()
|
35
|
+
self.batch_norms = torch.nn.ModuleList()
|
36
|
+
for _ in range(num_layers):
|
37
|
+
self.gnns.append(
|
38
|
+
GINEConv(
|
39
|
+
nn=Sequential(
|
40
|
+
Linear(in_channels, in_channels * 2),
|
41
|
+
ReLU(),
|
42
|
+
Linear(in_channels * 2, in_channels),
|
43
|
+
),
|
44
|
+
train_eps=True,
|
45
|
+
edge_dim=in_channels,
|
46
|
+
))
|
47
|
+
self.batch_norms.append(BatchNorm1d(in_channels))
|
48
|
+
self.reset_parameters()
|
49
|
+
|
50
|
+
def reset_parameters(self):
|
51
|
+
torch.nn.init.xavier_uniform_(self.x_embed1.weight.data)
|
52
|
+
torch.nn.init.xavier_uniform_(self.x_embed2.weight.data)
|
53
|
+
torch.nn.init.xavier_uniform_(self.edge_embed1.weight.data)
|
54
|
+
torch.nn.init.xavier_uniform_(self.edge_embed2.weight.data)
|
55
|
+
|
56
|
+
def forward(
|
57
|
+
self,
|
58
|
+
x: Tensor,
|
59
|
+
edge_index: Tensor,
|
60
|
+
batch: Tensor,
|
61
|
+
edge_attr: Tensor,
|
62
|
+
) -> Tensor:
|
63
|
+
x = self.x_embed1(x[:, 0].long()) + self.x_embed2(x[:, 1].long())
|
64
|
+
edge_index, edge_attr = add_self_loops(
|
65
|
+
edge_index,
|
66
|
+
edge_attr,
|
67
|
+
fill_value=0,
|
68
|
+
num_nodes=x.size(0),
|
69
|
+
)
|
70
|
+
edge_attr = self.edge_embed1(edge_attr[:, 0]) + self.edge_embed2(
|
71
|
+
edge_attr[:, 1])
|
72
|
+
for i, (gnn, bn) in enumerate(zip(self.gnns, self.batch_norms)):
|
73
|
+
x = gnn(x, edge_index, edge_attr)
|
74
|
+
x = bn(x)
|
75
|
+
if i < self.num_layers - 1:
|
76
|
+
x = F.relu(x)
|
77
|
+
x = F.dropout(x, self.dropout, training=self.training)
|
78
|
+
|
79
|
+
x, mask = to_dense_batch(x, batch)
|
80
|
+
return x, mask
|
81
|
+
|
82
|
+
|
83
|
+
class GITFormer(torch.nn.Module):
|
84
|
+
def __init__(
|
85
|
+
self,
|
86
|
+
num_query_token: int,
|
87
|
+
vision_graph_width: int,
|
88
|
+
cross_attention_freq: int = 2,
|
89
|
+
):
|
90
|
+
super().__init__()
|
91
|
+
from transformers import AutoConfig, AutoModel
|
92
|
+
|
93
|
+
config = AutoConfig.from_pretrained("allenai/scibert_scivocab_uncased")
|
94
|
+
config.encoder_width = vision_graph_width
|
95
|
+
# insert cross-attention layer every other block
|
96
|
+
config.add_cross_attention = True
|
97
|
+
config.is_decoder = True
|
98
|
+
config.cross_attention_freq = cross_attention_freq
|
99
|
+
config.query_length = num_query_token
|
100
|
+
self.Qformer = AutoModel.from_pretrained(
|
101
|
+
"allenai/scibert_scivocab_uncased", config=config)
|
102
|
+
self.query_tokens = torch.nn.Parameter(
|
103
|
+
torch.zeros(1, num_query_token, config.hidden_size))
|
104
|
+
self.query_tokens.data.normal_(mean=0.0, std=config.initializer_range)
|
105
|
+
|
106
|
+
|
107
|
+
class GITMol(torch.nn.Module):
|
108
|
+
r"""The GITMol model from the `"GIT-Mol: A Multi-modal Large Language
|
109
|
+
Model for Molecular Science with Graph, Image, and Text"
|
110
|
+
<https://arxiv.org/pdf/2308.06911>`_ paper.
|
111
|
+
|
112
|
+
.. note::
|
113
|
+
For an example of using :class:`GITMol`, see
|
114
|
+
`examples/llm/git_mol.py <https://github.com/pyg-team/
|
115
|
+
pytorch_geometric/blob/master/examples/llm/git_mol.py>`_.
|
116
|
+
"""
|
117
|
+
def __init__(self) -> None:
|
118
|
+
super().__init__()
|
119
|
+
# graph
|
120
|
+
self.graph_encoder = GraphEncoder(num_layers=2, in_channels=16)
|
121
|
+
self.graph_proj = Linear(16, 768)
|
122
|
+
self.ln_graph = LayerNorm(768)
|
123
|
+
# text
|
124
|
+
self.text_encoder = SentenceTransformer(
|
125
|
+
model_name='allenai/scibert_scivocab_uncased',
|
126
|
+
pooling_strategy='last_hidden_state',
|
127
|
+
)
|
128
|
+
self.text_proj = Linear(768, 768)
|
129
|
+
self.ln_text = LayerNorm(768)
|
130
|
+
# vision
|
131
|
+
self.vision_encoder = VisionTransformer(
|
132
|
+
model_name='microsoft/swin-base-patch4-window7-224', )
|
133
|
+
self.vision_proj = Linear(1024, 768)
|
134
|
+
self.ln_vision = LayerNorm(768)
|
135
|
+
# cross-attention
|
136
|
+
self.gitformer = GITFormer(384, 768)
|
137
|
+
|
138
|
+
self.xtm_head = torch.nn.ModuleDict({
|
139
|
+
'image':
|
140
|
+
Linear(self.gitformer.Qformer.config.hidden_size, 2),
|
141
|
+
'graph':
|
142
|
+
Linear(self.gitformer.Qformer.config.hidden_size, 2),
|
143
|
+
'cs_text':
|
144
|
+
Linear(self.gitformer.Qformer.config.hidden_size, 2),
|
145
|
+
})
|
146
|
+
|
147
|
+
self.xtc_proj = torch.nn.ModuleDict({
|
148
|
+
'image':
|
149
|
+
Linear(self.gitformer.Qformer.config.hidden_size, 768),
|
150
|
+
'graph':
|
151
|
+
Linear(self.gitformer.Qformer.config.hidden_size, 768),
|
152
|
+
'cs_text':
|
153
|
+
Linear(self.gitformer.Qformer.config.hidden_size, 768),
|
154
|
+
})
|
155
|
+
self.temp = torch.nn.Parameter(0.07 * torch.ones([]))
|
156
|
+
self.model_freeze()
|
157
|
+
|
158
|
+
def model_freeze(self) -> None:
|
159
|
+
for param in self.graph_encoder.parameters():
|
160
|
+
param.requires_grad = False
|
161
|
+
|
162
|
+
for param in self.vision_encoder.parameters():
|
163
|
+
param.requires_grad = False
|
164
|
+
|
165
|
+
def forward(
|
166
|
+
self,
|
167
|
+
x: Tensor,
|
168
|
+
edge_index: Tensor,
|
169
|
+
batch: Tensor,
|
170
|
+
edge_attr: Optional[Tensor],
|
171
|
+
smiles: List[str],
|
172
|
+
images: Tensor,
|
173
|
+
captions: List[str],
|
174
|
+
) -> Tensor:
|
175
|
+
batch_size = len(smiles)
|
176
|
+
|
177
|
+
x_vision = self.vision_encoder(images)
|
178
|
+
x_vision = self.vision_proj(x_vision)
|
179
|
+
x_vision = self.ln_vision(x_vision) # [bs, patch_len, d]
|
180
|
+
vision_atts = torch.ones(x_vision.size()[:-1],
|
181
|
+
dtype=torch.long).to(x_vision.device)
|
182
|
+
vision_targets = torch.arange(batch_size).to(x_vision.device)
|
183
|
+
|
184
|
+
x_graph, graph_atts = self.graph_encoder(x, edge_index, batch,
|
185
|
+
edge_attr)
|
186
|
+
x_graph = self.graph_proj(x_graph)
|
187
|
+
x_graph = self.ln_graph(x_graph) # [bs, node_len, d]
|
188
|
+
graph_targets = torch.arange(batch_size).to(x_graph.device)
|
189
|
+
|
190
|
+
x_smiles = self.text_encoder.encode(smiles) # [bs, seq_len, d]
|
191
|
+
smiles_atts = torch.ones(x_smiles.size()[:-1],
|
192
|
+
dtype=torch.long).to(x_smiles.device)
|
193
|
+
smiles_targets = torch.arange(batch_size).to(x_smiles.device)
|
194
|
+
|
195
|
+
caption_input_ids, caption_attention_masks = self.text_encoder.get_input_ids( # noqa: E501
|
196
|
+
captions)
|
197
|
+
|
198
|
+
text_output = self.gitformer.Qformer(
|
199
|
+
caption_input_ids,
|
200
|
+
attention_mask=caption_attention_masks,
|
201
|
+
return_dict=True,
|
202
|
+
)
|
203
|
+
text_feat = F.normalize(
|
204
|
+
self.text_proj(text_output.last_hidden_state[:, 0, :]), dim=-1)
|
205
|
+
|
206
|
+
loss = 0
|
207
|
+
for x_embed, x_atts, x_targets, modal in zip(
|
208
|
+
[x_graph, x_smiles, x_vision],
|
209
|
+
[graph_atts, smiles_atts, vision_atts],
|
210
|
+
[graph_targets, smiles_targets, vision_targets],
|
211
|
+
['graph', 'cs_text', 'image'],
|
212
|
+
):
|
213
|
+
loss += self._calc_xtc_loss(x_embed, x_atts, x_targets, text_feat,
|
214
|
+
modal)
|
215
|
+
loss += self._calc_xtm_loss(x_embed, caption_input_ids,
|
216
|
+
caption_attention_masks, modal)
|
217
|
+
|
218
|
+
return loss / 6
|
219
|
+
|
220
|
+
def _calc_xtm_loss(
|
221
|
+
self,
|
222
|
+
x_embeds: Tensor,
|
223
|
+
input_ids: Tensor,
|
224
|
+
attention_mask: Tensor,
|
225
|
+
modal: str,
|
226
|
+
) -> Tensor:
|
227
|
+
# Initializing lists to hold the original and negative samples
|
228
|
+
x_embeds_list = []
|
229
|
+
text_input_ids_list = []
|
230
|
+
text_attention_mask_list = []
|
231
|
+
|
232
|
+
batch_size = x_embeds.size(0)
|
233
|
+
for i in range(batch_size):
|
234
|
+
# Original samples
|
235
|
+
x_embeds_list.append(x_embeds[i])
|
236
|
+
text_input_ids_list.append(input_ids[i, :])
|
237
|
+
text_attention_mask_list.append(attention_mask[i, :])
|
238
|
+
|
239
|
+
if batch_size > 1:
|
240
|
+
# Negative samples (neg_text_input_ids corresponds to x_embeds)
|
241
|
+
neg_text_input_ids = input_ids[i - 1 if i == batch_size -
|
242
|
+
1 else i + 1, :]
|
243
|
+
neg_text_attention_mask = attention_mask[i -
|
244
|
+
1 if i == batch_size -
|
245
|
+
1 else i + 1, :]
|
246
|
+
text_input_ids_list.append(neg_text_input_ids)
|
247
|
+
text_attention_mask_list.append(neg_text_attention_mask)
|
248
|
+
x_embeds_list.append(x_embeds[i, :])
|
249
|
+
|
250
|
+
# Negative samples (text_input_ids corresponds to neg_x_embeds)
|
251
|
+
neg_x_embeds = x_embeds[i - 1 if i == batch_size - 1 else i +
|
252
|
+
1, :]
|
253
|
+
x_embeds_list.append(neg_x_embeds)
|
254
|
+
text_input_ids_list.append(input_ids[i, :])
|
255
|
+
text_attention_mask_list.append(attention_mask[i, :])
|
256
|
+
|
257
|
+
# Stack all samples into two large tensors
|
258
|
+
x_embeds_all = torch.stack(x_embeds_list, dim=1) \
|
259
|
+
.reshape(-1, x_embeds.size(1), x_embeds.size(2))
|
260
|
+
text_input_ids_all = torch.stack(text_input_ids_list, dim=1) \
|
261
|
+
.reshape(-1, input_ids.size(1))
|
262
|
+
# Create image attention masks for the concatenated tensor
|
263
|
+
image_attns_all = torch.ones(x_embeds_all.size()[:-1],
|
264
|
+
dtype=torch.long).to(x_embeds_all.device)
|
265
|
+
query_tokens_xtm = self.gitformer.query_tokens.expand(
|
266
|
+
text_input_ids_all.shape[0], -1, -1)
|
267
|
+
query_attns_xtm = torch.ones(query_tokens_xtm.size()[:-1],
|
268
|
+
dtype=torch.long).to(x_embeds_all.device)
|
269
|
+
|
270
|
+
output_xtm = self.gitformer.Qformer(
|
271
|
+
inputs_embeds=query_tokens_xtm,
|
272
|
+
attention_mask=query_attns_xtm,
|
273
|
+
encoder_hidden_states=x_embeds_all,
|
274
|
+
encoder_attention_mask=image_attns_all,
|
275
|
+
return_dict=True,
|
276
|
+
).last_hidden_state
|
277
|
+
|
278
|
+
xtm_embeddings = output_xtm[:, :query_tokens_xtm.size(1), :]
|
279
|
+
|
280
|
+
xtm_logit = self.xtm_head[modal](xtm_embeddings).mean(dim=1)
|
281
|
+
# Create labels: 1 for the original samples, 0 for the negative samples
|
282
|
+
if batch_size > 1:
|
283
|
+
labels = torch.cat(
|
284
|
+
[torch.ones(batch_size),
|
285
|
+
torch.zeros(batch_size * 2)], dim=0)
|
286
|
+
else:
|
287
|
+
labels = torch.ones(batch_size)
|
288
|
+
labels = labels.long().to(xtm_logit.device)
|
289
|
+
|
290
|
+
# Calculate cross entropy loss
|
291
|
+
return F.cross_entropy(xtm_logit, labels)
|
292
|
+
|
293
|
+
def _calc_xtc_loss(
|
294
|
+
self,
|
295
|
+
x_embeds: Tensor,
|
296
|
+
x_atts: Tensor,
|
297
|
+
x_targets: Tensor,
|
298
|
+
text_feat: Tensor,
|
299
|
+
modal: str,
|
300
|
+
) -> Tensor:
|
301
|
+
query_tokens = self.gitformer.query_tokens.expand(
|
302
|
+
x_embeds.shape[0], -1, -1)
|
303
|
+
|
304
|
+
query_output = self.gitformer.Qformer(
|
305
|
+
inputs_embeds=query_tokens,
|
306
|
+
encoder_hidden_states=x_embeds,
|
307
|
+
encoder_attention_mask=x_atts,
|
308
|
+
return_dict=True,
|
309
|
+
).last_hidden_state
|
310
|
+
|
311
|
+
x_feats = F.normalize(self.xtc_proj[modal](query_output), dim=-1)
|
312
|
+
|
313
|
+
sim_q2t = torch.matmul(
|
314
|
+
x_feats.unsqueeze(1),
|
315
|
+
text_feat.unsqueeze(-1),
|
316
|
+
).squeeze(-1)
|
317
|
+
|
318
|
+
# modal-text similarity: aggregate across all query tokens
|
319
|
+
sim_x2t, _ = sim_q2t.max(-1)
|
320
|
+
sim_x2t = sim_x2t / self.temp
|
321
|
+
|
322
|
+
# text-query similarity
|
323
|
+
sim_t2q = torch.matmul(
|
324
|
+
text_feat.unsqueeze(1).unsqueeze(1),
|
325
|
+
x_feats.permute(0, 2, 1),
|
326
|
+
).squeeze(-2)
|
327
|
+
|
328
|
+
# text-modal similarity: aggregate across all query tokens
|
329
|
+
sim_t2x, _ = sim_t2q.max(-1)
|
330
|
+
sim_t2x = sim_t2x / self.temp
|
331
|
+
|
332
|
+
loss_itc = (
|
333
|
+
F.cross_entropy(sim_x2t, x_targets, label_smoothing=0.1) +
|
334
|
+
F.cross_entropy(sim_t2x, x_targets, label_smoothing=0.1)) / 2
|
335
|
+
|
336
|
+
return loss_itc
|
@@ -48,6 +48,36 @@ class SentenceTransformer(torch.nn.Module):
|
|
48
48
|
emb = F.normalize(emb, p=2, dim=1)
|
49
49
|
return emb
|
50
50
|
|
51
|
+
def get_input_ids(
|
52
|
+
self,
|
53
|
+
text: List[str],
|
54
|
+
batch_size: Optional[int] = None,
|
55
|
+
output_device: Optional[Union[torch.device, str]] = None,
|
56
|
+
) -> Tensor:
|
57
|
+
is_empty = len(text) == 0
|
58
|
+
text = ['dummy'] if is_empty else text
|
59
|
+
|
60
|
+
batch_size = len(text) if batch_size is None else batch_size
|
61
|
+
|
62
|
+
input_ids: List[Tensor] = []
|
63
|
+
attention_masks: List[Tensor] = []
|
64
|
+
for start in range(0, len(text), batch_size):
|
65
|
+
token = self.tokenizer(
|
66
|
+
text[start:start + batch_size],
|
67
|
+
padding=True,
|
68
|
+
truncation=True,
|
69
|
+
return_tensors='pt',
|
70
|
+
)
|
71
|
+
input_ids.append(token.input_ids.to(self.device))
|
72
|
+
attention_masks.append(token.attention_mask.to(self.device))
|
73
|
+
|
74
|
+
def _out(x: List[Tensor]) -> Tensor:
|
75
|
+
out = torch.cat(x, dim=0) if len(x) > 1 else x[0]
|
76
|
+
out = out[:0] if is_empty else out
|
77
|
+
return out.to(output_device)
|
78
|
+
|
79
|
+
return _out(input_ids), _out(attention_masks)
|
80
|
+
|
51
81
|
@property
|
52
82
|
def device(self) -> torch.device:
|
53
83
|
return next(iter(self.model.parameters())).device
|
@@ -0,0 +1,33 @@
|
|
1
|
+
from typing import Optional, Union
|
2
|
+
|
3
|
+
import torch
|
4
|
+
from torch import Tensor
|
5
|
+
|
6
|
+
|
7
|
+
class VisionTransformer(torch.nn.Module):
|
8
|
+
def __init__(
|
9
|
+
self,
|
10
|
+
model_name: str,
|
11
|
+
) -> None:
|
12
|
+
super().__init__()
|
13
|
+
self.model_name = model_name
|
14
|
+
|
15
|
+
from transformers import SwinConfig, SwinModel
|
16
|
+
|
17
|
+
self.config = SwinConfig.from_pretrained(model_name)
|
18
|
+
self.model = SwinModel(self.config)
|
19
|
+
|
20
|
+
@torch.no_grad()
|
21
|
+
def forward(
|
22
|
+
self,
|
23
|
+
images: Tensor,
|
24
|
+
output_device: Optional[Union[torch.device, str]] = None,
|
25
|
+
) -> Tensor:
|
26
|
+
return self.model(images).last_hidden_state.to(output_device)
|
27
|
+
|
28
|
+
@property
|
29
|
+
def device(self) -> torch.device:
|
30
|
+
return next(iter(self.model.parameters())).device
|
31
|
+
|
32
|
+
def __repr__(self) -> str:
|
33
|
+
return f'{self.__class__.__name__}(model_name={self.model_name})'
|
File without changes
|