pyg-nightly 2.7.0.dev20250124__py3-none-any.whl → 2.7.0.dev20250126__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: pyg-nightly
3
- Version: 2.7.0.dev20250124
3
+ Version: 2.7.0.dev20250126
4
4
  Summary: Graph Neural Network Library for PyTorch
5
5
  Keywords: deep-learning,pytorch,geometric-deep-learning,graph-neural-networks,graph-convolutional-networks
6
6
  Author-email: Matthias Fey <matthias@pyg.org>
@@ -1,4 +1,4 @@
1
- torch_geometric/__init__.py,sha256=woO2qQnJ5h-sALP2KdvQ69AVAzlveBh2Qs87DiEu7A4,1904
1
+ torch_geometric/__init__.py,sha256=urZ1LyKqq-2Oed-4wqdQLl23rnoUSD-BKPA6eHOOy4s,1904
2
2
  torch_geometric/_compile.py,sha256=f-WQeH4VLi5Hn9lrgztFUCSrN_FImjhQa6BxFzcYC38,1338
3
3
  torch_geometric/_onnx.py,sha256=V9ffrIKSqhDw6xUZ12lkuSfNs48cQp2EeJ6Z19GfnVw,349
4
4
  torch_geometric/backend.py,sha256=lVaf7aLoVaB3M-UcByUJ1G4T4FOK6LXAg0CF4W3E8jo,1575
@@ -54,7 +54,7 @@ torch_geometric/data/temporal.py,sha256=WOJ6gFrTLikaLhUvotyUF5ql14FkE5Ox3hNkdSp6
54
54
  torch_geometric/data/view.py,sha256=XjkVSc-UWZFCT4DlXLShZtO8duhFQkS9gq88zZXANsk,1089
55
55
  torch_geometric/data/lightning/__init__.py,sha256=w3En1tJfy3kSqe1MycpOyZpHFO3fxBCgNCUOznPA3YU,178
56
56
  torch_geometric/data/lightning/datamodule.py,sha256=Bn9iaIfE4NWDDWWMqCvBeZ4bIW1Silx_Ol5CPJCliaQ,29242
57
- torch_geometric/datasets/__init__.py,sha256=eqVmuffZnc-O7KBdXO98SNVwSGehT5uy2LAC86MxGO4,6107
57
+ torch_geometric/datasets/__init__.py,sha256=d9nuTCytBvg60lm_WYRAQwjoZxR1H_7JsW8een1k1No,6186
58
58
  torch_geometric/datasets/actor.py,sha256=oUxgJIX8bi5hJr1etWNYIFyVQNDDXi1nyVpHGGMEAGQ,4304
59
59
  torch_geometric/datasets/airfrans.py,sha256=212gYsk7PvF-qcmvM2YXaOBhFrS79evAGg_sPHXih4w,5439
60
60
  torch_geometric/datasets/airports.py,sha256=b3gkv3gY2JkUpmGiz36Z-g7EcnSfU8lBG1YsCOWdJ6k,3758
@@ -104,6 +104,7 @@ torch_geometric/datasets/icews.py,sha256=Vdlk-PD10AU68xq8X5IOgrK0wgIBFq8A0D6_Wtr
104
104
  torch_geometric/datasets/igmc_dataset.py,sha256=pMiOoXjvqhfsDDNw51WT_IVi6wGJ0cUNwTdpEprPh3E,4611
105
105
  torch_geometric/datasets/imdb.py,sha256=QVJbtPPkcLznyvzuxDCxmqO5xXocVG59KhrjXi1qXg0,4232
106
106
  torch_geometric/datasets/infection_dataset.py,sha256=jIYqX0vkCE-3fNjaijzCSmY1RVMFiX3gnmLwkqDXRkI,7293
107
+ torch_geometric/datasets/instruct_mol_dataset.py,sha256=EK_3lRflFYS6KHfPM1AcYtB7FRs2We3GgTu39H9vVKI,4990
107
108
  torch_geometric/datasets/jodie.py,sha256=8CW43ZepM26dk2HMGvXDDF-4BorBeegqegViWyeYOks,3643
108
109
  torch_geometric/datasets/karate.py,sha256=khCcCUEaw7FuYBKwEsOoogpTShKYnx5nXrRtCOAoEAU,3462
