pyg-nightly 2.7.0.dev20241119__py3-none-any.whl → 2.7.0.dev20241121__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: pyg-nightly
3
- Version: 2.7.0.dev20241119
3
+ Version: 2.7.0.dev20241121
4
4
  Summary: Graph Neural Network Library for PyTorch
5
5
  Keywords: deep-learning,pytorch,geometric-deep-learning,graph-neural-networks,graph-convolutional-networks
6
6
  Author-email: Matthias Fey <matthias@pyg.org>
@@ -1,4 +1,4 @@
1
- torch_geometric/__init__.py,sha256=spxW7Bk1ADYtDbAY5o7hc4aHzY-HMhp_JzJaHacQX30,1904
1
+ torch_geometric/__init__.py,sha256=KxeHpFIYYrXJ-wesw35LT5EYOhXnC86_S8hRPAfpOy4,1904
2
2
  torch_geometric/_compile.py,sha256=f-WQeH4VLi5Hn9lrgztFUCSrN_FImjhQa6BxFzcYC38,1338
3
3
  torch_geometric/_onnx.py,sha256=V9ffrIKSqhDw6xUZ12lkuSfNs48cQp2EeJ6Z19GfnVw,349
4
4
  torch_geometric/backend.py,sha256=lVaf7aLoVaB3M-UcByUJ1G4T4FOK6LXAg0CF4W3E8jo,1575
@@ -53,7 +53,7 @@ torch_geometric/data/temporal.py,sha256=WOJ6gFrTLikaLhUvotyUF5ql14FkE5Ox3hNkdSp6
53
53
  torch_geometric/data/view.py,sha256=XjkVSc-UWZFCT4DlXLShZtO8duhFQkS9gq88zZXANsk,1089
54
54
  torch_geometric/data/lightning/__init__.py,sha256=w3En1tJfy3kSqe1MycpOyZpHFO3fxBCgNCUOznPA3YU,178
55
55
  torch_geometric/data/lightning/datamodule.py,sha256=Bn9iaIfE4NWDDWWMqCvBeZ4bIW1Silx_Ol5CPJCliaQ,29242
56
- torch_geometric/datasets/__init__.py,sha256=fey-955PyCQXGBeUTNPWwU5uK3PJOEvaY1_fDt1SxXc,5880
56
+ torch_geometric/datasets/__init__.py,sha256=f9YqoX9WTSVMzjuLfFD_eCsC4iQ5kbFNQiZru3n6qw0,6013
57
57
  torch_geometric/datasets/actor.py,sha256=oUxgJIX8bi5hJr1etWNYIFyVQNDDXi1nyVpHGGMEAGQ,4304
58
58
  torch_geometric/datasets/airfrans.py,sha256=212gYsk7PvF-qcmvM2YXaOBhFrS79evAGg_sPHXih4w,5439
59
59
  torch_geometric/datasets/airports.py,sha256=b3gkv3gY2JkUpmGiz36Z-g7EcnSfU8lBG1YsCOWdJ6k,3758
@@ -113,6 +113,7 @@ torch_geometric/datasets/md17.py,sha256=BD6LU2xm6_ycXVk6r4O0poNt5Sr_PJ2P1QjNqIOL
113
113
  torch_geometric/datasets/mixhop_synthetic_dataset.py,sha256=4NNvTHUvvV6pcqQCyVDS5XhppXUeF2H9GTfFoc49eyU,3951
114
114
  torch_geometric/datasets/mnist_superpixels.py,sha256=o2ArbZ0_OE0u8VCaHmWwvngESlOFr9oM9dSEP_tjAS4,3340
115
115
  torch_geometric/datasets/modelnet.py,sha256=-qmLjlQiKVWmtHefAIIE97dQxEcaBfetMJnvgYZuwkg,5347
116
+ torch_geometric/datasets/molecule_gpt_dataset.py,sha256=XE14wgPVBm2kVLYL6NgXUDhv4QGHxVISG-VWEwO7hfA,18754
116
117
  torch_geometric/datasets/molecule_net.py,sha256=VNWLEDulFID8mLsxgN8q1T-O3M2i0n0Si5ISwEZezMU,7379
117
118
  torch_geometric/datasets/movie_lens.py,sha256=M4Bu0Xus8IkW8GYzjxPxSdPXNbcCCx9cu6cncxBvLx8,4033
118
119
  torch_geometric/datasets/movie_lens_100k.py,sha256=eTpBAteM3jqTEtiwLxmhVj4r8JvftvPx8Hvs-3ZIHlU,6057
@@ -144,6 +145,7 @@ torch_geometric/datasets/shapenet.py,sha256=tn3HiQQAr6lxHrqxfOVaAtl40guwFYTXWCbS
144
145
  torch_geometric/datasets/shrec2016.py,sha256=cTLhctbqE0EUEvKddJFhPzDb1oLKXOth4O_WzsWtyMk,6323
