pyg-nightly 2.7.0.dev20250702__py3-none-any.whl → 2.7.0.dev20250703__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pyg-nightly might be problematic. Click here for more details.
- {pyg_nightly-2.7.0.dev20250702.dist-info → pyg_nightly-2.7.0.dev20250703.dist-info}/METADATA +1 -1
- {pyg_nightly-2.7.0.dev20250702.dist-info → pyg_nightly-2.7.0.dev20250703.dist-info}/RECORD +16 -13
- torch_geometric/__init__.py +1 -1
- torch_geometric/datasets/__init__.py +4 -0
- torch_geometric/datasets/git_mol_dataset.py +1 -1
- torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
- torch_geometric/datasets/qm9.py +1 -1
- torch_geometric/datasets/teeth3ds.py +269 -0
- torch_geometric/nn/conv/meshcnn_conv.py +9 -15
- torch_geometric/nn/models/__init__.py +2 -0
- torch_geometric/nn/models/glem.py +7 -3
- torch_geometric/nn/models/protein_mpnn.py +304 -0
- torch_geometric/utils/convert.py +15 -8
- torch_geometric/utils/smiles.py +1 -1
- {pyg_nightly-2.7.0.dev20250702.dist-info → pyg_nightly-2.7.0.dev20250703.dist-info}/WHEEL +0 -0
- {pyg_nightly-2.7.0.dev20250702.dist-info → pyg_nightly-2.7.0.dev20250703.dist-info}/licenses/LICENSE +0 -0
{pyg_nightly-2.7.0.dev20250702.dist-info → pyg_nightly-2.7.0.dev20250703.dist-info}/METADATA
RENAMED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: pyg-nightly
|
|
3
|
-
Version: 2.7.0.
|
|
3
|
+
Version: 2.7.0.dev20250703
|
|
4
4
|
Summary: Graph Neural Network Library for PyTorch
|
|
5
5
|
Keywords: deep-learning,pytorch,geometric-deep-learning,graph-neural-networks,graph-convolutional-networks
|
|
6
6
|
Author-email: Matthias Fey <matthias@pyg.org>
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
torch_geometric/__init__.py,sha256=
|
|
1
|
+
torch_geometric/__init__.py,sha256=8_AcpgPpDfVr7gwK3sYPD-pu9JYUmziOQzs9SPRvqcE,2250
|
|
2
2
|
torch_geometric/_compile.py,sha256=9yqMTBKatZPr40WavJz9FjNi7pQj8YZAZOyZmmRGXgc,1351
|
|
3
3
|
torch_geometric/_onnx.py,sha256=V9ffrIKSqhDw6xUZ12lkuSfNs48cQp2EeJ6Z19GfnVw,349
|
|
4
4
|
torch_geometric/backend.py,sha256=lVaf7aLoVaB3M-UcByUJ1G4T4FOK6LXAg0CF4W3E8jo,1575
|
|
@@ -55,7 +55,7 @@ torch_geometric/data/temporal.py,sha256=WOJ6gFrTLikaLhUvotyUF5ql14FkE5Ox3hNkdSp6
|
|
|
55
55
|
torch_geometric/data/view.py,sha256=XjkVSc-UWZFCT4DlXLShZtO8duhFQkS9gq88zZXANsk,1089
|
|
56
56
|
torch_geometric/data/lightning/__init__.py,sha256=w3En1tJfy3kSqe1MycpOyZpHFO3fxBCgNCUOznPA3YU,178
|
|
57
57
|
torch_geometric/data/lightning/datamodule.py,sha256=IjucsIKRcNv16DIqILQnqa_sz72a4-yivoySmEllv2o,29353
|
|
58
|
-
torch_geometric/datasets/__init__.py,sha256=
|
|
58
|
+
torch_geometric/datasets/__init__.py,sha256=rgfUmjd9U3o8renKVl81Brscx4LOtwWmt6qAoaG41C4,6417
|
|
59
59
|
torch_geometric/datasets/actor.py,sha256=oUxgJIX8bi5hJr1etWNYIFyVQNDDXi1nyVpHGGMEAGQ,4304
|
|
60
60
|
torch_geometric/datasets/airfrans.py,sha256=8cCBmHPttrlKY_iwfyr-K-CUX_JEDjrIOg3r9dQSN7o,5439
|
|
61
61
|
torch_geometric/datasets/airports.py,sha256=b3gkv3gY2JkUpmGiz36Z-g7EcnSfU8lBG1YsCOWdJ6k,3758
|
|
@@ -95,7 +95,7 @@ torch_geometric/datasets/gdelt_lite.py,sha256=zE1WagpgmsQARQhEgdCBtALRKyuQvIZqxT
|
|
|
95
95
|
torch_geometric/datasets/ged_dataset.py,sha256=dtd-C6pCygNHLXgVfg3ZTWtTVHKT13Q3GlGrze1_rpo,9551
|
|
96
96
|
torch_geometric/datasets/gemsec.py,sha256=oMTSryTgyed9z_4ydg3ql12KM-_35uqL1AoNls5nG8M,2820
|
|
97
97
|
torch_geometric/datasets/geometry.py,sha256=-BxUMirZcUOf01c3avvF0b6wGPn-4S3Zj3Oau1RaJVk,4223
|
|
98
|
-
torch_geometric/datasets/git_mol_dataset.py,sha256=
|
|
98
|
+
torch_geometric/datasets/git_mol_dataset.py,sha256=l5u4U86tfjJdHtQPN7SM3Yjv25LD1Idtm7VHaqJqNik,10665
|
|
99
99
|
torch_geometric/datasets/github.py,sha256=Qhqhkvi6eZ8VF_HqP1rL2iYToZavFNsQh7J1WdeM9dA,2687
|
|
100
100
|
torch_geometric/datasets/gnn_benchmark_dataset.py,sha256=4P8n7czF-gf1egLYlAcSSvfB0GXIKpAbH5UjsuFld1M,6976
|
|
101
101
|
torch_geometric/datasets/heterophilous_graph_dataset.py,sha256=yHHtwl4uPrid0vPOxvPV3sIS8HWdswar8FJ0h0OQ9is,4224
|
|
@@ -139,8 +139,9 @@ torch_geometric/datasets/pcqm4m.py,sha256=7ID_xXXIAyuNzYLI2lBWygZl9wGos-dbaz1b6E
|
|
|
139
139
|
torch_geometric/datasets/planetoid.py,sha256=RksfwR_PI7qGVphs-T-4jXDepYwQCweMXElLm096hgg,7201
|
|
140
140
|
torch_geometric/datasets/polblogs.py,sha256=IYzsvd4R0OojmOOZUoOdCwQYfwcTfth1PNtcBK1yOGc,3045
|
|
141
141
|
torch_geometric/datasets/ppi.py,sha256=zPtg-omC7WYvr9Tzwkb7zNjpXLODsvxKxKdGEUswp2E,5030
|
|
142
|
+
torch_geometric/datasets/protein_mpnn_dataset.py,sha256=TTeTVJMo0Rlt2_h9bbZMKJe3rTJcjCgY5cXGyWteBfA,17756
|
|
142
143
|
torch_geometric/datasets/qm7.py,sha256=bYyK8xlh9kTr5vqueNbLu9EAjIXkQH1KX1VWnjKfOJc,3323
|
|
143
|
-
torch_geometric/datasets/qm9.py,sha256=
|
|
144
|
+
torch_geometric/datasets/qm9.py,sha256=Ub1t8KNeWFZvw50_Qk-80yNFeYFDwdAeyQtp3JHZs7o,17197
|
|
144
145
|
torch_geometric/datasets/rcdd.py,sha256=gvOoM1tw_X5QMyBB4FkMUwNErMXAvImyjz5twktBAh8,5317
|
|
145
146
|
torch_geometric/datasets/reddit.py,sha256=QUgiKTaj6YTOYbgWgqV8mPYsctOui2ujaM8f8qy81v0,3131
|
|
146
147
|
torch_geometric/datasets/reddit2.py,sha256=WSdrhbDPcUEG37XWNUd0uKnqgI911MOcfjXmgjbTPoQ,4291
|
|
@@ -153,6 +154,7 @@ torch_geometric/datasets/snap_dataset.py,sha256=deJvB6cpIQ3bu_pcWoqgEo1-Kl_NcFi7
|
|
|
153
154
|
torch_geometric/datasets/suite_sparse.py,sha256=eqjH4vAUq872qdk3YdLkZSwlu6r7HHpTgK0vEVGmY1s,3278
|
|
154
155
|
torch_geometric/datasets/tag_dataset.py,sha256=qTnwr2N1tbWYeLGbItfv70UxQ3n1rKesjeVU3kcOCP8,14757
|
|
155
156
|
torch_geometric/datasets/taobao.py,sha256=CUcZpbWsNTasevflO8zqP0YvENy89P7wpKS4MHaDJ6Q,4170
|
|
157
|
+
torch_geometric/datasets/teeth3ds.py,sha256=hZvhcq9lsQENNFr5hk50w2T3CgxE_tlnQfrCgN6uIDQ,9919
|
|
156
158
|
torch_geometric/datasets/tosca.py,sha256=nUSF8NQT1GlkwWQLshjWmr8xORsvRHzzIqhUyDCvABc,4632
|
|
157
159
|
torch_geometric/datasets/tu_dataset.py,sha256=14OSaXBgVwT1dX2h1wZ3xVIwoo0GQBEfR3yWh6Q0VF0,7847
|
|
158
160
|
torch_geometric/datasets/twitch.py,sha256=qfEerf-Uaojx2ZvegENowdG4E7RoUT_HUO9xtULadvo,3658
|
|
@@ -374,7 +376,7 @@ torch_geometric/nn/conv/hgt_conv.py,sha256=lUhTWUMovMtn9yR_b2-kLNLqHChGOUl2OtXBY
|
|
|
374
376
|
torch_geometric/nn/conv/hypergraph_conv.py,sha256=4BosbbqJyprlI6QjPqIfMxCqnARU_0mUn1zcAQhbw90,8691
|
|
375
377
|
torch_geometric/nn/conv/le_conv.py,sha256=DonmmYZOKk5wIlTZzzIfNKqBY6MO0MRxYhyr0YtNz-Q,3494
|
|
376
378
|
torch_geometric/nn/conv/lg_conv.py,sha256=8jMa79iPsOUbXEfBIc3wmbvAD8T3d1j37LeIFTX3Yag,2369
|
|
377
|
-
torch_geometric/nn/conv/meshcnn_conv.py,sha256=
|
|
379
|
+
torch_geometric/nn/conv/meshcnn_conv.py,sha256=92zUcgfS0Fwv-MpddF4Ia1a65y7ddPAkazYf7D6kvwg,21951
|
|
378
380
|
torch_geometric/nn/conv/message_passing.py,sha256=ZuTvSvodGy1GyAW4mHtuoMUuxclam-7opidYNY5IHm8,44377
|
|
379
381
|
torch_geometric/nn/conv/mf_conv.py,sha256=SkOGMN1tFT9dcqy8xYowsB2ozw6QfkoArgR1BksZZaU,4340
|
|
380
382
|
torch_geometric/nn/conv/mixhop_conv.py,sha256=qVDPWeWcnO7_eHM0ZnpKtr8SISjb4jp0xjgpoDrwjlk,4555
|
|
@@ -429,7 +431,7 @@ torch_geometric/nn/kge/distmult.py,sha256=dGQ0bVzjreZgFN1lXE23_IIidsiOq7ehPrMb-N
|
|
|
429
431
|
torch_geometric/nn/kge/loader.py,sha256=5Uc1j3OUMQnBYSHDqL7pLCty1siFLzoPkztigYO2zP8,771
|
|
430
432
|
torch_geometric/nn/kge/rotate.py,sha256=XLuO1AbyTt5cJxr97ZzoyAyIEsHKesgW5TvDmnGJAao,3208
|
|
431
433
|
torch_geometric/nn/kge/transe.py,sha256=jlejq5BLMm-sb1wWcLDp7pZqCdelWBgjDIC8ctbjSdU,3088
|
|
432
|
-
torch_geometric/nn/models/__init__.py,sha256=
|
|
434
|
+
torch_geometric/nn/models/__init__.py,sha256=fbHQauZw9Snvl2PuN5cjZoAW8SwUl6E-p2IOmwUKB3A,2395
|
|
433
435
|
torch_geometric/nn/models/attentive_fp.py,sha256=1z3iTV2O5W9tqHFAdno8FeBFeXmuG-TDZk4lwwVh3Ac,6634
|
|
434
436
|
torch_geometric/nn/models/attract_repel.py,sha256=h9OyogT0NY0xiT0DkpJHMxH6ZUmo8R-CmwZdKEwq8Ek,5277
|
|
435
437
|
torch_geometric/nn/models/autoencoder.py,sha256=nGje-zty78Y3hxOJ9o0_6QziJjOvBlknk6z0_fDQwQU,10770
|
|
@@ -442,7 +444,7 @@ torch_geometric/nn/models/dimenet.py,sha256=O2rqEx5HWs_lMwRD8eq6WMkbqJaCLL5zgWUJ
|
|
|
442
444
|
torch_geometric/nn/models/dimenet_utils.py,sha256=Eyn_EiJqwKvuYj6BtRpSxrzMG3v4Gk98X9MxZ7uvwm4,5069
|
|
443
445
|
torch_geometric/nn/models/g_retriever.py,sha256=tVibbqM_r-1LnA3R3oVyzp0bpuN3qPoYqcU6LZ8dYEk,8260
|
|
444
446
|
torch_geometric/nn/models/git_mol.py,sha256=Wc6Hx6RDDR7sDWRWHfA5eK9e9gFsrTZ9OLmpMfoj3pE,12676
|
|
445
|
-
torch_geometric/nn/models/glem.py,sha256=
|
|
447
|
+
torch_geometric/nn/models/glem.py,sha256=GlL_I63g-_5eTycSGRj720YntldQ-CQ351RaDPc6XAU,16674
|
|
446
448
|
torch_geometric/nn/models/gnnff.py,sha256=15dkiLgy0LmH1hnUrpeoHioIp4BPTfjpVATpnGRt9E0,7860
|
|
447
449
|
torch_geometric/nn/models/gpse.py,sha256=acEAeeicLgzKRL54WhvIFxjA5XViHgXgMEH-NgbMdqI,41971
|
|
448
450
|
torch_geometric/nn/models/graph_mixer.py,sha256=mthMeCOikR8gseEsu4oJ3Cd9C35zHSv1p32ROwnG-6s,9246
|
|
@@ -459,6 +461,7 @@ torch_geometric/nn/models/molecule_gpt.py,sha256=k-XULH6jaurj-R2EE4sIWTkqlNqa3Cz
|
|
|
459
461
|
torch_geometric/nn/models/neural_fingerprint.py,sha256=pTLJgU9Uh2Lnf9bggLj4cKI8YdEFcMF-9MALuubqbuQ,2378
|
|
460
462
|
torch_geometric/nn/models/node2vec.py,sha256=81Ku4Rp4IwLEAy06KEgJ2fYtXXVL_uv_Hb8lBr6YXrE,7664
|
|
461
463
|
torch_geometric/nn/models/pmlp.py,sha256=dcAASVSyQMMhItSfEJWPeAFh0R3tNCwAHwdrShwQ8o4,3538
|
|
464
|
+
torch_geometric/nn/models/protein_mpnn.py,sha256=QXHfltiJPmakpzgJKw_1vwCGBlszv9nfY4r4F38Sg9k,11031
|
|
462
465
|
torch_geometric/nn/models/re_net.py,sha256=pz66q5b5BoGDNVQvpEGS2RGoeKvpjkYAv9r3WAuvITk,8986
|
|
463
466
|
torch_geometric/nn/models/rect.py,sha256=2F3XyyvHTAEuqfJpiNB5M8pSGy738LhPiom5I-SDWqM,2808
|
|
464
467
|
torch_geometric/nn/models/rev_gnn.py,sha256=Bpme087Zs227lcB0ODOKWsxaly67q96wseaRt6bacjs,11796
|
|
@@ -613,7 +616,7 @@ torch_geometric/utils/_tree_decomposition.py,sha256=ZtpjPQJgXbQWtSWjo-Fmhrov0DGO
|
|
|
613
616
|
torch_geometric/utils/_trim_to_layer.py,sha256=cauOEzMJJK4w9BC-Pg1bHVncBYqG9XxQex3rn10BFjc,8339
|
|
614
617
|
torch_geometric/utils/_unbatch.py,sha256=B0vjKI96PtHvSBG8F_lqvsiJE134aVjUurPZsG6UZRI,2378
|
|
615
618
|
torch_geometric/utils/augmentation.py,sha256=1F0YCuaklZ9ZbXxdFV0oOoemWvLd8p60WvFo2chzl7E,8600
|
|
616
|
-
torch_geometric/utils/convert.py,sha256=
|
|
619
|
+
torch_geometric/utils/convert.py,sha256=RE5n5no74Xu39-QMWFE0-1RvTgykdK33ymyjF9WcuSs,21938
|
|
617
620
|
torch_geometric/utils/cross_entropy.py,sha256=ZFS5bivtzv3EV9zqgKsekmuQyoZZggPSclhl_tRNHxo,3047
|
|
618
621
|
torch_geometric/utils/dropout.py,sha256=gg0rDnD4FLvBaKSoLAkZwViAQflhLefJm6_Mju5dmQs,11416
|
|
619
622
|
torch_geometric/utils/embedding.py,sha256=Ac_MPSrZGpw-e-gU6Yz-seUioC2WZxBSSzXFeclGwMk,5232
|
|
@@ -634,13 +637,13 @@ torch_geometric/utils/num_nodes.py,sha256=F15ciTFOe8AxjkUh1wKH7RLmJvQYYpz-l3pPPv
|
|
|
634
637
|
torch_geometric/utils/ppr.py,sha256=ebiHbQqRJsQbGUI5xu-IkzQSQsgIaC71vgO0KcXIKAk,4055
|
|
635
638
|
torch_geometric/utils/random.py,sha256=Rv5HlhG5310rytbT9EZ7xWLGKQfozfz1azvYi5nx2-U,5148
|
|
636
639
|
torch_geometric/utils/repeat.py,sha256=RxCoRoEisaP6NouXPPW5tY1Rn-tIfrmpJPm0qGP6W8M,815
|
|
637
|
-
torch_geometric/utils/smiles.py,sha256=
|
|
640
|
+
torch_geometric/utils/smiles.py,sha256=CFqeNtSBXQtY9Ex2gQzI0La490IpVVrm01QdRYEpV7w,7114
|
|
638
641
|
torch_geometric/utils/sparse.py,sha256=1DbaEwdyvnzvg5qVjPlnWcEVDMkxrQLX1jJ0dr6P4js,25135
|
|
639
642
|
torch_geometric/utils/undirected.py,sha256=H_nfpI0_WluOG6VfjPyldvcjL4w5USAKWu2x5nUAUjw,6222
|
|
640
643
|
torch_geometric/visualization/__init__.py,sha256=b-HnVesXjyJ_L1N-DnjiRiRVf7lhwKaBQF_2i5YMVSU,208
|
|
641
644
|
torch_geometric/visualization/graph.py,sha256=mfZHXYfiU-CWMtfawYc80IxVwVmtK9hbIkSKhM_j7oI,14311
|
|
642
645
|
torch_geometric/visualization/influence.py,sha256=CWMvuNA_Nf1sfbJmQgn58yS4OFpeKXeZPe7kEuvkUBw,477
|
|
643
|
-
pyg_nightly-2.7.0.
|
|
644
|
-
pyg_nightly-2.7.0.
|
|
645
|
-
pyg_nightly-2.7.0.
|
|
646
|
-
pyg_nightly-2.7.0.
|
|
646
|
+
pyg_nightly-2.7.0.dev20250703.dist-info/licenses/LICENSE,sha256=ic-27cMJc1kWoMEYncz3Ya3Ur2Bi3bNLWib2DT763-o,1067
|
|
647
|
+
pyg_nightly-2.7.0.dev20250703.dist-info/WHEEL,sha256=G2gURzTEtmeR8nrdXUJfNiB3VYVxigPQ-bEQujpNiNs,82
|
|
648
|
+
pyg_nightly-2.7.0.dev20250703.dist-info/METADATA,sha256=lm-xVNswGkfSaczkhV_ANegM0oGwBEfhCv75gwj0X5Q,63005
|
|
649
|
+
pyg_nightly-2.7.0.dev20250703.dist-info/RECORD,,
|
torch_geometric/__init__.py
CHANGED
|
@@ -31,7 +31,7 @@ from .lazy_loader import LazyLoader
|
|
|
31
31
|
contrib = LazyLoader('contrib', globals(), 'torch_geometric.contrib')
|
|
32
32
|
graphgym = LazyLoader('graphgym', globals(), 'torch_geometric.graphgym')
|
|
33
33
|
|
|
34
|
-
__version__ = '2.7.0.
|
|
34
|
+
__version__ = '2.7.0.dev20250703'
|
|
35
35
|
|
|
36
36
|
__all__ = [
|
|
37
37
|
'Index',
|
|
@@ -81,8 +81,10 @@ from .web_qsp_dataset import WebQSPDataset, CWQDataset
|
|
|
81
81
|
from .git_mol_dataset import GitMolDataset
|
|
82
82
|
from .molecule_gpt_dataset import MoleculeGPTDataset
|
|
83
83
|
from .instruct_mol_dataset import InstructMolDataset
|
|
84
|
+
from .protein_mpnn_dataset import ProteinMPNNDataset
|
|
84
85
|
from .tag_dataset import TAGDataset
|
|
85
86
|
from .city import CityNetwork
|
|
87
|
+
from .teeth3ds import Teeth3DS
|
|
86
88
|
|
|
87
89
|
from .dbp15k import DBP15K
|
|
88
90
|
from .aminer import AMiner
|
|
@@ -201,8 +203,10 @@ homo_datasets = [
|
|
|
201
203
|
'GitMolDataset',
|
|
202
204
|
'MoleculeGPTDataset',
|
|
203
205
|
'InstructMolDataset',
|
|
206
|
+
'ProteinMPNNDataset',
|
|
204
207
|
'TAGDataset',
|
|
205
208
|
'CityNetwork',
|
|
209
|
+
'Teeth3DS',
|
|
206
210
|
]
|
|
207
211
|
|
|
208
212
|
hetero_datasets = [
|
|
@@ -0,0 +1,451 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import pickle
|
|
3
|
+
import random
|
|
4
|
+
from collections import defaultdict
|
|
5
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import torch
|
|
9
|
+
from tqdm import tqdm
|
|
10
|
+
|
|
11
|
+
from torch_geometric.data import (
|
|
12
|
+
Data,
|
|
13
|
+
InMemoryDataset,
|
|
14
|
+
download_url,
|
|
15
|
+
extract_tar,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class ProteinMPNNDataset(InMemoryDataset):
|
|
20
|
+
r"""The ProteinMPNN dataset from the `"Robust deep learning based protein
|
|
21
|
+
sequence design using ProteinMPNN"
|
|
22
|
+
<https://www.biorxiv.org/content/10.1101/2022.06.03.494563v1>`_ paper.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
root (str): Root directory where the dataset should be saved.
|
|
26
|
+
size (str): Size of the PDB information to train the model.
|
|
27
|
+
If :obj:`"small"`, loads the small dataset (229.4 MB).
|
|
28
|
+
If :obj:`"large"`, loads the large dataset (64.1 GB).
|
|
29
|
+
(default: :obj:`"small"`)
|
|
30
|
+
split (str, optional): If :obj:`"train"`, loads the training dataset.
|
|
31
|
+
If :obj:`"valid"`, loads the validation dataset.
|
|
32
|
+
If :obj:`"test"`, loads the test dataset.
|
|
33
|
+
(default: :obj:`"train"`)
|
|
34
|
+
datacut (str, optional): Date cutoff to filter the dataset.
|
|
35
|
+
(default: :obj:`"2030-01-01"`)
|
|
36
|
+
rescut (float, optional): PDB resolution cutoff.
|
|
37
|
+
(default: :obj:`3.5`)
|
|
38
|
+
homo (float, optional): Homology cutoff.
|
|
39
|
+
(default: :obj:`0.70`)
|
|
40
|
+
max_length (int, optional): Maximum length of the protein complex.
|
|
41
|
+
(default: :obj:`10000`)
|
|
42
|
+
num_units (int, optional): Number of units of the protein complex.
|
|
43
|
+
(default: :obj:`150`)
|
|
44
|
+
transform (callable, optional): A function/transform that takes in an
|
|
45
|
+
:obj:`torch_geometric.data.Data` object and returns a transformed
|
|
46
|
+
version. The data object will be transformed before every access.
|
|
47
|
+
(default: :obj:`None`)
|
|
48
|
+
pre_transform (callable, optional): A function/transform that takes in
|
|
49
|
+
an :obj:`torch_geometric.data.Data` object and returns a
|
|
50
|
+
transformed version. The data object will be transformed before
|
|
51
|
+
being saved to disk. (default: :obj:`None`)
|
|
52
|
+
pre_filter (callable, optional): A function that takes in an
|
|
53
|
+
:obj:`torch_geometric.data.Data` object and returns a boolean
|
|
54
|
+
value, indicating whether the data object should be included in the
|
|
55
|
+
final dataset. (default: :obj:`None`)
|
|
56
|
+
force_reload (bool, optional): Whether to re-process the dataset.
|
|
57
|
+
(default: :obj:`False`)
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
raw_url = {
|
|
61
|
+
'small':
|
|
62
|
+
'https://files.ipd.uw.edu/pub/training_sets/'
|
|
63
|
+
'pdb_2021aug02_sample.tar.gz',
|
|
64
|
+
'large':
|
|
65
|
+
'https://files.ipd.uw.edu/pub/training_sets/'
|
|
66
|
+
'pdb_2021aug02.tar.gz',
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
splits = {
|
|
70
|
+
'train': 1,
|
|
71
|
+
'valid': 2,
|
|
72
|
+
'test': 3,
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
def __init__(
|
|
76
|
+
self,
|
|
77
|
+
root: str,
|
|
78
|
+
size: str = 'small',
|
|
79
|
+
split: str = 'train',
|
|
80
|
+
datacut: str = '2030-01-01',
|
|
81
|
+
rescut: float = 3.5,
|
|
82
|
+
homo: float = 0.70,
|
|
83
|
+
max_length: int = 10000,
|
|
84
|
+
num_units: int = 150,
|
|
85
|
+
transform: Optional[Callable] = None,
|
|
86
|
+
pre_transform: Optional[Callable] = None,
|
|
87
|
+
pre_filter: Optional[Callable] = None,
|
|
88
|
+
force_reload: bool = False,
|
|
89
|
+
) -> None:
|
|
90
|
+
self.size = size
|
|
91
|
+
self.split = split
|
|
92
|
+
self.datacut = datacut
|
|
93
|
+
self.rescut = rescut
|
|
94
|
+
self.homo = homo
|
|
95
|
+
self.max_length = max_length
|
|
96
|
+
self.num_units = num_units
|
|
97
|
+
|
|
98
|
+
self.sub_folder = self.raw_url[self.size].split('/')[-1].split('.')[0]
|
|
99
|
+
|
|
100
|
+
super().__init__(root, transform, pre_transform, pre_filter,
|
|
101
|
+
force_reload=force_reload)
|
|
102
|
+
self.load(self.processed_paths[self.splits[self.split]])
|
|
103
|
+
|
|
104
|
+
@property
|
|
105
|
+
def raw_file_names(self) -> List[str]:
|
|
106
|
+
return [
|
|
107
|
+
f'{self.sub_folder}/{f}'
|
|
108
|
+
for f in ['list.csv', 'valid_clusters.txt', 'test_clusters.txt']
|
|
109
|
+
]
|
|
110
|
+
|
|
111
|
+
@property
|
|
112
|
+
def processed_file_names(self) -> List[str]:
|
|
113
|
+
return ['splits.pkl', 'train.pt', 'valid.pt', 'test.pt']
|
|
114
|
+
|
|
115
|
+
def download(self) -> None:
|
|
116
|
+
file_path = download_url(self.raw_url[self.size], self.raw_dir)
|
|
117
|
+
extract_tar(file_path, self.raw_dir)
|
|
118
|
+
os.unlink(file_path)
|
|
119
|
+
|
|
120
|
+
def process(self) -> None:
|
|
121
|
+
alphabet_set = set(list('ACDEFGHIKLMNPQRSTVWYX'))
|
|
122
|
+
cluster_ids = self._process_split()
|
|
123
|
+
total_items = sum(len(items) for items in cluster_ids.values())
|
|
124
|
+
data_list = []
|
|
125
|
+
|
|
126
|
+
with tqdm(total=total_items, desc="Processing") as pbar:
|
|
127
|
+
for _, items in cluster_ids.items():
|
|
128
|
+
for chain_id, _ in items:
|
|
129
|
+
item = self._process_pdb1(chain_id)
|
|
130
|
+
|
|
131
|
+
if 'label' not in item:
|
|
132
|
+
pbar.update(1)
|
|
133
|
+
continue
|
|
134
|
+
if len(list(np.unique(item['idx']))) >= 352:
|
|
135
|
+
pbar.update(1)
|
|
136
|
+
continue
|
|
137
|
+
|
|
138
|
+
my_dict = self._process_pdb2(item)
|
|
139
|
+
|
|
140
|
+
if len(my_dict['seq']) > self.max_length:
|
|
141
|
+
pbar.update(1)
|
|
142
|
+
continue
|
|
143
|
+
bad_chars = set(list(
|
|
144
|
+
my_dict['seq'])).difference(alphabet_set)
|
|
145
|
+
if len(bad_chars) > 0:
|
|
146
|
+
pbar.update(1)
|
|
147
|
+
continue
|
|
148
|
+
|
|
149
|
+
x_chain_all, chain_seq_label_all, mask, chain_mask_all, residue_idx, chain_encoding_all = self._process_pdb3( # noqa: E501
|
|
150
|
+
my_dict)
|
|
151
|
+
|
|
152
|
+
data = Data(
|
|
153
|
+
x=x_chain_all, # [seq_len, 4, 3]
|
|
154
|
+
chain_seq_label=chain_seq_label_all, # [seq_len]
|
|
155
|
+
mask=mask, # [seq_len]
|
|
156
|
+
chain_mask_all=chain_mask_all, # [seq_len]
|
|
157
|
+
residue_idx=residue_idx, # [seq_len]
|
|
158
|
+
chain_encoding_all=chain_encoding_all, # [seq_len]
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
if self.pre_filter is not None and not self.pre_filter(
|
|
162
|
+
data):
|
|
163
|
+
continue
|
|
164
|
+
if self.pre_transform is not None:
|
|
165
|
+
data = self.pre_transform(data)
|
|
166
|
+
|
|
167
|
+
data_list.append(data)
|
|
168
|
+
|
|
169
|
+
if len(data_list) >= self.num_units:
|
|
170
|
+
pbar.update(total_items - pbar.n)
|
|
171
|
+
break
|
|
172
|
+
pbar.update(1)
|
|
173
|
+
else:
|
|
174
|
+
continue
|
|
175
|
+
break
|
|
176
|
+
self.save(data_list, self.processed_paths[self.splits[self.split]])
|
|
177
|
+
|
|
178
|
+
def _process_split(self) -> Dict[int, List[Tuple[str, int]]]:
|
|
179
|
+
import pandas as pd
|
|
180
|
+
save_path = self.processed_paths[0]
|
|
181
|
+
|
|
182
|
+
if os.path.exists(save_path):
|
|
183
|
+
print('Load split')
|
|
184
|
+
with open(save_path, 'rb') as f:
|
|
185
|
+
data = pickle.load(f)
|
|
186
|
+
else:
|
|
187
|
+
# CHAINID, DEPOSITION, RESOLUTION, HASH, CLUSTER, SEQUENCE
|
|
188
|
+
df = pd.read_csv(self.raw_paths[0])
|
|
189
|
+
df = df[(df['RESOLUTION'] <= self.rescut)
|
|
190
|
+
& (df['DEPOSITION'] <= self.datacut)]
|
|
191
|
+
|
|
192
|
+
val_ids = pd.read_csv(self.raw_paths[1], header=None)[0].tolist()
|
|
193
|
+
test_ids = pd.read_csv(self.raw_paths[2], header=None)[0].tolist()
|
|
194
|
+
|
|
195
|
+
# compile training and validation sets
|
|
196
|
+
data = {
|
|
197
|
+
'train': defaultdict(list),
|
|
198
|
+
'valid': defaultdict(list),
|
|
199
|
+
'test': defaultdict(list),
|
|
200
|
+
}
|
|
201
|
+
|
|
202
|
+
for _, r in tqdm(df.iterrows(), desc='Processing split',
|
|
203
|
+
total=len(df)):
|
|
204
|
+
cluster_id = r['CLUSTER']
|
|
205
|
+
hash_id = r['HASH']
|
|
206
|
+
chain_id = r['CHAINID']
|
|
207
|
+
if cluster_id in val_ids:
|
|
208
|
+
data['valid'][cluster_id].append((chain_id, hash_id))
|
|
209
|
+
elif cluster_id in test_ids:
|
|
210
|
+
data['test'][cluster_id].append((chain_id, hash_id))
|
|
211
|
+
else:
|
|
212
|
+
data['train'][cluster_id].append((chain_id, hash_id))
|
|
213
|
+
|
|
214
|
+
with open(save_path, 'wb') as f:
|
|
215
|
+
pickle.dump(data, f)
|
|
216
|
+
|
|
217
|
+
return data[self.split]
|
|
218
|
+
|
|
219
|
+
def _process_pdb1(self, chain_id: str) -> Dict[str, Any]:
|
|
220
|
+
pdbid, chid = chain_id.split('_')
|
|
221
|
+
prefix = f'{self.raw_dir}/{self.sub_folder}/pdb/{pdbid[1:3]}/{pdbid}'
|
|
222
|
+
# load metadata
|
|
223
|
+
if not os.path.isfile(f'{prefix}.pt'):
|
|
224
|
+
return {'seq': np.zeros(5)}
|
|
225
|
+
meta = torch.load(f'{prefix}.pt')
|
|
226
|
+
asmb_ids = meta['asmb_ids']
|
|
227
|
+
asmb_chains = meta['asmb_chains']
|
|
228
|
+
chids = np.array(meta['chains'])
|
|
229
|
+
|
|
230
|
+
# find candidate assemblies which contain chid chain
|
|
231
|
+
asmb_candidates = {
|
|
232
|
+
a
|
|
233
|
+
for a, b in zip(asmb_ids, asmb_chains) if chid in b.split(',')
|
|
234
|
+
}
|
|
235
|
+
|
|
236
|
+
# if the chains is missing is missing from all the assemblies
|
|
237
|
+
# then return this chain alone
|
|
238
|
+
if len(asmb_candidates) < 1:
|
|
239
|
+
chain = torch.load(f'{prefix}_{chid}.pt')
|
|
240
|
+
L = len(chain['seq'])
|
|
241
|
+
return {
|
|
242
|
+
'seq': chain['seq'],
|
|
243
|
+
'xyz': chain['xyz'],
|
|
244
|
+
'idx': torch.zeros(L).int(),
|
|
245
|
+
'masked': torch.Tensor([0]).int(),
|
|
246
|
+
'label': chain_id,
|
|
247
|
+
}
|
|
248
|
+
|
|
249
|
+
# randomly pick one assembly from candidates
|
|
250
|
+
asmb_i = random.sample(list(asmb_candidates), 1)
|
|
251
|
+
|
|
252
|
+
# indices of selected transforms
|
|
253
|
+
idx = np.where(np.array(asmb_ids) == asmb_i)[0]
|
|
254
|
+
|
|
255
|
+
# load relevant chains
|
|
256
|
+
chains = {
|
|
257
|
+
c: torch.load(f'{prefix}_{c}.pt')
|
|
258
|
+
for i in idx
|
|
259
|
+
for c in asmb_chains[i] if c in meta['chains']
|
|
260
|
+
}
|
|
261
|
+
|
|
262
|
+
# generate assembly
|
|
263
|
+
asmb = {}
|
|
264
|
+
for k in idx:
|
|
265
|
+
|
|
266
|
+
# pick k-th xform
|
|
267
|
+
xform = meta[f'asmb_xform{k}']
|
|
268
|
+
u = xform[:, :3, :3]
|
|
269
|
+
r = xform[:, :3, 3]
|
|
270
|
+
|
|
271
|
+
# select chains which k-th xform should be applied to
|
|
272
|
+
s1 = set(meta['chains'])
|
|
273
|
+
s2 = set(asmb_chains[k].split(','))
|
|
274
|
+
chains_k = s1 & s2
|
|
275
|
+
|
|
276
|
+
# transform selected chains
|
|
277
|
+
for c in chains_k:
|
|
278
|
+
try:
|
|
279
|
+
xyz = chains[c]['xyz']
|
|
280
|
+
xyz_ru = torch.einsum('bij,raj->brai', u, xyz) + r[:, None,
|
|
281
|
+
None, :]
|
|
282
|
+
asmb.update({
|
|
283
|
+
(c, k, i): xyz_i
|
|
284
|
+
for i, xyz_i in enumerate(xyz_ru)
|
|
285
|
+
})
|
|
286
|
+
except KeyError:
|
|
287
|
+
return {'seq': np.zeros(5)}
|
|
288
|
+
|
|
289
|
+
# select chains which share considerable similarity to chid
|
|
290
|
+
seqid = meta['tm'][chids == chid][0, :, 1]
|
|
291
|
+
homo = {
|
|
292
|
+
ch_j
|
|
293
|
+
for seqid_j, ch_j in zip(seqid, chids) if seqid_j > self.homo
|
|
294
|
+
}
|
|
295
|
+
# stack all chains in the assembly together
|
|
296
|
+
seq: str = ''
|
|
297
|
+
xyz_all: List[torch.Tensor] = []
|
|
298
|
+
idx_all: List[torch.Tensor] = []
|
|
299
|
+
masked: List[int] = []
|
|
300
|
+
seq_list = []
|
|
301
|
+
for counter, (k, v) in enumerate(asmb.items()):
|
|
302
|
+
seq += chains[k[0]]['seq']
|
|
303
|
+
seq_list.append(chains[k[0]]['seq'])
|
|
304
|
+
xyz_all.append(v)
|
|
305
|
+
idx_all.append(torch.full((v.shape[0], ), counter))
|
|
306
|
+
if k[0] in homo:
|
|
307
|
+
masked.append(counter)
|
|
308
|
+
|
|
309
|
+
return {
|
|
310
|
+
'seq': seq,
|
|
311
|
+
'xyz': torch.cat(xyz_all, dim=0),
|
|
312
|
+
'idx': torch.cat(idx_all, dim=0),
|
|
313
|
+
'masked': torch.Tensor(masked).int(),
|
|
314
|
+
'label': chain_id,
|
|
315
|
+
}
|
|
316
|
+
|
|
317
|
+
def _process_pdb2(self, t: Dict[str, Any]) -> Dict[str, Any]:
|
|
318
|
+
init_alphabet = list(
|
|
319
|
+
'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz')
|
|
320
|
+
extra_alphabet = [str(item) for item in list(np.arange(300))]
|
|
321
|
+
chain_alphabet = init_alphabet + extra_alphabet
|
|
322
|
+
my_dict: Dict[str, Union[str, int, Dict[str, Any], List[Any]]] = {}
|
|
323
|
+
concat_seq = ''
|
|
324
|
+
mask_list = []
|
|
325
|
+
visible_list = []
|
|
326
|
+
for idx in list(np.unique(t['idx'])):
|
|
327
|
+
letter = chain_alphabet[idx]
|
|
328
|
+
res = np.argwhere(t['idx'] == idx)
|
|
329
|
+
initial_sequence = "".join(list(
|
|
330
|
+
np.array(list(t['seq']))[res][
|
|
331
|
+
0,
|
|
332
|
+
]))
|
|
333
|
+
if initial_sequence[-6:] == "HHHHHH":
|
|
334
|
+
res = res[:, :-6]
|
|
335
|
+
if initial_sequence[0:6] == "HHHHHH":
|
|
336
|
+
res = res[:, 6:]
|
|
337
|
+
if initial_sequence[-7:-1] == "HHHHHH":
|
|
338
|
+
res = res[:, :-7]
|
|
339
|
+
if initial_sequence[-8:-2] == "HHHHHH":
|
|
340
|
+
res = res[:, :-8]
|
|
341
|
+
if initial_sequence[-9:-3] == "HHHHHH":
|
|
342
|
+
res = res[:, :-9]
|
|
343
|
+
if initial_sequence[-10:-4] == "HHHHHH":
|
|
344
|
+
res = res[:, :-10]
|
|
345
|
+
if initial_sequence[1:7] == "HHHHHH":
|
|
346
|
+
res = res[:, 7:]
|
|
347
|
+
if initial_sequence[2:8] == "HHHHHH":
|
|
348
|
+
res = res[:, 8:]
|
|
349
|
+
if initial_sequence[3:9] == "HHHHHH":
|
|
350
|
+
res = res[:, 9:]
|
|
351
|
+
if initial_sequence[4:10] == "HHHHHH":
|
|
352
|
+
res = res[:, 10:]
|
|
353
|
+
if res.shape[1] >= 4:
|
|
354
|
+
chain_seq = "".join(list(np.array(list(t['seq']))[res][0]))
|
|
355
|
+
my_dict[f'seq_chain_{letter}'] = chain_seq
|
|
356
|
+
concat_seq += chain_seq
|
|
357
|
+
if idx in t['masked']:
|
|
358
|
+
mask_list.append(letter)
|
|
359
|
+
else:
|
|
360
|
+
visible_list.append(letter)
|
|
361
|
+
coords_dict_chain = {}
|
|
362
|
+
all_atoms = np.array(t['xyz'][res])[0] # [L, 14, 3]
|
|
363
|
+
for i, c in enumerate(['N', 'CA', 'C', 'O']):
|
|
364
|
+
coords_dict_chain[
|
|
365
|
+
f'{c}_chain_{letter}'] = all_atoms[:, i, :].tolist()
|
|
366
|
+
my_dict[f'coords_chain_{letter}'] = coords_dict_chain
|
|
367
|
+
my_dict['name'] = t['label']
|
|
368
|
+
my_dict['masked_list'] = mask_list
|
|
369
|
+
my_dict['visible_list'] = visible_list
|
|
370
|
+
my_dict['num_of_chains'] = len(mask_list) + len(visible_list)
|
|
371
|
+
my_dict['seq'] = concat_seq
|
|
372
|
+
return my_dict
|
|
373
|
+
|
|
374
|
+
def _process_pdb3(
|
|
375
|
+
self, b: Dict[str, Any]
|
|
376
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
|
|
377
|
+
torch.Tensor, torch.Tensor]:
|
|
378
|
+
L = len(b['seq'])
|
|
379
|
+
# residue idx with jumps across chains
|
|
380
|
+
residue_idx = -100 * np.ones([L], dtype=np.int32)
|
|
381
|
+
# get the list of masked / visible chains
|
|
382
|
+
masked_chains, visible_chains = b['masked_list'], b['visible_list']
|
|
383
|
+
visible_temp_dict, masked_temp_dict = {}, {}
|
|
384
|
+
for letter in masked_chains + visible_chains:
|
|
385
|
+
chain_seq = b[f'seq_chain_{letter}']
|
|
386
|
+
if letter in visible_chains:
|
|
387
|
+
visible_temp_dict[letter] = chain_seq
|
|
388
|
+
elif letter in masked_chains:
|
|
389
|
+
masked_temp_dict[letter] = chain_seq
|
|
390
|
+
# check for duplicate chains (same sequence but different identity)
|
|
391
|
+
for _, vm in masked_temp_dict.items():
|
|
392
|
+
for kv, vv in visible_temp_dict.items():
|
|
393
|
+
if vm == vv:
|
|
394
|
+
if kv not in masked_chains:
|
|
395
|
+
masked_chains.append(kv)
|
|
396
|
+
if kv in visible_chains:
|
|
397
|
+
visible_chains.remove(kv)
|
|
398
|
+
# build protein data structures
|
|
399
|
+
all_chains = masked_chains + visible_chains
|
|
400
|
+
np.random.shuffle(all_chains)
|
|
401
|
+
x_chain_list = []
|
|
402
|
+
chain_mask_list = []
|
|
403
|
+
chain_seq_list = []
|
|
404
|
+
chain_encoding_list = []
|
|
405
|
+
c, l0, l1 = 1, 0, 0
|
|
406
|
+
for letter in all_chains:
|
|
407
|
+
chain_seq = b[f'seq_chain_{letter}']
|
|
408
|
+
chain_length = len(chain_seq)
|
|
409
|
+
chain_coords = b[f'coords_chain_{letter}']
|
|
410
|
+
x_chain = np.stack([
|
|
411
|
+
chain_coords[c] for c in [
|
|
412
|
+
f'N_chain_{letter}', f'CA_chain_{letter}',
|
|
413
|
+
f'C_chain_{letter}', f'O_chain_{letter}'
|
|
414
|
+
]
|
|
415
|
+
], 1) # [chain_length, 4, 3]
|
|
416
|
+
x_chain_list.append(x_chain)
|
|
417
|
+
chain_seq_list.append(chain_seq)
|
|
418
|
+
if letter in visible_chains:
|
|
419
|
+
chain_mask = np.zeros(chain_length) # 0 for visible chains
|
|
420
|
+
elif letter in masked_chains:
|
|
421
|
+
chain_mask = np.ones(chain_length) # 1 for masked chains
|
|
422
|
+
chain_mask_list.append(chain_mask)
|
|
423
|
+
chain_encoding_list.append(c * np.ones(chain_length))
|
|
424
|
+
l1 += chain_length
|
|
425
|
+
residue_idx[l0:l1] = 100 * (c - 1) + np.arange(l0, l1)
|
|
426
|
+
l0 += chain_length
|
|
427
|
+
c += 1
|
|
428
|
+
x_chain_all = np.concatenate(x_chain_list, 0) # [L, 4, 3]
|
|
429
|
+
chain_seq_all = "".join(chain_seq_list)
|
|
430
|
+
# [L,] 1.0 for places that need to be predicted
|
|
431
|
+
chain_mask_all = np.concatenate(chain_mask_list, 0)
|
|
432
|
+
chain_encoding_all = np.concatenate(chain_encoding_list, 0)
|
|
433
|
+
|
|
434
|
+
# Convert to labels
|
|
435
|
+
alphabet = 'ACDEFGHIKLMNPQRSTVWYX'
|
|
436
|
+
chain_seq_label_all = np.asarray(
|
|
437
|
+
[alphabet.index(a) for a in chain_seq_all], dtype=np.int32)
|
|
438
|
+
|
|
439
|
+
isnan = np.isnan(x_chain_all)
|
|
440
|
+
mask = np.isfinite(np.sum(x_chain_all, (1, 2))).astype(np.float32)
|
|
441
|
+
x_chain_all[isnan] = 0.
|
|
442
|
+
|
|
443
|
+
# Conversion
|
|
444
|
+
return (
|
|
445
|
+
torch.from_numpy(x_chain_all).to(dtype=torch.float32),
|
|
446
|
+
torch.from_numpy(chain_seq_label_all).to(dtype=torch.long),
|
|
447
|
+
torch.from_numpy(mask).to(dtype=torch.float32),
|
|
448
|
+
torch.from_numpy(chain_mask_all).to(dtype=torch.float32),
|
|
449
|
+
torch.from_numpy(residue_idx).to(dtype=torch.long),
|
|
450
|
+
torch.from_numpy(chain_encoding_all).to(dtype=torch.long),
|
|
451
|
+
)
|
torch_geometric/datasets/qm9.py
CHANGED
|
@@ -202,7 +202,7 @@ class QM9(InMemoryDataset):
|
|
|
202
202
|
from rdkit import Chem, RDLogger
|
|
203
203
|
from rdkit.Chem.rdchem import BondType as BT
|
|
204
204
|
from rdkit.Chem.rdchem import HybridizationType
|
|
205
|
-
RDLogger.DisableLog('rdApp.*')
|
|
205
|
+
RDLogger.DisableLog('rdApp.*')
|
|
206
206
|
WITH_RDKIT = True
|
|
207
207
|
|
|
208
208
|
except ImportError:
|
|
@@ -0,0 +1,269 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import os
|
|
3
|
+
import os.path as osp
|
|
4
|
+
from glob import glob
|
|
5
|
+
from typing import Callable, Dict, List, Optional
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import torch
|
|
9
|
+
from tqdm import tqdm
|
|
10
|
+
|
|
11
|
+
from torch_geometric.data import (
|
|
12
|
+
Data,
|
|
13
|
+
InMemoryDataset,
|
|
14
|
+
download_url,
|
|
15
|
+
extract_zip,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class Teeth3DS(InMemoryDataset):
|
|
20
|
+
r"""The Teeth3DS+ dataset from the `"An Extended Benchmark for Intra-oral
|
|
21
|
+
3D Scans Analysis" <https://crns-smartvision.github.io/teeth3ds/>`_ paper.
|
|
22
|
+
|
|
23
|
+
This dataset is the first comprehensive public benchmark designed to
|
|
24
|
+
advance the field of intra-oral 3D scan analysis developed as part of the
|
|
25
|
+
3DTeethSeg 2022 and 3DTeethLand 2024 MICCAI challenges, aiming to drive
|
|
26
|
+
research in teeth identification, segmentation, labeling, 3D modeling,
|
|
27
|
+
and dental landmark identification.
|
|
28
|
+
The dataset includes at least 1,800 intra-oral scans (containing 23,999
|
|
29
|
+
annotated teeth) collected from 900 patients, covering both upper and lower
|
|
30
|
+
jaws separately.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
root (str): Root directory where the dataset should be saved.
|
|
34
|
+
split (str): The split name (one of :obj:`"Teeth3DS"`,
|
|
35
|
+
:obj:`"3DTeethSeg22_challenge"` or :obj:`"3DTeethLand_challenge"`).
|
|
36
|
+
train (bool, optional): If :obj:`True`, loads the training dataset,
|
|
37
|
+
otherwise the test dataset. (default: :obj:`True`)
|
|
38
|
+
num_samples (int, optional): Number of points to sample from each mesh.
|
|
39
|
+
(default: :obj:`30000`)
|
|
40
|
+
transform (callable, optional): A function/transform that takes in an
|
|
41
|
+
:obj:`torch_geometric.data.Data` object and returns a transformed
|
|
42
|
+
version. The data object will be transformed before every access.
|
|
43
|
+
(default: :obj:`None`)
|
|
44
|
+
pre_transform (callable, optional): A function/transform that takes in
|
|
45
|
+
an :obj:`torch_geometric.data.Data` object and returns a
|
|
46
|
+
transformed version. The data object will be transformed before
|
|
47
|
+
being saved to disk. (default: :obj:`None`)
|
|
48
|
+
force_reload (bool, optional): Whether to re-process the dataset.
|
|
49
|
+
(default: :obj:`False`)
|
|
50
|
+
"""
|
|
51
|
+
urls = {
|
|
52
|
+
'data_part_1.zip':
|
|
53
|
+
'https://osf.io/download/qhprs/',
|
|
54
|
+
'data_part_2.zip':
|
|
55
|
+
'https://osf.io/download/4pwnr/',
|
|
56
|
+
'data_part_3.zip':
|
|
57
|
+
'https://osf.io/download/frwdp/',
|
|
58
|
+
'data_part_4.zip':
|
|
59
|
+
'https://osf.io/download/2arn4/',
|
|
60
|
+
'data_part_5.zip':
|
|
61
|
+
'https://osf.io/download/xrz5f/',
|
|
62
|
+
'data_part_6.zip':
|
|
63
|
+
'https://osf.io/download/23hgq/',
|
|
64
|
+
'data_part_7.zip':
|
|
65
|
+
'https://osf.io/download/u83ad/',
|
|
66
|
+
'train_test_split':
|
|
67
|
+
'https://files.de-1.osf.io/v1/'
|
|
68
|
+
'resources/xctdy/providers/osfstorage/?zip='
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
sample_url = {
|
|
72
|
+
'teeth3ds_sample': 'https://osf.io/download/vr38s/',
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
landmarks_urls = {
|
|
76
|
+
'3DTeethLand_landmarks_train.zip': 'https://osf.io/download/k5hbj/',
|
|
77
|
+
'3DTeethLand_landmarks_test.zip': 'https://osf.io/download/sqw5e/',
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
def __init__(
|
|
81
|
+
self,
|
|
82
|
+
root: str,
|
|
83
|
+
split:
|
|
84
|
+
str = 'Teeth3DS', # [3DTeethSeg22_challenge, 3DTeethLand_challenge]
|
|
85
|
+
train: bool = True,
|
|
86
|
+
num_samples: int = 30000,
|
|
87
|
+
transform: Optional[Callable] = None,
|
|
88
|
+
pre_transform: Optional[Callable] = None,
|
|
89
|
+
force_reload: bool = False,
|
|
90
|
+
) -> None:
|
|
91
|
+
|
|
92
|
+
self.mode = 'training' if train else 'testing'
|
|
93
|
+
self.split = split
|
|
94
|
+
self.num_samples = num_samples
|
|
95
|
+
|
|
96
|
+
super().__init__(root, transform, pre_transform,
|
|
97
|
+
force_reload=force_reload)
|
|
98
|
+
|
|
99
|
+
@property
|
|
100
|
+
def processed_dir(self) -> str:
|
|
101
|
+
return os.path.join(self.root, f'processed_{self.split}_{self.mode}')
|
|
102
|
+
|
|
103
|
+
@property
|
|
104
|
+
def raw_file_names(self) -> List[str]:
|
|
105
|
+
return ['license.txt']
|
|
106
|
+
|
|
107
|
+
@property
|
|
108
|
+
def processed_file_names(self) -> List[str]:
|
|
109
|
+
# Directory containing train/test split files:
|
|
110
|
+
split_subdir = 'teeth3ds_sample' if self.split == 'sample' else ''
|
|
111
|
+
split_dir = osp.join(
|
|
112
|
+
self.raw_dir,
|
|
113
|
+
split_subdir,
|
|
114
|
+
f'{self.split}_train_test_split',
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
split_files = glob(osp.join(split_dir, f'{self.mode}*.txt'))
|
|
118
|
+
|
|
119
|
+
# Collect all file names from the split files:
|
|
120
|
+
combined_list = []
|
|
121
|
+
for file_path in split_files:
|
|
122
|
+
with open(file_path) as file:
|
|
123
|
+
combined_list.extend(file.read().splitlines())
|
|
124
|
+
|
|
125
|
+
# Generate the list of processed file paths:
|
|
126
|
+
return [f'{file_name}.pt' for file_name in combined_list]
|
|
127
|
+
|
|
128
|
+
def download(self) -> None:
|
|
129
|
+
if self.split == 'sample':
|
|
130
|
+
for key, url in self.sample_url.items():
|
|
131
|
+
path = download_url(url, self.root, filename=key)
|
|
132
|
+
extract_zip(path, self.raw_dir)
|
|
133
|
+
os.unlink(path)
|
|
134
|
+
else:
|
|
135
|
+
for key, url in self.urls.items():
|
|
136
|
+
path = download_url(url, self.root, filename=key)
|
|
137
|
+
extract_zip(path, self.raw_dir)
|
|
138
|
+
os.unlink(path)
|
|
139
|
+
for key, url in self.landmarks_urls.items():
|
|
140
|
+
path = download_url(url, self.root, filename=key)
|
|
141
|
+
extract_zip(path, self.raw_dir) # Extract each downloaded part
|
|
142
|
+
os.unlink(path)
|
|
143
|
+
|
|
144
|
+
def process_file(self, file_path: str) -> Optional[Data]:
|
|
145
|
+
"""Processes the input file path to load mesh data, annotations,
|
|
146
|
+
and prepare the input features for a graph-based deep learning model.
|
|
147
|
+
"""
|
|
148
|
+
import trimesh
|
|
149
|
+
from fpsample import bucket_fps_kdline_sampling
|
|
150
|
+
|
|
151
|
+
mesh = trimesh.load_mesh(file_path)
|
|
152
|
+
|
|
153
|
+
if isinstance(mesh, list):
|
|
154
|
+
# Handle the case where a list of Geometry objects is returned
|
|
155
|
+
mesh = mesh[0]
|
|
156
|
+
|
|
157
|
+
vertices = mesh.vertices
|
|
158
|
+
vertex_normals = mesh.vertex_normals
|
|
159
|
+
|
|
160
|
+
# Perform sampling on mesh vertices:
|
|
161
|
+
if len(vertices) < self.num_samples:
|
|
162
|
+
sampled_indices = np.random.choice(
|
|
163
|
+
len(vertices),
|
|
164
|
+
self.num_samples,
|
|
165
|
+
replace=True,
|
|
166
|
+
)
|
|
167
|
+
else:
|
|
168
|
+
sampled_indices = bucket_fps_kdline_sampling(
|
|
169
|
+
vertices,
|
|
170
|
+
self.num_samples,
|
|
171
|
+
h=5,
|
|
172
|
+
start_idx=0,
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
if len(sampled_indices) != self.num_samples:
|
|
176
|
+
raise RuntimeError(f"Sampled points mismatch, expected "
|
|
177
|
+
f"{self.num_samples} points, but got "
|
|
178
|
+
f"{len(sampled_indices)} for '{file_path}'")
|
|
179
|
+
|
|
180
|
+
# Extract features and annotations for the sampled points:
|
|
181
|
+
pos = torch.tensor(vertices[sampled_indices], dtype=torch.float)
|
|
182
|
+
x = torch.tensor(vertex_normals[sampled_indices], dtype=torch.float)
|
|
183
|
+
|
|
184
|
+
# Load segmentation annotations:
|
|
185
|
+
seg_annotation_path = file_path.replace('.obj', '.json')
|
|
186
|
+
if osp.exists(seg_annotation_path):
|
|
187
|
+
with open(seg_annotation_path) as f:
|
|
188
|
+
seg_annotations = json.load(f)
|
|
189
|
+
y = torch.tensor(
|
|
190
|
+
np.asarray(seg_annotations['labels'])[sampled_indices],
|
|
191
|
+
dtype=torch.float)
|
|
192
|
+
instances = torch.tensor(
|
|
193
|
+
np.asarray(seg_annotations['instances'])[sampled_indices],
|
|
194
|
+
dtype=torch.float)
|
|
195
|
+
else:
|
|
196
|
+
y = torch.empty(0, 3)
|
|
197
|
+
instances = torch.empty(0, 3)
|
|
198
|
+
|
|
199
|
+
# Load landmarks annotations:
|
|
200
|
+
landmarks_annotation_path = file_path.replace('.obj', '__kpt.json')
|
|
201
|
+
|
|
202
|
+
# Parse keypoint annotations into structured tensors:
|
|
203
|
+
keypoints_dict: Dict[str, List] = {
|
|
204
|
+
key: []
|
|
205
|
+
for key in [
|
|
206
|
+
'Mesial', 'Distal', 'Cusp', 'InnerPoint', 'OuterPoint',
|
|
207
|
+
'FacialPoint'
|
|
208
|
+
]
|
|
209
|
+
}
|
|
210
|
+
keypoint_tensors: Dict[str, torch.Tensor] = {
|
|
211
|
+
key: torch.empty(0, 3)
|
|
212
|
+
for key in [
|
|
213
|
+
'Mesial', 'Distal', 'Cusp', 'InnerPoint', 'OuterPoint',
|
|
214
|
+
'FacialPoint'
|
|
215
|
+
]
|
|
216
|
+
}
|
|
217
|
+
if osp.exists(landmarks_annotation_path):
|
|
218
|
+
with open(landmarks_annotation_path) as f:
|
|
219
|
+
landmarks_annotations = json.load(f)
|
|
220
|
+
|
|
221
|
+
for keypoint in landmarks_annotations['objects']:
|
|
222
|
+
keypoints_dict[keypoint['class']].extend(keypoint['coord'])
|
|
223
|
+
|
|
224
|
+
keypoint_tensors = {
|
|
225
|
+
k: torch.tensor(np.asarray(v),
|
|
226
|
+
dtype=torch.float).reshape(-1, 3)
|
|
227
|
+
for k, v in keypoints_dict.items()
|
|
228
|
+
}
|
|
229
|
+
|
|
230
|
+
data = Data(
|
|
231
|
+
pos=pos,
|
|
232
|
+
x=x,
|
|
233
|
+
y=y,
|
|
234
|
+
instances=instances,
|
|
235
|
+
jaw=file_path.split('.obj')[0].split('_')[1],
|
|
236
|
+
mesial=keypoint_tensors['Mesial'],
|
|
237
|
+
distal=keypoint_tensors['Distal'],
|
|
238
|
+
cusp=keypoint_tensors['Cusp'],
|
|
239
|
+
inner_point=keypoint_tensors['InnerPoint'],
|
|
240
|
+
outer_point=keypoint_tensors['OuterPoint'],
|
|
241
|
+
facial_point=keypoint_tensors['FacialPoint'],
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
if self.pre_transform is not None:
|
|
245
|
+
data = self.pre_transform(data)
|
|
246
|
+
|
|
247
|
+
return data
|
|
248
|
+
|
|
249
|
+
def process(self) -> None:
|
|
250
|
+
for file in tqdm(self.processed_file_names):
|
|
251
|
+
name = file.split('.')[0]
|
|
252
|
+
path = osp.join(self.raw_dir, '**', '*', name + '.obj')
|
|
253
|
+
paths = glob(path)
|
|
254
|
+
if len(paths) == 1:
|
|
255
|
+
data = self.process_file(paths[0])
|
|
256
|
+
torch.save(data, osp.join(self.processed_dir, file))
|
|
257
|
+
|
|
258
|
+
def len(self) -> int:
|
|
259
|
+
return len(self.processed_file_names)
|
|
260
|
+
|
|
261
|
+
def get(self, idx: int) -> Data:
|
|
262
|
+
return torch.load(
|
|
263
|
+
osp.join(self.processed_dir, self.processed_file_names[idx]),
|
|
264
|
+
weights_only=False,
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
def __repr__(self) -> str:
|
|
268
|
+
return (f'{self.__class__.__name__}({len(self)}, '
|
|
269
|
+
f'mode={self.mode}, split={self.split})')
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
# The below is to suppress the warning on torch.nn.conv.MeshCNNConv::update
|
|
2
2
|
# pyright: reportIncompatibleMethodOverride=false
|
|
3
|
+
import warnings
|
|
3
4
|
from typing import Optional
|
|
4
|
-
from warnings import warn
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
7
|
from torch.nn import Linear, Module, ModuleList
|
|
@@ -456,13 +456,10 @@ class MeshCNNConv(MessagePassing):
|
|
|
456
456
|
{type(network)}"
|
|
457
457
|
if not hasattr(network, "in_channels") and \
|
|
458
458
|
not hasattr(network, "in_features"):
|
|
459
|
-
warn(
|
|
460
|
-
f"kernel[{i}] does not have attribute
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
{self.in_channels}-dimensional tensor. \
|
|
464
|
-
Still, assuming user configured \
|
|
465
|
-
correctly. Continuing..", stacklevel=2)
|
|
459
|
+
warnings.warn(
|
|
460
|
+
f"kernel[{i}] does not have attribute 'in_channels' nor "
|
|
461
|
+
f"'out_features'. The network must take as input a "
|
|
462
|
+
f"{self.in_channels}-dimensional tensor.", stacklevel=2)
|
|
466
463
|
else:
|
|
467
464
|
input_dimension = getattr(network, "in_channels",
|
|
468
465
|
network.in_features)
|
|
@@ -475,13 +472,10 @@ class MeshCNNConv(MessagePassing):
|
|
|
475
472
|
|
|
476
473
|
if not hasattr(network, "out_channels") and \
|
|
477
474
|
not hasattr(network, "out_features"):
|
|
478
|
-
warn(
|
|
479
|
-
f"kernel[{i}] does not have attribute
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
{self.in_channels}-dimensional tensor. \
|
|
483
|
-
Still, assuming user configured \
|
|
484
|
-
correctly. Continuing..", stacklevel=2)
|
|
475
|
+
warnings.warn(
|
|
476
|
+
f"kernel[{i}] does not have attribute 'in_channels' nor "
|
|
477
|
+
f"'out_features'. The network must take as input a "
|
|
478
|
+
f"{self.in_channels}-dimensional tensor.", stacklevel=2)
|
|
485
479
|
else:
|
|
486
480
|
output_dimension = getattr(network, "out_channels",
|
|
487
481
|
network.out_features)
|
|
@@ -32,6 +32,7 @@ from .visnet import ViSNet
|
|
|
32
32
|
from .g_retriever import GRetriever
|
|
33
33
|
from .git_mol import GITMol
|
|
34
34
|
from .molecule_gpt import MoleculeGPT
|
|
35
|
+
from .protein_mpnn import ProteinMPNN
|
|
35
36
|
from .glem import GLEM
|
|
36
37
|
from .sgformer import SGFormer
|
|
37
38
|
# Deprecated:
|
|
@@ -86,6 +87,7 @@ __all__ = classes = [
|
|
|
86
87
|
'GRetriever',
|
|
87
88
|
'GITMol',
|
|
88
89
|
'MoleculeGPT',
|
|
90
|
+
'ProteinMPNN',
|
|
89
91
|
'GLEM',
|
|
90
92
|
'SGFormer',
|
|
91
93
|
'ARLinkPredictor',
|
|
@@ -8,6 +8,13 @@ from torch_geometric.loader import DataLoader, NeighborLoader
|
|
|
8
8
|
from torch_geometric.nn.models import GraphSAGE, basic_gnn
|
|
9
9
|
|
|
10
10
|
|
|
11
|
+
def deal_nan(x):
|
|
12
|
+
if isinstance(x, torch.Tensor):
|
|
13
|
+
x = x.clone()
|
|
14
|
+
x[torch.isnan(x)] = 0.0
|
|
15
|
+
return x
|
|
16
|
+
|
|
17
|
+
|
|
11
18
|
class GLEM(torch.nn.Module):
|
|
12
19
|
r"""This GNN+LM co-training model is based on GLEM from the `"Learning on
|
|
13
20
|
Large-scale Text-attributed Graphs via Variational Inference"
|
|
@@ -379,9 +386,6 @@ class GLEM(torch.nn.Module):
|
|
|
379
386
|
is_augmented: use EM or just train GNN and LM with gold data
|
|
380
387
|
|
|
381
388
|
"""
|
|
382
|
-
def deal_nan(x):
|
|
383
|
-
return 0 if torch.isnan(x) else x
|
|
384
|
-
|
|
385
389
|
if is_augmented and (sum(~is_gold) > 0):
|
|
386
390
|
mle_loss = deal_nan(loss_func(logits[is_gold], labels[is_gold]))
|
|
387
391
|
# all other labels beside from ground truth(gold labels)
|
|
@@ -0,0 +1,304 @@
|
|
|
1
|
+
from itertools import product
|
|
2
|
+
from typing import Tuple
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn.functional as F
|
|
6
|
+
|
|
7
|
+
from torch_geometric.nn import knn_graph
|
|
8
|
+
from torch_geometric.nn.conv import MessagePassing
|
|
9
|
+
from torch_geometric.utils import to_dense_adj, to_dense_batch
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class PositionWiseFeedForward(torch.nn.Module):
|
|
13
|
+
def __init__(self, in_channels: int, hidden_channels: int) -> None:
|
|
14
|
+
super().__init__()
|
|
15
|
+
self.out = torch.nn.Sequential(
|
|
16
|
+
torch.nn.Linear(in_channels, hidden_channels),
|
|
17
|
+
torch.nn.GELU(),
|
|
18
|
+
torch.nn.Linear(hidden_channels, in_channels),
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
22
|
+
return self.out(x)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class PositionalEncoding(torch.nn.Module):
|
|
26
|
+
def __init__(self, hidden_channels: int,
|
|
27
|
+
max_relative_feature: int = 32) -> None:
|
|
28
|
+
super().__init__()
|
|
29
|
+
self.max_relative_feature = max_relative_feature
|
|
30
|
+
self.emb = torch.nn.Embedding(2 * max_relative_feature + 2,
|
|
31
|
+
hidden_channels)
|
|
32
|
+
|
|
33
|
+
def forward(self, offset, mask) -> torch.Tensor:
|
|
34
|
+
d = torch.clip(offset + self.max_relative_feature, 0,
|
|
35
|
+
2 * self.max_relative_feature) * mask + (1 - mask) * (
|
|
36
|
+
2 * self.max_relative_feature + 1) # noqa: E501
|
|
37
|
+
return self.emb(d.long())
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class Encoder(MessagePassing):
|
|
41
|
+
def __init__(
|
|
42
|
+
self,
|
|
43
|
+
in_channels: int,
|
|
44
|
+
hidden_channels: int,
|
|
45
|
+
dropout: float = 0.1,
|
|
46
|
+
scale: float = 30,
|
|
47
|
+
) -> None:
|
|
48
|
+
super().__init__()
|
|
49
|
+
self.out_v = torch.nn.Sequential(
|
|
50
|
+
torch.nn.Linear(in_channels, hidden_channels),
|
|
51
|
+
torch.nn.GELU(),
|
|
52
|
+
torch.nn.Linear(hidden_channels, hidden_channels),
|
|
53
|
+
torch.nn.GELU(),
|
|
54
|
+
torch.nn.Linear(hidden_channels, hidden_channels),
|
|
55
|
+
)
|
|
56
|
+
self.out_e = torch.nn.Sequential(
|
|
57
|
+
torch.nn.Linear(in_channels, hidden_channels),
|
|
58
|
+
torch.nn.GELU(),
|
|
59
|
+
torch.nn.Linear(hidden_channels, hidden_channels),
|
|
60
|
+
torch.nn.GELU(),
|
|
61
|
+
torch.nn.Linear(hidden_channels, hidden_channels),
|
|
62
|
+
)
|
|
63
|
+
self.dropout1 = torch.nn.Dropout(dropout)
|
|
64
|
+
self.dropout2 = torch.nn.Dropout(dropout)
|
|
65
|
+
self.dropout3 = torch.nn.Dropout(dropout)
|
|
66
|
+
self.norm1 = torch.nn.LayerNorm(hidden_channels)
|
|
67
|
+
self.norm2 = torch.nn.LayerNorm(hidden_channels)
|
|
68
|
+
self.norm3 = torch.nn.LayerNorm(hidden_channels)
|
|
69
|
+
self.scale = scale
|
|
70
|
+
self.dense = PositionWiseFeedForward(hidden_channels,
|
|
71
|
+
hidden_channels * 4)
|
|
72
|
+
|
|
73
|
+
def forward(
|
|
74
|
+
self,
|
|
75
|
+
x: torch.Tensor,
|
|
76
|
+
edge_index: torch.Tensor,
|
|
77
|
+
edge_attr: torch.Tensor,
|
|
78
|
+
) -> torch.Tensor:
|
|
79
|
+
# x: [N, d_v]
|
|
80
|
+
# edge_index: [2, E]
|
|
81
|
+
# edge_attr: [E, d_e]
|
|
82
|
+
# update node features
|
|
83
|
+
h_message = self.propagate(x=x, edge_index=edge_index,
|
|
84
|
+
edge_attr=edge_attr)
|
|
85
|
+
dh = h_message / self.scale
|
|
86
|
+
x = self.norm1(x + self.dropout1(dh))
|
|
87
|
+
dh = self.dense(x)
|
|
88
|
+
x = self.norm2(x + self.dropout2(dh))
|
|
89
|
+
# update edge features
|
|
90
|
+
row, col = edge_index
|
|
91
|
+
x_i, x_j = x[row], x[col]
|
|
92
|
+
h_e = torch.cat([x_i, x_j, edge_attr], dim=-1)
|
|
93
|
+
h_e = self.out_e(h_e)
|
|
94
|
+
edge_attr = self.norm3(edge_attr + self.dropout3(h_e))
|
|
95
|
+
return x, edge_attr
|
|
96
|
+
|
|
97
|
+
def message(self, x_i: torch.Tensor, x_j: torch.Tensor,
|
|
98
|
+
edge_attr: torch.Tensor) -> torch.Tensor:
|
|
99
|
+
h = torch.cat([x_i, x_j, edge_attr], dim=-1) # [E, 2*d_v + d_e]
|
|
100
|
+
h = self.out_e(h) # [E, d_e]
|
|
101
|
+
return h
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
class Decoder(MessagePassing):
|
|
105
|
+
def __init__(
|
|
106
|
+
self,
|
|
107
|
+
in_channels: int,
|
|
108
|
+
hidden_channels: int,
|
|
109
|
+
dropout: float = 0.1,
|
|
110
|
+
scale: float = 30,
|
|
111
|
+
) -> None:
|
|
112
|
+
super().__init__()
|
|
113
|
+
self.out_v = torch.nn.Sequential(
|
|
114
|
+
torch.nn.Linear(in_channels, hidden_channels),
|
|
115
|
+
torch.nn.GELU(),
|
|
116
|
+
torch.nn.Linear(hidden_channels, hidden_channels),
|
|
117
|
+
torch.nn.GELU(),
|
|
118
|
+
torch.nn.Linear(hidden_channels, hidden_channels),
|
|
119
|
+
)
|
|
120
|
+
self.dropout1 = torch.nn.Dropout(dropout)
|
|
121
|
+
self.dropout2 = torch.nn.Dropout(dropout)
|
|
122
|
+
self.norm1 = torch.nn.LayerNorm(hidden_channels)
|
|
123
|
+
self.norm2 = torch.nn.LayerNorm(hidden_channels)
|
|
124
|
+
self.scale = scale
|
|
125
|
+
self.dense = PositionWiseFeedForward(hidden_channels,
|
|
126
|
+
hidden_channels * 4)
|
|
127
|
+
|
|
128
|
+
def forward(
|
|
129
|
+
self,
|
|
130
|
+
x: torch.Tensor,
|
|
131
|
+
edge_index: torch.Tensor,
|
|
132
|
+
edge_attr: torch.Tensor,
|
|
133
|
+
x_label: torch.Tensor,
|
|
134
|
+
mask: torch.Tensor,
|
|
135
|
+
) -> torch.Tensor:
|
|
136
|
+
# x: [N, d_v]
|
|
137
|
+
# edge_index: [2, E]
|
|
138
|
+
# edge_attr: [E, d_e]
|
|
139
|
+
h_message = self.propagate(x=x, x_label=x_label, edge_index=edge_index,
|
|
140
|
+
edge_attr=edge_attr, mask=mask)
|
|
141
|
+
dh = h_message / self.scale
|
|
142
|
+
x = self.norm1(x + self.dropout1(dh))
|
|
143
|
+
dh = self.dense(x)
|
|
144
|
+
x = self.norm2(x + self.dropout2(dh))
|
|
145
|
+
return x
|
|
146
|
+
|
|
147
|
+
def message(self, x_i: torch.Tensor, x_j: torch.Tensor,
|
|
148
|
+
x_label_j: torch.Tensor, edge_attr: torch.Tensor,
|
|
149
|
+
mask: torch.Tensor) -> torch.Tensor:
|
|
150
|
+
h_1 = torch.cat([x_j, edge_attr, x_label_j], dim=-1)
|
|
151
|
+
h_0 = torch.cat([x_j, edge_attr, torch.zeros_like(x_label_j)], dim=-1)
|
|
152
|
+
h = h_1 * mask + h_0 * (1 - mask)
|
|
153
|
+
h = torch.concat([x_i, h], dim=-1)
|
|
154
|
+
h = self.out_v(h)
|
|
155
|
+
return h
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
class ProteinMPNN(torch.nn.Module):
|
|
159
|
+
def __init__(
|
|
160
|
+
self,
|
|
161
|
+
hidden_dim: int = 128,
|
|
162
|
+
num_encoder_layers: int = 3,
|
|
163
|
+
num_decoder_layers: int = 3,
|
|
164
|
+
num_neighbors: int = 30,
|
|
165
|
+
num_rbf: int = 16,
|
|
166
|
+
dropout: float = 0.1,
|
|
167
|
+
augment_eps: float = 0.2,
|
|
168
|
+
num_positional_embedding: int = 16,
|
|
169
|
+
vocab_size: int = 21,
|
|
170
|
+
) -> None:
|
|
171
|
+
super().__init__()
|
|
172
|
+
self.augment_eps = augment_eps
|
|
173
|
+
self.hidden_dim = hidden_dim
|
|
174
|
+
self.num_neighbors = num_neighbors
|
|
175
|
+
self.num_rbf = num_rbf
|
|
176
|
+
self.embedding = PositionalEncoding(num_positional_embedding)
|
|
177
|
+
self.edge_mlp = torch.nn.Sequential(
|
|
178
|
+
torch.nn.Linear(num_positional_embedding + 400, hidden_dim),
|
|
179
|
+
torch.nn.LayerNorm(hidden_dim),
|
|
180
|
+
torch.nn.Linear(hidden_dim, hidden_dim),
|
|
181
|
+
)
|
|
182
|
+
self.label_embedding = torch.nn.Embedding(vocab_size, hidden_dim)
|
|
183
|
+
self.encoder_layers = torch.nn.ModuleList([
|
|
184
|
+
Encoder(hidden_dim * 3, hidden_dim, dropout)
|
|
185
|
+
for _ in range(num_encoder_layers)
|
|
186
|
+
])
|
|
187
|
+
|
|
188
|
+
self.decoder_layers = torch.nn.ModuleList([
|
|
189
|
+
Decoder(hidden_dim * 4, hidden_dim, dropout)
|
|
190
|
+
for _ in range(num_decoder_layers)
|
|
191
|
+
])
|
|
192
|
+
self.output = torch.nn.Linear(hidden_dim, vocab_size)
|
|
193
|
+
|
|
194
|
+
self.reset_parameters()
|
|
195
|
+
|
|
196
|
+
def reset_parameters(self):
|
|
197
|
+
for p in self.parameters():
|
|
198
|
+
if p.dim() > 1:
|
|
199
|
+
torch.nn.init.xavier_uniform_(p)
|
|
200
|
+
|
|
201
|
+
def _featurize(
|
|
202
|
+
self,
|
|
203
|
+
x: torch.Tensor,
|
|
204
|
+
mask: torch.Tensor,
|
|
205
|
+
batch: torch.Tensor,
|
|
206
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
207
|
+
N, Ca, C, O = (x[:, i, :] for i in range(4)) # noqa: E741
|
|
208
|
+
b = Ca - N
|
|
209
|
+
c = C - Ca
|
|
210
|
+
a = torch.cross(b, c, dim=-1)
|
|
211
|
+
Cb = -0.58273431 * a + 0.56802827 * b - 0.54067466 * c + Ca
|
|
212
|
+
|
|
213
|
+
valid_mask = mask.bool()
|
|
214
|
+
valid_Ca = Ca[valid_mask]
|
|
215
|
+
valid_batch = batch[valid_mask]
|
|
216
|
+
|
|
217
|
+
edge_index = knn_graph(valid_Ca, k=self.num_neighbors,
|
|
218
|
+
batch=valid_batch, loop=True)
|
|
219
|
+
|
|
220
|
+
row, col = edge_index
|
|
221
|
+
original_indices = torch.arange(Ca.size(0),
|
|
222
|
+
device=x.device)[valid_mask]
|
|
223
|
+
edge_index_original = torch.stack(
|
|
224
|
+
[original_indices[row], original_indices[col]], dim=0)
|
|
225
|
+
row, col = edge_index_original
|
|
226
|
+
|
|
227
|
+
rbf_all = []
|
|
228
|
+
for A, B in list(product([N, Ca, C, O, Cb], repeat=2)):
|
|
229
|
+
distances = torch.sqrt(torch.sum((A[row] - B[col])**2, 1) + 1e-6)
|
|
230
|
+
rbf = self._rbf(distances)
|
|
231
|
+
rbf_all.append(rbf)
|
|
232
|
+
|
|
233
|
+
return edge_index_original, torch.cat(rbf_all, dim=-1)
|
|
234
|
+
|
|
235
|
+
def _rbf(self, D: torch.Tensor) -> torch.Tensor:
|
|
236
|
+
D_min, D_max, D_count = 2., 22., self.num_rbf
|
|
237
|
+
D_mu = torch.linspace(D_min, D_max, D_count, device=D.device)
|
|
238
|
+
D_mu = D_mu.view([1, -1])
|
|
239
|
+
D_sigma = (D_max - D_min) / D_count
|
|
240
|
+
D_expand = torch.unsqueeze(D, -1)
|
|
241
|
+
RBF = torch.exp(-((D_expand - D_mu) / D_sigma)**2)
|
|
242
|
+
return RBF
|
|
243
|
+
|
|
244
|
+
def forward(
|
|
245
|
+
self,
|
|
246
|
+
x: torch.Tensor,
|
|
247
|
+
chain_seq_label: torch.Tensor,
|
|
248
|
+
mask: torch.Tensor,
|
|
249
|
+
chain_mask_all: torch.Tensor,
|
|
250
|
+
residue_idx: torch.Tensor,
|
|
251
|
+
chain_encoding_all: torch.Tensor,
|
|
252
|
+
batch: torch.Tensor,
|
|
253
|
+
) -> torch.Tensor:
|
|
254
|
+
device = x.device
|
|
255
|
+
if self.training and self.augment_eps > 0:
|
|
256
|
+
x = x + self.augment_eps * torch.randn_like(x)
|
|
257
|
+
|
|
258
|
+
edge_index, edge_attr = self._featurize(x, mask, batch)
|
|
259
|
+
|
|
260
|
+
row, col = edge_index
|
|
261
|
+
offset = residue_idx[row] - residue_idx[col]
|
|
262
|
+
# find self vs non-self interaction
|
|
263
|
+
e_chains = ((chain_encoding_all[row] -
|
|
264
|
+
chain_encoding_all[col]) == 0).long()
|
|
265
|
+
e_pos = self.embedding(offset, e_chains)
|
|
266
|
+
h_e = self.edge_mlp(torch.cat([edge_attr, e_pos], dim=-1))
|
|
267
|
+
h_v = torch.zeros(x.size(0), self.hidden_dim, device=x.device)
|
|
268
|
+
|
|
269
|
+
# encoder
|
|
270
|
+
for encoder in self.encoder_layers:
|
|
271
|
+
h_v, h_e = encoder(h_v, edge_index, h_e)
|
|
272
|
+
|
|
273
|
+
# mask
|
|
274
|
+
h_label = self.label_embedding(chain_seq_label)
|
|
275
|
+
batch_chain_mask_all, _ = to_dense_batch(chain_mask_all * mask,
|
|
276
|
+
batch) # [B, N]
|
|
277
|
+
# 0 - visible - encoder, 1 - masked - decoder
|
|
278
|
+
decoding_order = torch.argsort(
|
|
279
|
+
(batch_chain_mask_all + 1e-4) * (torch.abs(
|
|
280
|
+
torch.randn(batch_chain_mask_all.shape, device=device))))
|
|
281
|
+
mask_size = batch_chain_mask_all.size(1)
|
|
282
|
+
permutation_matrix_reverse = F.one_hot(decoding_order,
|
|
283
|
+
num_classes=mask_size).float()
|
|
284
|
+
order_mask_backward = torch.einsum(
|
|
285
|
+
'ij, biq, bjp->bqp',
|
|
286
|
+
1 - torch.triu(torch.ones(mask_size, mask_size, device=device)),
|
|
287
|
+
permutation_matrix_reverse,
|
|
288
|
+
permutation_matrix_reverse,
|
|
289
|
+
)
|
|
290
|
+
adj = to_dense_adj(edge_index, batch)
|
|
291
|
+
mask_attend = order_mask_backward[adj.bool()].unsqueeze(-1)
|
|
292
|
+
|
|
293
|
+
# decoder
|
|
294
|
+
for decoder in self.decoder_layers:
|
|
295
|
+
h_v = decoder(
|
|
296
|
+
h_v,
|
|
297
|
+
edge_index,
|
|
298
|
+
h_e,
|
|
299
|
+
h_label,
|
|
300
|
+
mask_attend,
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
logits = self.output(h_v)
|
|
304
|
+
return F.log_softmax(logits, dim=-1)
|
torch_geometric/utils/convert.py
CHANGED
|
@@ -452,15 +452,22 @@ def to_cugraph(
|
|
|
452
452
|
g = cugraph.Graph(directed=directed)
|
|
453
453
|
df = cudf.from_dlpack(to_dlpack(edge_index.t()))
|
|
454
454
|
|
|
455
|
+
df = cudf.DataFrame({
|
|
456
|
+
'source':
|
|
457
|
+
cudf.from_dlpack(to_dlpack(edge_index[0])),
|
|
458
|
+
'destination':
|
|
459
|
+
cudf.from_dlpack(to_dlpack(edge_index[1])),
|
|
460
|
+
})
|
|
461
|
+
|
|
455
462
|
if edge_weight is not None:
|
|
456
463
|
assert edge_weight.dim() == 1
|
|
457
|
-
df['
|
|
464
|
+
df['weight'] = cudf.from_dlpack(to_dlpack(edge_weight))
|
|
458
465
|
|
|
459
466
|
g.from_cudf_edgelist(
|
|
460
467
|
df,
|
|
461
|
-
source=
|
|
462
|
-
destination=
|
|
463
|
-
edge_attr='
|
|
468
|
+
source='source',
|
|
469
|
+
destination='destination',
|
|
470
|
+
edge_attr='weight' if edge_weight is not None else None,
|
|
464
471
|
renumber=relabel_nodes,
|
|
465
472
|
)
|
|
466
473
|
|
|
@@ -476,13 +483,13 @@ def from_cugraph(g: Any) -> Tuple[Tensor, Optional[Tensor]]:
|
|
|
476
483
|
"""
|
|
477
484
|
df = g.view_edge_list()
|
|
478
485
|
|
|
479
|
-
src = from_dlpack(df[
|
|
480
|
-
dst = from_dlpack(df[
|
|
486
|
+
src = from_dlpack(df[g.source_columns].to_dlpack()).long()
|
|
487
|
+
dst = from_dlpack(df[g.destination_columns].to_dlpack()).long()
|
|
481
488
|
edge_index = torch.stack([src, dst], dim=0)
|
|
482
489
|
|
|
483
490
|
edge_weight = None
|
|
484
|
-
if
|
|
485
|
-
edge_weight = from_dlpack(df[
|
|
491
|
+
if g.weight_column is not None:
|
|
492
|
+
edge_weight = from_dlpack(df[g.weight_column].to_dlpack())
|
|
486
493
|
|
|
487
494
|
return edge_index, edge_weight
|
|
488
495
|
|
torch_geometric/utils/smiles.py
CHANGED
|
File without changes
|
{pyg_nightly-2.7.0.dev20250702.dist-info → pyg_nightly-2.7.0.dev20250703.dist-info}/licenses/LICENSE
RENAMED
|
File without changes
|