109
110
  torch_geometric/datasets/last_fm.py,sha256=jKM3gw7T5x4AlUtmA0TXB2iWpNMi-S-ME2bP37kzE3Q,4581
@@ -629,6 +630,6 @@ torch_geometric/utils/undirected.py,sha256=H_nfpI0_WluOG6VfjPyldvcjL4w5USAKWu2x5
629
630
  torch_geometric/visualization/__init__.py,sha256=PyR_4K5SafsJrBr6qWrkjKr6GBL1b7FtZybyXCDEVwY,154
630
631
  torch_geometric/visualization/graph.py,sha256=ZuLPL92yGRi7lxlqsUPwL_EVVXF7P2kMcveTtW79vpA,4784
631
632
  torch_geometric/visualization/influence.py,sha256=CWMvuNA_Nf1sfbJmQgn58yS4OFpeKXeZPe7kEuvkUBw,477
632
- pyg_nightly-2.7.0.dev20250124.dist-info/WHEEL,sha256=CpUCUxeHQbRN5UGRQHYRJorO5Af-Qy_fHMctcQ8DSGI,82
633
- pyg_nightly-2.7.0.dev20250124.dist-info/METADATA,sha256=Rzd4j6F9Pjy0ArOM92QFdu_jUeoFnqyDAbIlh136kzk,62977
634
- pyg_nightly-2.7.0.dev20250124.dist-info/RECORD,,
633
+ pyg_nightly-2.7.0.dev20250126.dist-info/WHEEL,sha256=CpUCUxeHQbRN5UGRQHYRJorO5Af-Qy_fHMctcQ8DSGI,82
634
+ pyg_nightly-2.7.0.dev20250126.dist-info/METADATA,sha256=g1OMv25SQV6iB9NN9qWPzTDvigR6j9StZi-NMOWtE4U,62977
635
+ pyg_nightly-2.7.0.dev20250126.dist-info/RECORD,,
@@ -30,7 +30,7 @@ from .lazy_loader import LazyLoader
30
30
  contrib = LazyLoader('contrib', globals(), 'torch_geometric.contrib')
31
31
  graphgym = LazyLoader('graphgym', globals(), 'torch_geometric.graphgym')
32
32
 
33
- __version__ = '2.7.0.dev20250124'
33
+ __version__ = '2.7.0.dev20250126'
34
34
 