145
146
  torch_geometric/datasets/snap_dataset.py,sha256=r3sC-dHDouyaYoHGdoBY0uO0qOOvD6_Hb96d2ceGMZk,9433
146
147
  torch_geometric/datasets/suite_sparse.py,sha256=eqjH4vAUq872qdk3YdLkZSwlu6r7HHpTgK0vEVGmY1s,3278
148
+ torch_geometric/datasets/tag_dataset.py,sha256=0fzOsakR9L9CK6ppGN-USD4-Vq-ssbQ2Xovw2nqqtWo,14759
147
149
  torch_geometric/datasets/taobao.py,sha256=CUcZpbWsNTasevflO8zqP0YvENy89P7wpKS4MHaDJ6Q,4170
148
150
  torch_geometric/datasets/tosca.py,sha256=nUSF8NQT1GlkwWQLshjWmr8xORsvRHzzIqhUyDCvABc,4632
149
151
  torch_geometric/datasets/tu_dataset.py,sha256=14OSaXBgVwT1dX2h1wZ3xVIwoo0GQBEfR3yWh6Q0VF0,7847
@@ -324,8 +326,9 @@ torch_geometric/nn/aggr/set_transformer.py,sha256=FG7_JizpFX14M6VSCwLSjYXYdJ1ZiQ
324
326
  torch_geometric/nn/aggr/sort.py,sha256=bvOOWnFkNOBOZih4rqVZQsjfeDX3vmXo1bpPSFD846w,2507
325
327
  torch_geometric/nn/aggr/utils.py,sha256=CLJ-ZrVWYIOBpdhQBLAz94dj3cMKKKc3qwGr4DFbiCU,8338
326
328
  torch_geometric/nn/aggr/variance_preserving.py,sha256=fu-U_aGYpVLpgSFvVg0ONMe6nqoyv8tZ6Y35qMYTf9w,1126
327
- torch_geometric/nn/attention/__init__.py,sha256=Ip6n4xbUbhJhrmPO9LjvHq0nNQe-yxiC4WHyOYOrHJc,76
329
+ torch_geometric/nn/attention/__init__.py,sha256=1lCB7zh7uM6FkpW81S9U4CvxTwpCkz59KatPTIE9UmA,127
328
330
  torch_geometric/nn/attention/performer.py,sha256=2PCDn4_-oNTao2-DkXIaoi18anP01OxRELF2pvp-jk8,7357
331
+ torch_geometric/nn/attention/qformer.py,sha256=7J-pWm_vpumK38IC-iCBz4oqL-BEIofEIxJ0wfjWq9A,2338
329
332
  torch_geometric/nn/conv/__init__.py,sha256=37zTdt0gfSAUPMtwXjZg5mWx_itojJVFNODYR1h1ch0,3515
330
333
  torch_geometric/nn/conv/agnn_conv.py,sha256=5nEPLx_BBHcDaO6HWzLuHfXc0Yd_reKynAOH0Iq09lU,3077
331
334
  torch_geometric/nn/conv/antisymmetric_conv.py,sha256=dhA6sCETy1jlXReYJZBSyToOcL_mZ1wL10fMIb8Ppuw,4387
@@ -417,7 +420,7 @@ torch_geometric/nn/kge/distmult.py,sha256=dGQ0bVzjreZgFN1lXE23_IIidsiOq7ehPrMb-N
417
420
  torch_geometric/nn/kge/loader.py,sha256=5Uc1j3OUMQnBYSHDqL7pLCty1siFLzoPkztigYO2zP8,771
418
421
  torch_geometric/nn/kge/rotate.py,sha256=XLuO1AbyTt5cJxr97ZzoyAyIEsHKesgW5TvDmnGJAao,3208
419
422
  torch_geometric/nn/kge/transe.py,sha256=jlejq5BLMm-sb1wWcLDp7pZqCdelWBgjDIC8ctbjSdU,3088
420
- torch_geometric/nn/models/__init__.py,sha256=RpYFFqaYWq1BVMF3Fs-EQo-QZDdLQjIHPdkl3d2MOW4,2017
423
+ torch_geometric/nn/models/__init__.py,sha256=dr2-YsRzUdVBM6Ut78FB9Wbjn-kzV0gPwOlWGPdQLY4,2108
421
424
  torch_geometric/nn/models/attentive_fp.py,sha256=tkgvw28wg9-JqHIfBllfCwTHrZIUiv85yZJcDqjz3z0,6634
422
425
  torch_geometric/nn/models/autoencoder.py,sha256=nGje-zty78Y3hxOJ9o0_6QziJjOvBlknk6z0_fDQwQU,10770
423
426
  torch_geometric/nn/models/basic_gnn.py,sha256=PGa0RUMyvrNy_5yRI2jX_zwPsmZXwOQWfsWvxOiHsSk,31225
@@ -428,6 +431,7 @@ torch_geometric/nn/models/deepgcn.py,sha256=tIgT03cj8MghYlxEozpoGvGG_CwpJrGDxv1Z
428
431
  torch_geometric/nn/models/dimenet.py,sha256=Kc5p-rB5q-0e8lY22l-OdQTscTxJh2lTEpeRFMdL4RY,36186
429
432
  torch_geometric/nn/models/dimenet_utils.py,sha256=Eyn_EiJqwKvuYj6BtRpSxrzMG3v4Gk98X9MxZ7uvwm4,5069
430
433
  torch_geometric/nn/models/g_retriever.py,sha256=VueRImNJlh1WvRWcsSXliSw8RlxlzWlu2WSFs_VQaJc,7749
434
+ torch_geometric/nn/models/glem.py,sha256=gqQF4jlU7U_u5-zGeJZuHiEqhSXa-wLU5TghN4u5fYY,16389
431
435
  torch_geometric/nn/models/gnnff.py,sha256=15dkiLgy0LmH1hnUrpeoHioIp4BPTfjpVATpnGRt9E0,7860
432
436
  torch_geometric/nn/models/graph_mixer.py,sha256=mthMeCOikR8gseEsu4oJ3Cd9C35zHSv1p32ROwnG-6s,9246
433
437
  torch_geometric/nn/models/graph_unet.py,sha256=N8TSmJo8AlbZjjcame0xW_jZvMOirL5ahw6qv5Yjpbs,5586
@@ -439,6 +443,7 @@ torch_geometric/nn/models/mask_label.py,sha256=B2HcL6ZkaUEo3a8nebZoUqEIfDEfcIGOV
439
443
  torch_geometric/nn/models/meta.py,sha256=lQWovjdQgTGT_rDAm6L186ObINeQCD9tLBz8xenmrF0,6540
440
444
  torch_geometric/nn/models/metapath2vec.py,sha256=nxttGe4QVWr4teYEoNz8uHRu-yVsLSZPOeF_tz0bj2o,10788
441
445
  torch_geometric/nn/models/mlp.py,sha256=rdwUFxxxqLjXK-iy1L1sXiwSNwAfqTlvHLaqVZ-jwCs,10315
446
+ torch_geometric/nn/models/molecule_gpt.py,sha256=k-XULH6jaurj-R2EE4sIWTkqlNqa3CzWxfQgfFa-G8s,7637
442
447
  torch_geometric/nn/models/neural_fingerprint.py,sha256=pTLJgU9Uh2Lnf9bggLj4cKI8YdEFcMF-9MALuubqbuQ,2378
443
448
  torch_geometric/nn/models/node2vec.py,sha256=U-VhJlvt5lT-JShFrF5tN84wCPqoVuftLVNyOVXs0OU,7664
444
449
  torch_geometric/nn/models/pmlp.py,sha256=dcAASVSyQMMhItSfEJWPeAFh0R3tNCwAHwdrShwQ8o4,3538
@@ -450,8 +455,8 @@ torch_geometric/nn/models/signed_gcn.py,sha256=J40CnedFIqtKI1LhW1ITSEFRbA_XiJZL6
450
455
  torch_geometric/nn/models/tgn.py,sha256=kEGdfLJybkbMT4UMoAh2nCzfX3_nDjfm1cicuPHEwAM,11878
451
456
  torch_geometric/nn/models/visnet.py,sha256=97OFMCsPDEI5BCSi7RhoRcU2CNRp7zck2tEzrltFZj4,43192
452
457
  torch_geometric/nn/nlp/__init__.py,sha256=JJESTA7w_K8v60XbCd25IqmrKKHLz5OiNexMHYGV2mE,138
453
- torch_geometric/nn/nlp/llm.py,sha256=_penl2qkDMeVtlwGPrl7UuyxBh6ILtdiLHmrUNQHkYc,11731
454
- torch_geometric/nn/nlp/sentence_transformer.py,sha256=JrTN3W1srdkNX7qYDGB08mY5615i5nfEJSTHAdd5EuA,3260
458
+ torch_geometric/nn/nlp/llm.py,sha256=M15Qn0yHyA6HL2rHCH2p4H6hKjUvLfnzlxdfEFvRxSA,11732
459
+ torch_geometric/nn/nlp/sentence_transformer.py,sha256=VzMtNUYk6FvOVc3PdVets9_2Sb2FdQbzu9H3m6teRlI,3417
455
460
  torch_geometric/nn/norm/__init__.py,sha256=u2qIDrkbeuObGVXSAIftAlvSd6ouGTtxznCfD-59UiA,669
456
461
  torch_geometric/nn/norm/batch_norm.py,sha256=sJKrinHGwA-noIgteg1RD2W06rd0zskD-rXuY-36glY,8283