35
35
  __all__ = [
36
36
  'Index',
@@ -79,6 +79,7 @@ from .neurograph import NeuroGraphDataset
79
79
  from .web_qsp_dataset import WebQSPDataset, CWQDataset
80
80
  from .git_mol_dataset import GitMolDataset
81
81
  from .molecule_gpt_dataset import MoleculeGPTDataset
82
+ from .instruct_mol_dataset import InstructMolDataset
82
83
  from .tag_dataset import TAGDataset
83
84
 
84
85
  from .dbp15k import DBP15K
@@ -196,6 +197,7 @@ homo_datasets = [
196
197
  'CWQDataset',
197
198
  'GitMolDataset',
198
199
  'MoleculeGPTDataset',
200
+ 'InstructMolDataset',
199
201
  'TAGDataset',
200
202
  ]
201
203
 
@@ -0,0 +1,134 @@
1
+ import json
2
+ import sys
3
+ from typing import Callable, List, Optional
4
+
5
+ import torch
6
+ from tqdm import tqdm
7
+
8
+ from torch_geometric.data import Data, InMemoryDataset
9
+ from torch_geometric.io import fs
10
+ from torch_geometric.utils import one_hot
11
+
12
+
13
+ class InstructMolDataset(InMemoryDataset):
14
+ r"""The dataset from the `"InstructMol: Multi-Modal Integration for
15
+ Building a Versatile and Reliable Molecular Assistant in Drug Discovery"
16
+ <https://arxiv.org/pdf/2311.16208>`_ paper.
17
+
18
+ Args:
19
+ root (str): Root directory where the dataset should be saved.
20
+ transform (callable, optional): A function/transform that takes in an
21
+ :obj:`torch_geometric.data.Data` object and returns a transformed
22
+ version. The data object will be transformed before every access.
23
+ (default: :obj:`None`)
24
+ pre_transform (callable, optional): A function/transform that takes in
25
+ an :obj:`torch_geometric.data.Data` object and returns a
26
+ transformed version. The data object will be transformed before
27
+ being saved to disk. (default: :obj:`None`)
28
+ pre_filter (callable, optional): A function that takes in an
29
+ :obj:`torch_geometric.data.Data` object and returns a boolean
30
+ value, indicating whether the data object should be included in the
31
+ final dataset. (default: :obj:`None`)
32
+ force_reload (bool, optional): Whether to re-process the dataset.
33
+ (default: :obj:`False`)
34
+ """
35
+ raw_url = 'https://huggingface.co/datasets/OpenMol/PubChemSFT/blob/main'
36
+
37
+ def __init__(
38
+ self,
39
+ root: str,
40
+ transform: Optional[Callable] = None,
41
+ pre_transform: Optional[Callable] = None,
42
+ pre_filter: Optional[Callable] = None,
43
+ force_reload: bool = False,
44
+ ):
45
+ super().__init__(root, transform, pre_transform, pre_filter,
46
+ force_reload=force_reload)
47
+ self.load(self.processed_paths[0])
48
+
49
+ @property
50
+ def raw_file_names(self) -> List[str]:
51
+ return ['all_clean.json']
52
+
53
+ @property
54
+ def processed_file_names(self) -> List[str]:
55
+ return ['data.pt']
56
+
57
+ def download(self) -> None:
58
+ print('downloading dataset...')
59
+ fs.cp(f'{self.raw_url}/all_clean.json', self.raw_dir)
60
+
61
+ def process(self) -> None:
62
+ try:
63
+ from rdkit import Chem
64
+ from rdkit.Chem.rdchem import BondType as BT
65
+ WITH_RDKIT = True
66
+
67
+ except ImportError:
68
+ WITH_RDKIT = False
69
+
70
+ if not WITH_RDKIT:
71
+ print(("Using a pre-processed version of the dataset. Please "
72
+ "install 'rdkit' to alternatively process the raw data."),
73
+ file=sys.stderr)
74
+
75
+ data_list = fs.torch_load(self.raw_paths[0])
76
+ data_list = [Data(**data_dict) for data_dict in data_list]
77
+
78
+ if self.pre_filter is not None:
79
+ data_list = [d for d in data_list if self.pre_filter(d)]
80
+
81
+ if self.pre_transform is not None:
82
+ data_list = [self.pre_transform(d) for d in data_list]
83
+
84
+ self.save(data_list, self.processed_paths[0])
85
+ return
86
+
87
+ # types of atom and bond
88
+ types = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4, 'Unknow': 5}
89
+ bonds = {BT.SINGLE: 0, BT.DOUBLE: 1, BT.TRIPLE: 2, BT.AROMATIC: 3}
90
+
91
+ # load data
92
+ mols = json.load(open(f'{self.raw_dir}/all_clean.json'))
93
+
94
+ data_list = []
95
+ for smiles, qa_pairs in tqdm(mols.items(), total=len(mols)):
96
+ mol = Chem.MolFromSmiles(smiles)
97
+ if mol is None:
98
+ continue
99
+
100
+ x: torch.Tensor = torch.tensor([
101
+ types[atom.GetSymbol()] if atom.GetSymbol() in types else 5
102
+ for atom in mol.GetAtoms()
103
+ ])
104
+ x = one_hot(x, num_classes=len(types), dtype=torch.float)
105
+
106
+ rows, cols, edge_types = [], [], []
107
+ for bond in mol.GetBonds():
108
+ i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
109
+ edge_types += [bonds[bond.GetBondType()]] * 2
110
+ rows += [i, j]
111
+ cols += [j, i]
112
+
113
+ edge_index = torch.tensor([rows, cols], dtype=torch.long)
114
+ edge_type = torch.tensor(edge_types, dtype=torch.long)
115
+ edge_attr = one_hot(edge_type, num_classes=len(bonds))
116
+
117
+ for question, answer in qa_pairs:
118
+ data = Data(
119
+ x=x,
120
+ edge_index=edge_index,
121
+ edge_attr=edge_attr,
122
+ smiles=smiles,
123
+ instruction=question,
124
+ y=answer,
125
+ )
126
+
127
+ if self.pre_filter is not None and not self.pre_filter(data):
128
+ continue
129
+ if self.pre_transform is not None:
130
+ data = self.pre_transform(data)
131
+
132
+ data_list.append(data)
133
+
134
+ self.save(data_list, self.processed_paths[0])