457
462
  torch_geometric/nn/norm/diff_group_norm.py,sha256=b57XvNekrUYGDjNJlGeqvaMGNJmHwopSF0_yyBWlLuA,4722
@@ -618,6 +623,6 @@ torch_geometric/utils/undirected.py,sha256=H_nfpI0_WluOG6VfjPyldvcjL4w5USAKWu2x5
618
623
  torch_geometric/visualization/__init__.py,sha256=PyR_4K5SafsJrBr6qWrkjKr6GBL1b7FtZybyXCDEVwY,154
619
624
  torch_geometric/visualization/graph.py,sha256=ZuLPL92yGRi7lxlqsUPwL_EVVXF7P2kMcveTtW79vpA,4784
620
625
  torch_geometric/visualization/influence.py,sha256=CWMvuNA_Nf1sfbJmQgn58yS4OFpeKXeZPe7kEuvkUBw,477
621
- pyg_nightly-2.7.0.dev20241119.dist-info/WHEEL,sha256=CpUCUxeHQbRN5UGRQHYRJorO5Af-Qy_fHMctcQ8DSGI,82
622
- pyg_nightly-2.7.0.dev20241119.dist-info/METADATA,sha256=3Y-GTdZXsDzzOxIrxa35EwttZ_dPAwF2jkLotBJ9ubg,62979
623
- pyg_nightly-2.7.0.dev20241119.dist-info/RECORD,,
626
+ pyg_nightly-2.7.0.dev20241121.dist-info/WHEEL,sha256=CpUCUxeHQbRN5UGRQHYRJorO5Af-Qy_fHMctcQ8DSGI,82
627
+ pyg_nightly-2.7.0.dev20241121.dist-info/METADATA,sha256=GZC_8xTwtszZfiAKIcN4yjzbZCrGmYbnpl5lr98v8eg,62979
628
+ pyg_nightly-2.7.0.dev20241121.dist-info/RECORD,,
@@ -30,7 +30,7 @@ from .lazy_loader import LazyLoader
30
30
  contrib = LazyLoader('contrib', globals(), 'torch_geometric.contrib')
31
31
  graphgym = LazyLoader('graphgym', globals(), 'torch_geometric.graphgym')
32
32
 
33
- __version__ = '2.7.0.dev20241119'
33
+ __version__ = '2.7.0.dev20241121'
34
34
 
35
35
  __all__ = [
36
36
  'Index',
@@ -77,6 +77,8 @@ from .myket import MyketDataset
77
77
  from .brca_tgca import BrcaTcga
78
78
  from .neurograph import NeuroGraphDataset
79
79
  from .web_qsp_dataset import WebQSPDataset
80
+ from .molecule_gpt_dataset import MoleculeGPTDataset
81
+ from .tag_dataset import TAGDataset
80
82
 
81
83
  from .dbp15k import DBP15K
82
84
  from .aminer import AMiner
@@ -190,6 +192,8 @@ homo_datasets = [
190
192
  'BrcaTcga',
191
193
  'NeuroGraphDataset',
192
194
  'WebQSPDataset',
195
+ 'MoleculeGPTDataset',
196
+ 'TAGDataset',
193
197
  ]
194
198
 
195
199
  hetero_datasets = [
@@ -0,0 +1,480 @@
1
+ import gzip
2
+ import json
3
+ import multiprocessing
4
+ import os
5
+ import sys
6
+ from collections import defaultdict
7
+ from multiprocessing import Pool
8
+ from typing import Callable, List, Optional, Tuple
9
+
10
+ import numpy as np
11
+ import requests
12
+ import torch
13
+ from tqdm import tqdm
14
+
15
+ from torch_geometric.data import Data, InMemoryDataset, download_url
16
+ from torch_geometric.io import fs
17
+ from torch_geometric.nn.nlp import LLM
18
+ from torch_geometric.utils import one_hot
19
+
20
+
21
+ def clean_up_description(description: str) -> str:
22
+ description = description + " "
23
+
24
+ # extra adj Pure
25
+ if description.startswith("Pure "):
26
+ description = description.replace("Pure ", "")
27
+ # fix typo
28
+ if description.startswith("Mercurycombines"):
29
+ description = description.replace("Mercurycombines",
30
+ "Mercury combines")
31
+
32
+ # a special case
33
+ description = description.replace(
34
+ "17-Hydroxy-6-methylpregna-3,6-diene-3,20-dione. ",
35
+ "17-Hydroxy-6-methylpregna-3,6-diene-3,20-dione is ")
36
+
37
+ # a special case
38
+ description = description.replace("5-Thymidylic acid. ",
39
+ "5-Thymidylic acid. is ")
40
+
41
+ # a special case
42
+ description = description.replace(
43
+ "5'-S-(3-Amino-3-carboxypropyl)-5'-thioadenosine. ",
44
+ "5'-S-(3-Amino-3-carboxypropyl)-5'-thioadenosine. is ")
45
+
46
+ # a special case
47
+ description = description.replace(
48
+ ("Guanosine 5'-(trihydrogen diphosphate), monoanhydride"
49
+ " with phosphorothioic acid. "),
50
+ ("Guanosine 5'-(trihydrogen diphosphate), monoanhydride"
51
+ " with phosphorothioic acid is "))
52
+
53
+ # a special case
54
+ description = description.replace("5'-Uridylic acid. ",
55
+ "5'-Uridylic acid is ")
56
+
57
+ # a special case
58
+ description = description.replace("5'-Adenylic acid, ",
59
+ "5'-Adenylic acid is ")
60
+
61
+ # a special case
62
+ description = description.replace(
63
+ "Uridine 5'-(tetrahydrogen triphosphate). ",
64
+ "Uridine 5'-(tetrahydrogen triphosphate). is ")
65
+
66
+ # a special case
67
+ description = description.replace("Inosine 5'-Monophosphate. ",
68
+ "Inosine 5'-Monophosphate. is ")
69
+
70
+ # a special case
71
+ description = description.replace("Pivaloyloxymethyl butyrate (AN-9), ",
72
+ "Pivaloyloxymethyl butyrate (AN-9) is ")
73
+
74
+ # a special case
75
+ description = description.replace(
76
+ "4-Amino-5-cyano-7-(D-ribofuranosyl)-7H- pyrrolo(2,3-d)pyrimidine. ",
77
+ "4-Amino-5-cyano-7-(D-ribofuranosyl)-7H- pyrrolo(2,3-d)pyrimidine is ")
78
+
79
+ # a special case
80
+ description = description.replace(
81
+ "Cardamonin (also known as Dihydroxymethoxychalcone), ",
82
+ "Cardamonin (also known as Dihydroxymethoxychalcone) is ")
83
+
84
+ # a special case
85
+ description = description.replace("Lithium has been used to treat ",
86
+ "Lithium is ")
87
+
88
+ # a special case
89
+ description = description.replace("4,4'-Methylenebis ",
90
+ "4,4'-Methylenebis is ")
91
+
92
+ # a special case
93
+ description = description.replace(
94
+ "2,3,7,8-Tetrachlorodibenzo-p-dioxin",
95
+ "2,3,7,8-Tetrachlorodibenzo-p-dioxin is ")
96
+
97
+ # a special case
98
+ description = description.replace("Exposure to 2,4,5-trichlorophenol ",
99
+ "2,4,5-Trichlorophenol exposure ")
100
+
101
+ index = 0
102
+ L = len(description)
103
+ if description.startswith('C.I. '):
104
+ start_index = len('C.I. ')
105
+ elif description.startswith('Nectriapyrone. D '):
106
+ start_index = len('Nectriapyrone. D ')
107
+ elif description.startswith(
108
+ 'Salmonella enterica sv. Minnesota LPS core oligosaccharide'):
109
+ start_index = len(
110
+ 'Salmonella enterica sv. Minnesota LPS core oligosaccharide')
111
+ else:
112
+ start_index = 0
113
+ for index in range(start_index, L - 1):
114
+ if index < L - 2:
115
+ if description[index] == '.' and description[
116
+ index + 1] == ' ' and 'A' <= description[index + 2] <= 'Z':
117
+ break
118
+ elif index == L - 2:
119
+ break
120
+
121
+ first_sentence = description[:index + 1]
122
+ return first_sentence
123
+
124
+
125
+ def extract_name(name_raw: str, description: str) -> Tuple[str, str, str]:
126
+ first_sentence = clean_up_description(description)
127
+
128
+ splitter = ' -- -- '
129
+ if ' are ' in first_sentence or ' were ' in first_sentence:
130
+ replaced_words = 'These molecules'
131
+ else:
132
+ replaced_words = 'This molecule'
133
+
134
+ first_sentence = first_sentence.replace(' is ', splitter)
135
+ first_sentence = first_sentence.replace(' are ', splitter)
136
+ first_sentence = first_sentence.replace(' was ', splitter)
137
+ first_sentence = first_sentence.replace(' were ', splitter)
138
+ first_sentence = first_sentence.replace(' appears ', splitter)
139
+ first_sentence = first_sentence.replace(' occurs ', splitter)
140
+ first_sentence = first_sentence.replace(' stands for ', splitter)
141
+ first_sentence = first_sentence.replace(' belongs to ', splitter)
142
+ first_sentence = first_sentence.replace(' exists ',
143
+ splitter) # only for CID=11443
144
+ first_sentence = first_sentence.replace(' has been used in trials ',
145
+ splitter)
146
+ first_sentence = first_sentence.replace(' has been investigated ',
147
+ splitter)
148
+ first_sentence = first_sentence.replace(' has many uses ', splitter)
149
+
150
+ if splitter in first_sentence:
151
+ extracted_name = first_sentence.split(splitter, 1)[0]
152
+ elif first_sentence.startswith(name_raw):
153
+ extracted_name = name_raw
154
+ elif name_raw in first_sentence:
155
+ extracted_name = name_raw
156
+ extracted_name = None
157
+ print("=====", name_raw)
158
+ print("first sentence: ", first_sentence)
159
+ else:
160
+ extracted_name = None
161
+
162
+ if extracted_name is not None:
163
+ extracted_description = description.replace(extracted_name,
164
+ replaced_words)
165
+ else:
166
+ extracted_description = description
167
+
168
+ return extracted_name, extracted_description, first_sentence
169
+
170
+
171
+ class MoleculeGPTDataset(InMemoryDataset):
172
+ r"""The dataset from the `"MoleculeGPT: Instruction Following Large
173
+ Language Models for Molecular Property Prediction"
174
+ <https://ai4d3.github.io/papers/34.pdf>`_ paper.
175
+
176
+ Args:
177
+ root (str): Root directory where the dataset should be saved.
178
+ transform (callable, optional): A function/transform that takes in an
179
+ :obj:`torch_geometric.data.Data` object and returns a transformed
180
+ version. The data object will be transformed before every access.
181
+ (default: :obj:`None`)
182
+ pre_transform (callable, optional): A function/transform that takes in
183
+ an :obj:`torch_geometric.data.Data` object and returns a
184
+ transformed version. The data object will be transformed before
185
+ being saved to disk. (default: :obj:`None`)
186
+ pre_filter (callable, optional): A function that takes in an
187
+ :obj:`torch_geometric.data.Data` object and returns a boolean
188
+ value, indicating whether the data object should be included in the
189
+ final dataset. (default: :obj:`None`)
190
+ force_reload (bool, optional): Whether to re-process the dataset.
191
+ (default: :obj:`False`)
192
+ total_page_num (int, optional): The number of pages from PubChem.
193
+ (default: :obj:`10`)
194
+ total_block_num (int, optional): The blocks of SDF files from PubChem.
195
+ (default: :obj:`1`)
196
+ """
197
+ description_url = (
198
+ 'https://pubchem.ncbi.nlm.nih.gov/rest/pug_view/annotations/'
199
+ 'heading/json?heading_type=Compound&heading=Record+Description&page={}'
200
+ )
201
+ compound_url = ('https://ftp.ncbi.nlm.nih.gov/pubchem/Compound/'
202
+ 'CURRENT-Full/SDF')
203
+
204
+ def __init__(
205
+ self,
206
+ root: str,
207
+ transform: Optional[Callable] = None,
208
+ pre_transform: Optional[Callable] = None,
209
+ pre_filter: Optional[Callable] = None,
210
+ force_reload: bool = False,
211
+ total_page_num: int = 10,
212
+ total_block_num: int = 1,
213
+ ):
214
+ self.total_page_num = total_page_num
215
+ self.total_block_num = total_block_num
216
+
217
+ super().__init__(root, transform, pre_transform, pre_filter,
218
+ force_reload=force_reload)
219
+ self.load(self.processed_paths[0])
220
+
221
+ @property
222
+ def raw_file_names(self) -> List[str]:
223
+ return ['pubchem.csv']
224
+
225
+ @property
226
+ def processed_file_names(self) -> List[str]:
227
+ return ['data.pt']
228
+
229
+ def download(self) -> None:
230
+ # Step 01. Extract description
231
+ step1_folder = f"{self.raw_dir}/step_01_PubChemSTM_description"
232
+ if not os.path.exists(step1_folder):
233
+ os.makedirs(step1_folder)
234
+ valid_CID_set = set()
235
+ CID2name_raw, CID2name_extracted = defaultdict(list), defaultdict(
236
+ list)
237
+ CID2text_raw, CID2text_extracted = defaultdict(list), defaultdict(
238
+ list)
239
+
240
+ for page_index in tqdm(range(self.total_page_num)):
241
+ page_num = page_index + 1
242
+ f_out = open(
243
+ f"{step1_folder}/Compound_description_{page_num}.txt", "w")
244
+
245
+ description_data = requests.get(
246
+ self.description_url.format(page_num)).json()
247
+
248
+ description_data = description_data["Annotations"]
249
+ assert description_data["Page"] == page_num
250
+
251
+ record_list = description_data["Annotation"]
252
+
253
+ for record in record_list:
254
+ try:
255
+ CID = record["LinkedRecords"]["CID"][0]
256
+ if "Name" in record:
257
+ name_raw = record["Name"]
258
+ CID2name_raw[CID].append(name_raw)
259
+ else:
260
+ name_raw = None
261
+
262
+ data_list = record["Data"]
263
+ for data in data_list:
264
+ description = data["Value"]["StringWithMarkup"][0][
265
+ "String"].strip()
266
+
267
+ extracted_name, extracted_description, _ = extract_name( # noqa: E501
268
+ name_raw, description)
269
+ if extracted_name is not None:
270
+ CID2name_extracted[CID].append(extracted_name)
271
+
272
+ CID2text_raw[CID].append(description)
273
+ CID2text_extracted[CID].append(
274
+ extracted_description)
275
+
276
+ valid_CID_set.add(CID)
277
+ f_out.write(f"{CID}\n")
278
+ f_out.write(f"{extracted_description}\n\n")
279
+ except Exception:
280
+ continue
281
+
282
+ valid_CID_list = sorted(list(valid_CID_set))
283
+ print(f"Total CID (with raw name) {len(CID2name_raw)}")
284
+ print(f"Total CID (with extracted name) {len(CID2name_extracted)}")
285
+ print(f"Total CID {len(valid_CID_list)}")
286
+
287
+ with open(f"{self.raw_dir}/CID2name_raw.json", "w") as f:
288
+ json.dump(CID2name_raw, f)
289
+
290
+ with open(f"{self.raw_dir}/CID2name.json", "w") as f:
291
+ json.dump(CID2name_extracted, f)
292
+
293
+ with open(f"{self.raw_dir}/CID2text_raw.json", "w") as f:
294
+ json.dump(CID2text_raw, f)
295
+
296
+ with open(f"{self.raw_dir}/CID2text.json", "w") as f:
297
+ json.dump(CID2text_extracted, f)
298
+
299
+ # Step 02. Download SDF Files
300
+ step2_folder = f"{self.raw_dir}/step_02_PubChemSTM_SDF"
301
+ if not os.path.exists(step2_folder):
302
+ for block_id in tqdm(range(self.total_block_num)):
303
+ block_size = 500000
304
+ l_id = block_id * block_size + 1
305
+ r_id = (block_id + 1) * block_size
306
+
307
+ compound_file_name = f"Compound_{l_id:09d}_{r_id:09d}.sdf.gz"
308
+ download_url(f"{self.compound_url}/{compound_file_name}",
309
+ step2_folder)
310
+
311
+ def process(self, use_mp: bool = False) -> None:
312
+ try:
313
+ from rdkit import Chem
314
+ from rdkit.Chem.rdchem import BondType as BT
315
+ WITH_RDKIT = True
316
+
317
+ except ImportError:
318
+ WITH_RDKIT = False
319
+
320
+ if not WITH_RDKIT:
321
+ print(("Using a pre-processed version of the dataset. Please "
322
+ "install 'rdkit' to alternatively process the raw data."),
323
+ file=sys.stderr)
324
+
325
+ data_list = fs.torch_load(self.raw_paths[0])
326
+ data_list = [Data(**data_dict) for data_dict in data_list]
327
+
328
+ if self.pre_filter is not None:
329
+ data_list = [d for d in data_list if self.pre_filter(d)]
330
+
331
+ if self.pre_transform is not None:
332
+ data_list = [self.pre_transform(d) for d in data_list]
333
+
334
+ self.save(data_list, self.processed_paths[0])
335
+ return
336
+
337
+ # Step 03. Filter out SDF
338
+ step2_folder = f"{self.raw_dir}/step_02_PubChemSTM_SDF"
339
+ step3_folder = f"{self.raw_dir}/step_03_PubChemSTM_filtered"
340
+ if not os.path.exists(step3_folder):
341
+ os.makedirs(step3_folder)
342
+ with open(f"{self.raw_dir}/CID2text.json") as f:
343
+ CID2text = json.load(f)
344
+ target_CID_list = set(CID2text.keys())
345
+
346
+ block_size = 500000
347
+
348
+ def extract_one_SDF_file(block_id: int) -> None:
349
+ valid_mol_count = 0
350
+
351
+ writer = Chem.SDWriter(
352
+ f'{step3_folder}/filtered_{block_id}.sdf')
353
+ l_id = block_id * block_size + 1
354
+ r_id = (block_id + 1) * block_size
355
+
356
+ compound_file_name = f"Compound_{l_id:09d}_{r_id:09d}.sdf.gz"
357
+ gzip_loader = gzip.open(f"{step2_folder}/{compound_file_name}")
358
+ suppl = Chem.ForwardSDMolSupplier(gzip_loader)
359
+
360
+ for mol in tqdm(suppl):
361
+ if mol is None:
362
+ continue
363
+ cid = mol.GetProp("PUBCHEM_COMPOUND_CID")
364
+
365
+ if cid not in target_CID_list:
366
+ continue
367
+
368
+ writer.write(mol)
369
+ valid_mol_count += 1
370
+
371
+ print(f"block id: {block_id}\nfound {valid_mol_count}\n\n")
372
+ sys.stdout.flush()
373
+ return
374
+
375
+ if use_mp:
376
+ num_process = multiprocessing.cpu_count()
377
+ print(f"{num_process} CPUs")
378
+ num_process = 8
379
+ p = Pool(num_process)
380
+
381
+ block_id_list = np.arange(self.total_block_num)
382
+ with p:
383
+ p.map(extract_one_SDF_file, block_id_list)
384
+ else:
385
+ for block_id in range(self.total_block_num):
386
+ extract_one_SDF_file(block_id)
387
+
388
+ # Step 04. Merge SDF
389
+ with open(f"{self.raw_dir}/CID2text.json") as f:
390
+ CID2text = json.load(f)
391
+ target_CID_list = set(CID2text.keys())
392
+ print(f'The length of target_CID_list: {len(target_CID_list)}')
393
+
394
+ writer = Chem.SDWriter(f'{self.raw_dir}/molecules.sdf')
395
+
396
+ found_CID_set = set()
397
+ for block_id in range(self.total_block_num + 1):
398
+ compound_file_path = f"{step3_folder}/filtered_{block_id}.sdf"
399
+ try:
400
+ suppl = Chem.SDMolSupplier(compound_file_path)
401
+
402
+ for mol in tqdm(suppl):
403
+ writer.write(mol)
404
+ cid = mol.GetProp("PUBCHEM_COMPOUND_CID")
405
+ found_CID_set.add(cid)
406
+ except Exception:
407
+ print(f"block id: {block_id} with 0 valid SDF file")
408
+ continue
409
+
410
+ print(f"In total: {len(found_CID_set)} molecules")
411
+
412
+ # Step 05. Convert to PyG data format
413
+ types = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4, 'Unknow': 5}
414
+ bonds = {BT.SINGLE: 0, BT.DOUBLE: 1, BT.TRIPLE: 2, BT.AROMATIC: 3}
415
+
416
+ data_list = []
417
+ # Real data
418
+ CID2text_file = f'{self.raw_dir}/CID2text.json'
419
+
420
+ with open(CID2text_file) as f:
421
+ CID2text_data = json.load(f)
422
+
423
+ suppl = Chem.SDMolSupplier(f'{self.raw_dir}/molecules.sdf')
424
+
425
+ llm = LLM(
426
+ # model_name='lmsys/vicuna-7b-v1.5',
427
+ model_name='TinyLlama/TinyLlama-1.1B-Chat-v0.1',
428
+ num_params=1,
429
+ dtype=torch.bfloat16,
430
+ )
431
+ prompt = ("Propose a question regarding the molecule '∼' "
432
+ "whose answer is: {}:")
433
+ for mol in tqdm(suppl):
434
+ if mol.HasProp('PUBCHEM_COMPOUND_CID'):
435
+ CID = mol.GetProp("PUBCHEM_COMPOUND_CID")
436
+ CAN_SMILES = mol.GetProp("PUBCHEM_OPENEYE_CAN_SMILES")
437
+
438
+ m: Chem.Mol = Chem.MolFromSmiles(CAN_SMILES)
439
+ if m is None:
440
+ continue
441
+ RDKit_CAN_SMILES = Chem.MolToSmiles(m)
442
+
443
+ ground_truth = CID2text_data[CID][0]
444
+
445
+ instruction = llm.inference([prompt.format(ground_truth)])[0]
446
+
447
+ x: torch.Tensor = torch.tensor([
448
+ types[atom.GetSymbol()] if atom.GetSymbol() in types else 5
449
+ for atom in m.GetAtoms() # type: ignore
450
+ ])
451
+ x = one_hot(x, num_classes=len(types), dtype=torch.float)
452
+
453
+ rows, cols, edge_types = [], [], []
454
+ for bond in m.GetBonds(): # type: ignore
455
+ i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
456
+ edge_types += [bonds[bond.GetBondType()]] * 2
457
+ rows += [i, j]
458
+ cols += [j, i]
459
+
460
+ edge_index = torch.tensor([rows, cols], dtype=torch.long)
461
+ edge_type = torch.tensor(edge_types, dtype=torch.long)
462
+ edge_attr = one_hot(edge_type, num_classes=len(bonds))
463
+
464
+ data = Data(
465
+ x=x,
466
+ edge_index=edge_index,
467
+ edge_attr=edge_attr,
468
+ smiles=RDKit_CAN_SMILES,
469
+ instruction=instruction,
470
+ y=ground_truth,
471
+ )
472
+
473
+ if self.pre_filter is not None and not self.pre_filter(data):
474
+ continue
475
+ if self.pre_transform is not None:
476
+ data = self.pre_transform(data)
477
+
478
+ data_list.append(data)
479
+
480
+ self.save(data_list, self.processed_paths[0])