pyg-nightly 2.7.0.dev20250702__py3-none-any.whl → 2.7.0.dev20250704__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pyg-nightly might be problematic. Click here for more details.
- {pyg_nightly-2.7.0.dev20250702.dist-info → pyg_nightly-2.7.0.dev20250704.dist-info}/METADATA +1 -1
- {pyg_nightly-2.7.0.dev20250702.dist-info → pyg_nightly-2.7.0.dev20250704.dist-info}/RECORD +20 -15
- torch_geometric/__init__.py +1 -1
- torch_geometric/datasets/__init__.py +4 -0
- torch_geometric/datasets/git_mol_dataset.py +1 -1
- torch_geometric/datasets/molecule_gpt_dataset.py +1 -1
- torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
- torch_geometric/datasets/qm9.py +1 -1
- torch_geometric/datasets/teeth3ds.py +269 -0
- torch_geometric/nn/attention/__init__.py +2 -0
- torch_geometric/nn/attention/polynormer.py +107 -0
- torch_geometric/nn/conv/meshcnn_conv.py +9 -15
- torch_geometric/nn/models/__init__.py +4 -0
- torch_geometric/nn/models/glem.py +7 -3
- torch_geometric/nn/models/polynormer.py +206 -0
- torch_geometric/nn/models/protein_mpnn.py +304 -0
- torch_geometric/utils/convert.py +15 -8
- torch_geometric/utils/smiles.py +1 -1
- {pyg_nightly-2.7.0.dev20250702.dist-info → pyg_nightly-2.7.0.dev20250704.dist-info}/WHEEL +0 -0
- {pyg_nightly-2.7.0.dev20250702.dist-info → pyg_nightly-2.7.0.dev20250704.dist-info}/licenses/LICENSE +0 -0
{pyg_nightly-2.7.0.dev20250702.dist-info → pyg_nightly-2.7.0.dev20250704.dist-info}/METADATA
RENAMED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: pyg-nightly
|
|
3
|
-
Version: 2.7.0.
|
|
3
|
+
Version: 2.7.0.dev20250704
|
|
4
4
|
Summary: Graph Neural Network Library for PyTorch
|
|
5
5
|
Keywords: deep-learning,pytorch,geometric-deep-learning,graph-neural-networks,graph-convolutional-networks
|
|
6
6
|
Author-email: Matthias Fey <matthias@pyg.org>
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
torch_geometric/__init__.py,sha256=
|
|
1
|
+
torch_geometric/__init__.py,sha256=GOuL0XBOcsFqK-Q-c_STDpzZAG-vsctiDiU_Tg9W3t8,2250
|
|
2
2
|
torch_geometric/_compile.py,sha256=9yqMTBKatZPr40WavJz9FjNi7pQj8YZAZOyZmmRGXgc,1351
|
|
3
3
|
torch_geometric/_onnx.py,sha256=V9ffrIKSqhDw6xUZ12lkuSfNs48cQp2EeJ6Z19GfnVw,349
|
|
4
4
|
torch_geometric/backend.py,sha256=lVaf7aLoVaB3M-UcByUJ1G4T4FOK6LXAg0CF4W3E8jo,1575
|
|
@@ -55,7 +55,7 @@ torch_geometric/data/temporal.py,sha256=WOJ6gFrTLikaLhUvotyUF5ql14FkE5Ox3hNkdSp6
|
|
|
55
55
|
torch_geometric/data/view.py,sha256=XjkVSc-UWZFCT4DlXLShZtO8duhFQkS9gq88zZXANsk,1089
|
|
56
56
|
torch_geometric/data/lightning/__init__.py,sha256=w3En1tJfy3kSqe1MycpOyZpHFO3fxBCgNCUOznPA3YU,178
|
|
57
57
|
torch_geometric/data/lightning/datamodule.py,sha256=IjucsIKRcNv16DIqILQnqa_sz72a4-yivoySmEllv2o,29353
|
|
58
|
-
torch_geometric/datasets/__init__.py,sha256=
|
|
58
|
+
torch_geometric/datasets/__init__.py,sha256=rgfUmjd9U3o8renKVl81Brscx4LOtwWmt6qAoaG41C4,6417
|
|
59
59
|
torch_geometric/datasets/actor.py,sha256=oUxgJIX8bi5hJr1etWNYIFyVQNDDXi1nyVpHGGMEAGQ,4304
|
|
60
60
|
torch_geometric/datasets/airfrans.py,sha256=8cCBmHPttrlKY_iwfyr-K-CUX_JEDjrIOg3r9dQSN7o,5439
|
|
61
61
|
torch_geometric/datasets/airports.py,sha256=b3gkv3gY2JkUpmGiz36Z-g7EcnSfU8lBG1YsCOWdJ6k,3758
|
|
@@ -95,7 +95,7 @@ torch_geometric/datasets/gdelt_lite.py,sha256=zE1WagpgmsQARQhEgdCBtALRKyuQvIZqxT
|
|
|
95
95
|
torch_geometric/datasets/ged_dataset.py,sha256=dtd-C6pCygNHLXgVfg3ZTWtTVHKT13Q3GlGrze1_rpo,9551
|
|
96
96
|
torch_geometric/datasets/gemsec.py,sha256=oMTSryTgyed9z_4ydg3ql12KM-_35uqL1AoNls5nG8M,2820
|
|
97
97
|
torch_geometric/datasets/geometry.py,sha256=-BxUMirZcUOf01c3avvF0b6wGPn-4S3Zj3Oau1RaJVk,4223
|
|
98
|
-
torch_geometric/datasets/git_mol_dataset.py,sha256=
|
|
98
|
+
torch_geometric/datasets/git_mol_dataset.py,sha256=l5u4U86tfjJdHtQPN7SM3Yjv25LD1Idtm7VHaqJqNik,10665
|
|
99
99
|
torch_geometric/datasets/github.py,sha256=Qhqhkvi6eZ8VF_HqP1rL2iYToZavFNsQh7J1WdeM9dA,2687
|
|
100
100
|
torch_geometric/datasets/gnn_benchmark_dataset.py,sha256=4P8n7czF-gf1egLYlAcSSvfB0GXIKpAbH5UjsuFld1M,6976
|
|
101
101
|
torch_geometric/datasets/heterophilous_graph_dataset.py,sha256=yHHtwl4uPrid0vPOxvPV3sIS8HWdswar8FJ0h0OQ9is,4224
|
|
@@ -119,7 +119,7 @@ torch_geometric/datasets/medshapenet.py,sha256=eCBCXKpueweCwDSf_Q4_MwVA3IbJd04FS
|
|
|
119
119
|
torch_geometric/datasets/mixhop_synthetic_dataset.py,sha256=4NNvTHUvvV6pcqQCyVDS5XhppXUeF2H9GTfFoc49eyU,3951
|
|
120
120
|
torch_geometric/datasets/mnist_superpixels.py,sha256=o2ArbZ0_OE0u8VCaHmWwvngESlOFr9oM9dSEP_tjAS4,3340
|
|
121
121
|
torch_geometric/datasets/modelnet.py,sha256=-qmLjlQiKVWmtHefAIIE97dQxEcaBfetMJnvgYZuwkg,5347
|
|
122
|
-
torch_geometric/datasets/molecule_gpt_dataset.py,sha256=
|
|
122
|
+
torch_geometric/datasets/molecule_gpt_dataset.py,sha256=TFBduE3_3xxTFSHL3tirV-OAlBjSi6iHPOHJGQ_-tug,18785
|
|
123
123
|
torch_geometric/datasets/molecule_net.py,sha256=pMzaJzd-LbBncZ0VoC87HfA8d1F4NwCWTb5YKvLM890,7404
|
|
124
124
|
torch_geometric/datasets/movie_lens.py,sha256=M4Bu0Xus8IkW8GYzjxPxSdPXNbcCCx9cu6cncxBvLx8,4033
|
|
125
125
|
torch_geometric/datasets/movie_lens_100k.py,sha256=eTpBAteM3jqTEtiwLxmhVj4r8JvftvPx8Hvs-3ZIHlU,6057
|
|
@@ -139,8 +139,9 @@ torch_geometric/datasets/pcqm4m.py,sha256=7ID_xXXIAyuNzYLI2lBWygZl9wGos-dbaz1b6E
|
|
|
139
139
|
torch_geometric/datasets/planetoid.py,sha256=RksfwR_PI7qGVphs-T-4jXDepYwQCweMXElLm096hgg,7201
|
|
140
140
|
torch_geometric/datasets/polblogs.py,sha256=IYzsvd4R0OojmOOZUoOdCwQYfwcTfth1PNtcBK1yOGc,3045
|
|
141
141
|
torch_geometric/datasets/ppi.py,sha256=zPtg-omC7WYvr9Tzwkb7zNjpXLODsvxKxKdGEUswp2E,5030
|
|
142
|
+
torch_geometric/datasets/protein_mpnn_dataset.py,sha256=TTeTVJMo0Rlt2_h9bbZMKJe3rTJcjCgY5cXGyWteBfA,17756
|
|
142
143
|
torch_geometric/datasets/qm7.py,sha256=bYyK8xlh9kTr5vqueNbLu9EAjIXkQH1KX1VWnjKfOJc,3323
|
|
143
|
-
torch_geometric/datasets/qm9.py,sha256=
|
|
144
|
+
torch_geometric/datasets/qm9.py,sha256=Ub1t8KNeWFZvw50_Qk-80yNFeYFDwdAeyQtp3JHZs7o,17197
|
|
144
145
|
torch_geometric/datasets/rcdd.py,sha256=gvOoM1tw_X5QMyBB4FkMUwNErMXAvImyjz5twktBAh8,5317
|
|
145
146
|
torch_geometric/datasets/reddit.py,sha256=QUgiKTaj6YTOYbgWgqV8mPYsctOui2ujaM8f8qy81v0,3131
|
|
146
147
|
torch_geometric/datasets/reddit2.py,sha256=WSdrhbDPcUEG37XWNUd0uKnqgI911MOcfjXmgjbTPoQ,4291
|
|
@@ -153,6 +154,7 @@ torch_geometric/datasets/snap_dataset.py,sha256=deJvB6cpIQ3bu_pcWoqgEo1-Kl_NcFi7
|
|
|
153
154
|
torch_geometric/datasets/suite_sparse.py,sha256=eqjH4vAUq872qdk3YdLkZSwlu6r7HHpTgK0vEVGmY1s,3278
|
|
154
155
|
torch_geometric/datasets/tag_dataset.py,sha256=qTnwr2N1tbWYeLGbItfv70UxQ3n1rKesjeVU3kcOCP8,14757
|
|
155
156
|
torch_geometric/datasets/taobao.py,sha256=CUcZpbWsNTasevflO8zqP0YvENy89P7wpKS4MHaDJ6Q,4170
|
|
157
|
+
torch_geometric/datasets/teeth3ds.py,sha256=hZvhcq9lsQENNFr5hk50w2T3CgxE_tlnQfrCgN6uIDQ,9919
|
|
156
158
|
torch_geometric/datasets/tosca.py,sha256=nUSF8NQT1GlkwWQLshjWmr8xORsvRHzzIqhUyDCvABc,4632
|
|
157
159
|
torch_geometric/datasets/tu_dataset.py,sha256=14OSaXBgVwT1dX2h1wZ3xVIwoo0GQBEfR3yWh6Q0VF0,7847
|
|
158
160
|
torch_geometric/datasets/twitch.py,sha256=qfEerf-Uaojx2ZvegENowdG4E7RoUT_HUO9xtULadvo,3658
|
|
@@ -333,8 +335,9 @@ torch_geometric/nn/aggr/set_transformer.py,sha256=FG7_JizpFX14M6VSCwLSjYXYdJ1ZiQ
|
|
|
333
335
|
torch_geometric/nn/aggr/sort.py,sha256=bvOOWnFkNOBOZih4rqVZQsjfeDX3vmXo1bpPSFD846w,2507
|
|
334
336
|
torch_geometric/nn/aggr/utils.py,sha256=SQvdc0g6p_E2j0prA14MW2ekjEDvV-g545N0Q85uc-o,8625
|
|
335
337
|
torch_geometric/nn/aggr/variance_preserving.py,sha256=fu-U_aGYpVLpgSFvVg0ONMe6nqoyv8tZ6Y35qMYTf9w,1126
|
|
336
|
-
torch_geometric/nn/attention/__init__.py,sha256=
|
|
338
|
+
torch_geometric/nn/attention/__init__.py,sha256=w-jDQFpVqARJKjttTgKkD9kkAqRJl4MpASCfiNYIfr0,263
|
|
337
339
|
torch_geometric/nn/attention/performer.py,sha256=2PCDn4_-oNTao2-DkXIaoi18anP01OxRELF2pvp-jk8,7357
|
|
340
|
+
torch_geometric/nn/attention/polynormer.py,sha256=uBxGs0nldp6oGlByqbxgEk23VeXLEd6B3myS5BOKDRs,3998
|
|
338
341
|
torch_geometric/nn/attention/qformer.py,sha256=7J-pWm_vpumK38IC-iCBz4oqL-BEIofEIxJ0wfjWq9A,2338
|
|
339
342
|
torch_geometric/nn/attention/sgformer.py,sha256=OBC5HQxbY289bPDtwN8UbPH46To2GRTeVN-najogD-o,3747
|
|
340
343
|
torch_geometric/nn/conv/__init__.py,sha256=8CK-DFG2PEo2ZaFyg-IUlQH8ecQoDDi556uv3ugeQyc,3572
|
|
@@ -374,7 +377,7 @@ torch_geometric/nn/conv/hgt_conv.py,sha256=lUhTWUMovMtn9yR_b2-kLNLqHChGOUl2OtXBY
|
|
|
374
377
|
torch_geometric/nn/conv/hypergraph_conv.py,sha256=4BosbbqJyprlI6QjPqIfMxCqnARU_0mUn1zcAQhbw90,8691
|
|
375
378
|
torch_geometric/nn/conv/le_conv.py,sha256=DonmmYZOKk5wIlTZzzIfNKqBY6MO0MRxYhyr0YtNz-Q,3494
|
|
376
379
|
torch_geometric/nn/conv/lg_conv.py,sha256=8jMa79iPsOUbXEfBIc3wmbvAD8T3d1j37LeIFTX3Yag,2369
|
|
377
|
-
torch_geometric/nn/conv/meshcnn_conv.py,sha256=
|
|
380
|
+
torch_geometric/nn/conv/meshcnn_conv.py,sha256=92zUcgfS0Fwv-MpddF4Ia1a65y7ddPAkazYf7D6kvwg,21951
|
|
378
381
|
torch_geometric/nn/conv/message_passing.py,sha256=ZuTvSvodGy1GyAW4mHtuoMUuxclam-7opidYNY5IHm8,44377
|
|
379
382
|
torch_geometric/nn/conv/mf_conv.py,sha256=SkOGMN1tFT9dcqy8xYowsB2ozw6QfkoArgR1BksZZaU,4340
|
|
380
383
|
torch_geometric/nn/conv/mixhop_conv.py,sha256=qVDPWeWcnO7_eHM0ZnpKtr8SISjb4jp0xjgpoDrwjlk,4555
|
|
@@ -429,7 +432,7 @@ torch_geometric/nn/kge/distmult.py,sha256=dGQ0bVzjreZgFN1lXE23_IIidsiOq7ehPrMb-N
|
|
|
429
432
|
torch_geometric/nn/kge/loader.py,sha256=5Uc1j3OUMQnBYSHDqL7pLCty1siFLzoPkztigYO2zP8,771
|
|
430
433
|
torch_geometric/nn/kge/rotate.py,sha256=XLuO1AbyTt5cJxr97ZzoyAyIEsHKesgW5TvDmnGJAao,3208
|
|
431
434
|
torch_geometric/nn/kge/transe.py,sha256=jlejq5BLMm-sb1wWcLDp7pZqCdelWBgjDIC8ctbjSdU,3088
|
|
432
|
-
torch_geometric/nn/models/__init__.py,sha256=
|
|
435
|
+
torch_geometric/nn/models/__init__.py,sha256=71Hqc-ZMfCKn9lelFYDjpHXapbEa0wqVAd2OXCb1y5o,2448
|
|
433
436
|
torch_geometric/nn/models/attentive_fp.py,sha256=1z3iTV2O5W9tqHFAdno8FeBFeXmuG-TDZk4lwwVh3Ac,6634
|
|
434
437
|
torch_geometric/nn/models/attract_repel.py,sha256=h9OyogT0NY0xiT0DkpJHMxH6ZUmo8R-CmwZdKEwq8Ek,5277
|
|
435
438
|
torch_geometric/nn/models/autoencoder.py,sha256=nGje-zty78Y3hxOJ9o0_6QziJjOvBlknk6z0_fDQwQU,10770
|
|
@@ -442,7 +445,7 @@ torch_geometric/nn/models/dimenet.py,sha256=O2rqEx5HWs_lMwRD8eq6WMkbqJaCLL5zgWUJ
|
|
|
442
445
|
torch_geometric/nn/models/dimenet_utils.py,sha256=Eyn_EiJqwKvuYj6BtRpSxrzMG3v4Gk98X9MxZ7uvwm4,5069
|
|
443
446
|
torch_geometric/nn/models/g_retriever.py,sha256=tVibbqM_r-1LnA3R3oVyzp0bpuN3qPoYqcU6LZ8dYEk,8260
|
|
444
447
|
torch_geometric/nn/models/git_mol.py,sha256=Wc6Hx6RDDR7sDWRWHfA5eK9e9gFsrTZ9OLmpMfoj3pE,12676
|
|
445
|
-
torch_geometric/nn/models/glem.py,sha256=
|
|
448
|
+
torch_geometric/nn/models/glem.py,sha256=GlL_I63g-_5eTycSGRj720YntldQ-CQ351RaDPc6XAU,16674
|
|
446
449
|
torch_geometric/nn/models/gnnff.py,sha256=15dkiLgy0LmH1hnUrpeoHioIp4BPTfjpVATpnGRt9E0,7860
|
|
447
450
|
torch_geometric/nn/models/gpse.py,sha256=acEAeeicLgzKRL54WhvIFxjA5XViHgXgMEH-NgbMdqI,41971
|
|
448
451
|
torch_geometric/nn/models/graph_mixer.py,sha256=mthMeCOikR8gseEsu4oJ3Cd9C35zHSv1p32ROwnG-6s,9246
|
|
@@ -459,6 +462,8 @@ torch_geometric/nn/models/molecule_gpt.py,sha256=k-XULH6jaurj-R2EE4sIWTkqlNqa3Cz
|
|
|
459
462
|
torch_geometric/nn/models/neural_fingerprint.py,sha256=pTLJgU9Uh2Lnf9bggLj4cKI8YdEFcMF-9MALuubqbuQ,2378
|
|
460
463
|
torch_geometric/nn/models/node2vec.py,sha256=81Ku4Rp4IwLEAy06KEgJ2fYtXXVL_uv_Hb8lBr6YXrE,7664
|
|
461
464
|
torch_geometric/nn/models/pmlp.py,sha256=dcAASVSyQMMhItSfEJWPeAFh0R3tNCwAHwdrShwQ8o4,3538
|
|
465
|
+
torch_geometric/nn/models/polynormer.py,sha256=mayWdzdolT5PCt_Oo7UGG-JUripMHWB2lUWF1bh6goU,7640
|
|
466
|
+
torch_geometric/nn/models/protein_mpnn.py,sha256=QXHfltiJPmakpzgJKw_1vwCGBlszv9nfY4r4F38Sg9k,11031
|
|
462
467
|
torch_geometric/nn/models/re_net.py,sha256=pz66q5b5BoGDNVQvpEGS2RGoeKvpjkYAv9r3WAuvITk,8986
|
|
463
468
|
torch_geometric/nn/models/rect.py,sha256=2F3XyyvHTAEuqfJpiNB5M8pSGy738LhPiom5I-SDWqM,2808
|
|
464
469
|
torch_geometric/nn/models/rev_gnn.py,sha256=Bpme087Zs227lcB0ODOKWsxaly67q96wseaRt6bacjs,11796
|
|
@@ -613,7 +618,7 @@ torch_geometric/utils/_tree_decomposition.py,sha256=ZtpjPQJgXbQWtSWjo-Fmhrov0DGO
|
|
|
613
618
|
torch_geometric/utils/_trim_to_layer.py,sha256=cauOEzMJJK4w9BC-Pg1bHVncBYqG9XxQex3rn10BFjc,8339
|
|
614
619
|
torch_geometric/utils/_unbatch.py,sha256=B0vjKI96PtHvSBG8F_lqvsiJE134aVjUurPZsG6UZRI,2378
|
|
615
620
|
torch_geometric/utils/augmentation.py,sha256=1F0YCuaklZ9ZbXxdFV0oOoemWvLd8p60WvFo2chzl7E,8600
|
|
616
|
-
torch_geometric/utils/convert.py,sha256=
|
|
621
|
+
torch_geometric/utils/convert.py,sha256=RE5n5no74Xu39-QMWFE0-1RvTgykdK33ymyjF9WcuSs,21938
|
|
617
622
|
torch_geometric/utils/cross_entropy.py,sha256=ZFS5bivtzv3EV9zqgKsekmuQyoZZggPSclhl_tRNHxo,3047
|
|
618
623
|
torch_geometric/utils/dropout.py,sha256=gg0rDnD4FLvBaKSoLAkZwViAQflhLefJm6_Mju5dmQs,11416
|
|
619
624
|
torch_geometric/utils/embedding.py,sha256=Ac_MPSrZGpw-e-gU6Yz-seUioC2WZxBSSzXFeclGwMk,5232
|
|
@@ -634,13 +639,13 @@ torch_geometric/utils/num_nodes.py,sha256=F15ciTFOe8AxjkUh1wKH7RLmJvQYYpz-l3pPPv
|
|
|
634
639
|
torch_geometric/utils/ppr.py,sha256=ebiHbQqRJsQbGUI5xu-IkzQSQsgIaC71vgO0KcXIKAk,4055
|
|
635
640
|
torch_geometric/utils/random.py,sha256=Rv5HlhG5310rytbT9EZ7xWLGKQfozfz1azvYi5nx2-U,5148
|
|
636
641
|
torch_geometric/utils/repeat.py,sha256=RxCoRoEisaP6NouXPPW5tY1Rn-tIfrmpJPm0qGP6W8M,815
|
|
637
|
-
torch_geometric/utils/smiles.py,sha256=
|
|
642
|
+
torch_geometric/utils/smiles.py,sha256=CFqeNtSBXQtY9Ex2gQzI0La490IpVVrm01QdRYEpV7w,7114
|
|
638
643
|
torch_geometric/utils/sparse.py,sha256=1DbaEwdyvnzvg5qVjPlnWcEVDMkxrQLX1jJ0dr6P4js,25135
|
|
639
644
|
torch_geometric/utils/undirected.py,sha256=H_nfpI0_WluOG6VfjPyldvcjL4w5USAKWu2x5nUAUjw,6222
|
|
640
645
|
torch_geometric/visualization/__init__.py,sha256=b-HnVesXjyJ_L1N-DnjiRiRVf7lhwKaBQF_2i5YMVSU,208
|
|
641
646
|
torch_geometric/visualization/graph.py,sha256=mfZHXYfiU-CWMtfawYc80IxVwVmtK9hbIkSKhM_j7oI,14311
|
|
642
647
|
torch_geometric/visualization/influence.py,sha256=CWMvuNA_Nf1sfbJmQgn58yS4OFpeKXeZPe7kEuvkUBw,477
|
|
643
|
-
pyg_nightly-2.7.0.
|
|
644
|
-
pyg_nightly-2.7.0.
|
|
645
|
-
pyg_nightly-2.7.0.
|
|
646
|
-
pyg_nightly-2.7.0.
|
|
648
|
+
pyg_nightly-2.7.0.dev20250704.dist-info/licenses/LICENSE,sha256=ic-27cMJc1kWoMEYncz3Ya3Ur2Bi3bNLWib2DT763-o,1067
|
|
649
|
+
pyg_nightly-2.7.0.dev20250704.dist-info/WHEEL,sha256=G2gURzTEtmeR8nrdXUJfNiB3VYVxigPQ-bEQujpNiNs,82
|
|
650
|
+
pyg_nightly-2.7.0.dev20250704.dist-info/METADATA,sha256=Nau44bIMI13OXEqYNOlI1hYfDz8FXpUoPydv-JxRW2Q,63005
|
|
651
|
+
pyg_nightly-2.7.0.dev20250704.dist-info/RECORD,,
|
torch_geometric/__init__.py
CHANGED
|
@@ -31,7 +31,7 @@ from .lazy_loader import LazyLoader
|
|
|
31
31
|
contrib = LazyLoader('contrib', globals(), 'torch_geometric.contrib')
|
|
32
32
|
graphgym = LazyLoader('graphgym', globals(), 'torch_geometric.graphgym')
|
|
33
33
|
|
|
34
|
-
__version__ = '2.7.0.
|
|
34
|
+
__version__ = '2.7.0.dev20250704'
|
|
35
35
|
|
|
36
36
|
__all__ = [
|
|
37
37
|
'Index',
|
|
@@ -81,8 +81,10 @@ from .web_qsp_dataset import WebQSPDataset, CWQDataset
|
|
|
81
81
|
from .git_mol_dataset import GitMolDataset
|
|
82
82
|
from .molecule_gpt_dataset import MoleculeGPTDataset
|
|
83
83
|
from .instruct_mol_dataset import InstructMolDataset
|
|
84
|
+
from .protein_mpnn_dataset import ProteinMPNNDataset
|
|
84
85
|
from .tag_dataset import TAGDataset
|
|
85
86
|
from .city import CityNetwork
|
|
87
|
+
from .teeth3ds import Teeth3DS
|
|
86
88
|
|
|
87
89
|
from .dbp15k import DBP15K
|
|
88
90
|
from .aminer import AMiner
|
|
@@ -201,8 +203,10 @@ homo_datasets = [
|
|
|
201
203
|
'GitMolDataset',
|
|
202
204
|
'MoleculeGPTDataset',
|
|
203
205
|
'InstructMolDataset',
|
|
206
|
+
'ProteinMPNNDataset',
|
|
204
207
|
'TAGDataset',
|
|
205
208
|
'CityNetwork',
|
|
209
|
+
'Teeth3DS',
|
|
206
210
|
]
|
|
207
211
|
|
|
208
212
|
hetero_datasets = [
|
|
@@ -438,7 +438,7 @@ class MoleculeGPTDataset(InMemoryDataset):
|
|
|
438
438
|
for mol in tqdm(suppl):
|
|
439
439
|
if mol.HasProp('PUBCHEM_COMPOUND_CID'):
|
|
440
440
|
CID = mol.GetProp("PUBCHEM_COMPOUND_CID")
|
|
441
|
-
CAN_SMILES = mol.GetProp("
|
|
441
|
+
CAN_SMILES = mol.GetProp("PUBCHEM_SMILES")
|
|
442
442
|
|
|
443
443
|
m: Chem.Mol = Chem.MolFromSmiles(CAN_SMILES)
|
|
444
444
|
if m is None:
|
|
@@ -0,0 +1,451 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import pickle
|
|
3
|
+
import random
|
|
4
|
+
from collections import defaultdict
|
|
5
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import torch
|
|
9
|
+
from tqdm import tqdm
|
|
10
|
+
|
|
11
|
+
from torch_geometric.data import (
|
|
12
|
+
Data,
|
|
13
|
+
InMemoryDataset,
|
|
14
|
+
download_url,
|
|
15
|
+
extract_tar,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class ProteinMPNNDataset(InMemoryDataset):
|
|
20
|
+
r"""The ProteinMPNN dataset from the `"Robust deep learning based protein
|
|
21
|
+
sequence design using ProteinMPNN"
|
|
22
|
+
<https://www.biorxiv.org/content/10.1101/2022.06.03.494563v1>`_ paper.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
root (str): Root directory where the dataset should be saved.
|
|
26
|
+
size (str): Size of the PDB information to train the model.
|
|
27
|
+
If :obj:`"small"`, loads the small dataset (229.4 MB).
|
|
28
|
+
If :obj:`"large"`, loads the large dataset (64.1 GB).
|
|
29
|
+
(default: :obj:`"small"`)
|
|
30
|
+
split (str, optional): If :obj:`"train"`, loads the training dataset.
|
|
31
|
+
If :obj:`"valid"`, loads the validation dataset.
|
|
32
|
+
If :obj:`"test"`, loads the test dataset.
|
|
33
|
+
(default: :obj:`"train"`)
|
|
34
|
+
datacut (str, optional): Date cutoff to filter the dataset.
|
|
35
|
+
(default: :obj:`"2030-01-01"`)
|
|
36
|
+
rescut (float, optional): PDB resolution cutoff.
|
|
37
|
+
(default: :obj:`3.5`)
|
|
38
|
+
homo (float, optional): Homology cutoff.
|
|
39
|
+
(default: :obj:`0.70`)
|
|
40
|
+
max_length (int, optional): Maximum length of the protein complex.
|
|
41
|
+
(default: :obj:`10000`)
|
|
42
|
+
num_units (int, optional): Number of units of the protein complex.
|
|
43
|
+
(default: :obj:`150`)
|
|
44
|
+
transform (callable, optional): A function/transform that takes in an
|
|
45
|
+
:obj:`torch_geometric.data.Data` object and returns a transformed
|
|
46
|
+
version. The data object will be transformed before every access.
|
|
47
|
+
(default: :obj:`None`)
|
|
48
|
+
pre_transform (callable, optional): A function/transform that takes in
|
|
49
|
+
an :obj:`torch_geometric.data.Data` object and returns a
|
|
50
|
+
transformed version. The data object will be transformed before
|
|
51
|
+
being saved to disk. (default: :obj:`None`)
|
|
52
|
+
pre_filter (callable, optional): A function that takes in an
|
|
53
|
+
:obj:`torch_geometric.data.Data` object and returns a boolean
|
|
54
|
+
value, indicating whether the data object should be included in the
|
|
55
|
+
final dataset. (default: :obj:`None`)
|
|
56
|
+
force_reload (bool, optional): Whether to re-process the dataset.
|
|
57
|
+
(default: :obj:`False`)
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
raw_url = {
|
|
61
|
+
'small':
|
|
62
|
+
'https://files.ipd.uw.edu/pub/training_sets/'
|
|
63
|
+
'pdb_2021aug02_sample.tar.gz',
|
|
64
|
+
'large':
|
|
65
|
+
'https://files.ipd.uw.edu/pub/training_sets/'
|
|
66
|
+
'pdb_2021aug02.tar.gz',
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
splits = {
|
|
70
|
+
'train': 1,
|
|
71
|
+
'valid': 2,
|
|
72
|
+
'test': 3,
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
def __init__(
|
|
76
|
+
self,
|
|
77
|
+
root: str,
|
|
78
|
+
size: str = 'small',
|
|
79
|
+
split: str = 'train',
|
|
80
|
+
datacut: str = '2030-01-01',
|
|
81
|
+
rescut: float = 3.5,
|
|
82
|
+
homo: float = 0.70,
|
|
83
|
+
max_length: int = 10000,
|
|
84
|
+
num_units: int = 150,
|
|
85
|
+
transform: Optional[Callable] = None,
|
|
86
|
+
pre_transform: Optional[Callable] = None,
|
|
87
|
+
pre_filter: Optional[Callable] = None,
|
|
88
|
+
force_reload: bool = False,
|
|
89
|
+
) -> None:
|
|
90
|
+
self.size = size
|
|
91
|
+
self.split = split
|
|
92
|
+
self.datacut = datacut
|
|
93
|
+
self.rescut = rescut
|
|
94
|
+
self.homo = homo
|
|
95
|
+
self.max_length = max_length
|
|
96
|
+
self.num_units = num_units
|
|
97
|
+
|
|
98
|
+
self.sub_folder = self.raw_url[self.size].split('/')[-1].split('.')[0]
|
|
99
|
+
|
|
100
|
+
super().__init__(root, transform, pre_transform, pre_filter,
|
|
101
|
+
force_reload=force_reload)
|
|
102
|
+
self.load(self.processed_paths[self.splits[self.split]])
|
|
103
|
+
|
|
104
|
+
@property
|
|
105
|
+
def raw_file_names(self) -> List[str]:
|
|
106
|
+
return [
|
|
107
|
+
f'{self.sub_folder}/{f}'
|
|
108
|
+
for f in ['list.csv', 'valid_clusters.txt', 'test_clusters.txt']
|
|
109
|
+
]
|
|
110
|
+
|
|
111
|
+
@property
|
|
112
|
+
def processed_file_names(self) -> List[str]:
|
|
113
|
+
return ['splits.pkl', 'train.pt', 'valid.pt', 'test.pt']
|
|
114
|
+
|
|
115
|
+
def download(self) -> None:
|
|
116
|
+
file_path = download_url(self.raw_url[self.size], self.raw_dir)
|
|
117
|
+
extract_tar(file_path, self.raw_dir)
|
|
118
|
+
os.unlink(file_path)
|
|
119
|
+
|
|
120
|
+
def process(self) -> None:
|
|
121
|
+
alphabet_set = set(list('ACDEFGHIKLMNPQRSTVWYX'))
|
|
122
|
+
cluster_ids = self._process_split()
|
|
123
|
+
total_items = sum(len(items) for items in cluster_ids.values())
|
|
124
|
+
data_list = []
|
|
125
|
+
|
|
126
|
+
with tqdm(total=total_items, desc="Processing") as pbar:
|
|
127
|
+
for _, items in cluster_ids.items():
|
|
128
|
+
for chain_id, _ in items:
|
|
129
|
+
item = self._process_pdb1(chain_id)
|
|
130
|
+
|
|
131
|
+
if 'label' not in item:
|
|
132
|
+
pbar.update(1)
|
|
133
|
+
continue
|
|
134
|
+
if len(list(np.unique(item['idx']))) >= 352:
|
|
135
|
+
pbar.update(1)
|
|
136
|
+
continue
|
|
137
|
+
|
|
138
|
+
my_dict = self._process_pdb2(item)
|
|
139
|
+
|
|
140
|
+
if len(my_dict['seq']) > self.max_length:
|
|
141
|
+
pbar.update(1)
|
|
142
|
+
continue
|
|
143
|
+
bad_chars = set(list(
|
|
144
|
+
my_dict['seq'])).difference(alphabet_set)
|
|
145
|
+
if len(bad_chars) > 0:
|
|
146
|
+
pbar.update(1)
|
|
147
|
+
continue
|
|
148
|
+
|
|
149
|
+
x_chain_all, chain_seq_label_all, mask, chain_mask_all, residue_idx, chain_encoding_all = self._process_pdb3( # noqa: E501
|
|
150
|
+
my_dict)
|
|
151
|
+
|
|
152
|
+
data = Data(
|
|
153
|
+
x=x_chain_all, # [seq_len, 4, 3]
|
|
154
|
+
chain_seq_label=chain_seq_label_all, # [seq_len]
|
|
155
|
+
mask=mask, # [seq_len]
|
|
156
|
+
chain_mask_all=chain_mask_all, # [seq_len]
|
|
157
|
+
residue_idx=residue_idx, # [seq_len]
|
|
158
|
+
chain_encoding_all=chain_encoding_all, # [seq_len]
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
if self.pre_filter is not None and not self.pre_filter(
|
|
162
|
+
data):
|
|
163
|
+
continue
|
|
164
|
+
if self.pre_transform is not None:
|
|
165
|
+
data = self.pre_transform(data)
|
|
166
|
+
|
|
167
|
+
data_list.append(data)
|
|
168
|
+
|
|
169
|
+
if len(data_list) >= self.num_units:
|
|
170
|
+
pbar.update(total_items - pbar.n)
|
|
171
|
+
break
|
|
172
|
+
pbar.update(1)
|
|
173
|
+
else:
|
|
174
|
+
continue
|
|
175
|
+
break
|
|
176
|
+
self.save(data_list, self.processed_paths[self.splits[self.split]])
|
|
177
|
+
|
|
178
|
+
def _process_split(self) -> Dict[int, List[Tuple[str, int]]]:
|
|
179
|
+
import pandas as pd
|
|
180
|
+
save_path = self.processed_paths[0]
|
|
181
|
+
|
|
182
|
+
if os.path.exists(save_path):
|
|
183
|
+
print('Load split')
|
|
184
|
+
with open(save_path, 'rb') as f:
|
|
185
|
+
data = pickle.load(f)
|
|
186
|
+
else:
|
|
187
|
+
# CHAINID, DEPOSITION, RESOLUTION, HASH, CLUSTER, SEQUENCE
|
|
188
|
+
df = pd.read_csv(self.raw_paths[0])
|
|
189
|
+
df = df[(df['RESOLUTION'] <= self.rescut)
|
|
190
|
+
& (df['DEPOSITION'] <= self.datacut)]
|
|
191
|
+
|
|
192
|
+
val_ids = pd.read_csv(self.raw_paths[1], header=None)[0].tolist()
|
|
193
|
+
test_ids = pd.read_csv(self.raw_paths[2], header=None)[0].tolist()
|
|
194
|
+
|
|
195
|
+
# compile training and validation sets
|
|
196
|
+
data = {
|
|
197
|
+
'train': defaultdict(list),
|
|
198
|
+
'valid': defaultdict(list),
|
|
199
|
+
'test': defaultdict(list),
|
|
200
|
+
}
|
|
201
|
+
|
|
202
|
+
for _, r in tqdm(df.iterrows(), desc='Processing split',
|
|
203
|
+
total=len(df)):
|
|
204
|
+
cluster_id = r['CLUSTER']
|
|
205
|
+
hash_id = r['HASH']
|
|
206
|
+
chain_id = r['CHAINID']
|
|
207
|
+
if cluster_id in val_ids:
|
|
208
|
+
data['valid'][cluster_id].append((chain_id, hash_id))
|
|
209
|
+
elif cluster_id in test_ids:
|
|
210
|
+
data['test'][cluster_id].append((chain_id, hash_id))
|
|
211
|
+
else:
|
|
212
|
+
data['train'][cluster_id].append((chain_id, hash_id))
|
|
213
|
+
|
|
214
|
+
with open(save_path, 'wb') as f:
|
|
215
|
+
pickle.dump(data, f)
|
|
216
|
+
|
|
217
|
+
return data[self.split]
|
|
218
|
+
|
|
219
|
+
def _process_pdb1(self, chain_id: str) -> Dict[str, Any]:
|
|
220
|
+
pdbid, chid = chain_id.split('_')
|
|
221
|
+
prefix = f'{self.raw_dir}/{self.sub_folder}/pdb/{pdbid[1:3]}/{pdbid}'
|
|
222
|
+
# load metadata
|
|
223
|
+
if not os.path.isfile(f'{prefix}.pt'):
|
|
224
|
+
return {'seq': np.zeros(5)}
|
|
225
|
+
meta = torch.load(f'{prefix}.pt')
|
|
226
|
+
asmb_ids = meta['asmb_ids']
|
|
227
|
+
asmb_chains = meta['asmb_chains']
|
|
228
|
+
chids = np.array(meta['chains'])
|
|
229
|
+
|
|
230
|
+
# find candidate assemblies which contain chid chain
|
|
231
|
+
asmb_candidates = {
|
|
232
|
+
a
|
|
233
|
+
for a, b in zip(asmb_ids, asmb_chains) if chid in b.split(',')
|
|
234
|
+
}
|
|
235
|
+
|
|
236
|
+
# if the chains is missing is missing from all the assemblies
|
|
237
|
+
# then return this chain alone
|
|
238
|
+
if len(asmb_candidates) < 1:
|
|
239
|
+
chain = torch.load(f'{prefix}_{chid}.pt')
|
|
240
|
+
L = len(chain['seq'])
|
|
241
|
+
return {
|
|
242
|
+
'seq': chain['seq'],
|
|
243
|
+
'xyz': chain['xyz'],
|
|
244
|
+
'idx': torch.zeros(L).int(),
|
|
245
|
+
'masked': torch.Tensor([0]).int(),
|
|
246
|
+
'label': chain_id,
|
|
247
|
+
}
|
|
248
|
+
|
|
249
|
+
# randomly pick one assembly from candidates
|
|
250
|
+
asmb_i = random.sample(list(asmb_candidates), 1)
|
|
251
|
+
|
|
252
|
+
# indices of selected transforms
|
|
253
|
+
idx = np.where(np.array(asmb_ids) == asmb_i)[0]
|
|
254
|
+
|
|
255
|
+
# load relevant chains
|
|
256
|
+
chains = {
|
|
257
|
+
c: torch.load(f'{prefix}_{c}.pt')
|
|
258
|
+
for i in idx
|
|
259
|
+
for c in asmb_chains[i] if c in meta['chains']
|
|
260
|
+
}
|
|
261
|
+
|
|
262
|
+
# generate assembly
|
|
263
|
+
asmb = {}
|
|
264
|
+
for k in idx:
|
|
265
|
+
|
|
266
|
+
# pick k-th xform
|
|
267
|
+
xform = meta[f'asmb_xform{k}']
|
|
268
|
+
u = xform[:, :3, :3]
|
|
269
|
+
r = xform[:, :3, 3]
|
|
270
|
+
|
|
271
|
+
# select chains which k-th xform should be applied to
|
|
272
|
+
s1 = set(meta['chains'])
|
|
273
|
+
s2 = set(asmb_chains[k].split(','))
|
|
274
|
+
chains_k = s1 & s2
|
|
275
|
+
|
|
276
|
+
# transform selected chains
|
|
277
|
+
for c in chains_k:
|
|
278
|
+
try:
|
|
279
|
+
xyz = chains[c]['xyz']
|
|
280
|
+
xyz_ru = torch.einsum('bij,raj->brai', u, xyz) + r[:, None,
|
|
281
|
+
None, :]
|
|
282
|
+
asmb.update({
|
|
283
|
+
(c, k, i): xyz_i
|
|
284
|
+
for i, xyz_i in enumerate(xyz_ru)
|
|
285
|
+
})
|
|
286
|
+
except KeyError:
|
|
287
|
+
return {'seq': np.zeros(5)}
|
|
288
|
+
|
|
289
|
+
# select chains which share considerable similarity to chid
|
|
290
|
+
seqid = meta['tm'][chids == chid][0, :, 1]
|
|
291
|
+
homo = {
|
|
292
|
+
ch_j
|
|
293
|
+
for seqid_j, ch_j in zip(seqid, chids) if seqid_j > self.homo
|
|
294
|
+
}
|
|
295
|
+
# stack all chains in the assembly together
|
|
296
|
+
seq: str = ''
|
|
297
|
+
xyz_all: List[torch.Tensor] = []
|
|
298
|
+
idx_all: List[torch.Tensor] = []
|
|
299
|
+
masked: List[int] = []
|
|
300
|
+
seq_list = []
|
|
301
|
+
for counter, (k, v) in enumerate(asmb.items()):
|
|
302
|
+
seq += chains[k[0]]['seq']
|
|
303
|
+
seq_list.append(chains[k[0]]['seq'])
|
|
304
|
+
xyz_all.append(v)
|
|
305
|
+
idx_all.append(torch.full((v.shape[0], ), counter))
|
|
306
|
+
if k[0] in homo:
|
|
307
|
+
masked.append(counter)
|
|
308
|
+
|
|
309
|
+
return {
|
|
310
|
+
'seq': seq,
|
|
311
|
+
'xyz': torch.cat(xyz_all, dim=0),
|
|
312
|
+
'idx': torch.cat(idx_all, dim=0),
|
|
313
|
+
'masked': torch.Tensor(masked).int(),
|
|
314
|
+
'label': chain_id,
|
|
315
|
+
}
|
|
316
|
+
|
|
317
|
+
def _process_pdb2(self, t: Dict[str, Any]) -> Dict[str, Any]:
|
|
318
|
+
init_alphabet = list(
|
|
319
|
+
'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz')
|
|
320
|
+
extra_alphabet = [str(item) for item in list(np.arange(300))]
|
|
321
|
+
chain_alphabet = init_alphabet + extra_alphabet
|
|
322
|
+
my_dict: Dict[str, Union[str, int, Dict[str, Any], List[Any]]] = {}
|
|
323
|
+
concat_seq = ''
|
|
324
|
+
mask_list = []
|
|
325
|
+
visible_list = []
|
|
326
|
+
for idx in list(np.unique(t['idx'])):
|
|
327
|
+
letter = chain_alphabet[idx]
|
|
328
|
+
res = np.argwhere(t['idx'] == idx)
|
|
329
|
+
initial_sequence = "".join(list(
|
|
330
|
+
np.array(list(t['seq']))[res][
|
|
331
|
+
0,
|
|
332
|
+
]))
|
|
333
|
+
if initial_sequence[-6:] == "HHHHHH":
|
|
334
|
+
res = res[:, :-6]
|
|
335
|
+
if initial_sequence[0:6] == "HHHHHH":
|
|
336
|
+
res = res[:, 6:]
|
|
337
|
+
if initial_sequence[-7:-1] == "HHHHHH":
|
|
338
|
+
res = res[:, :-7]
|
|
339
|
+
if initial_sequence[-8:-2] == "HHHHHH":
|
|
340
|
+
res = res[:, :-8]
|
|
341
|
+
if initial_sequence[-9:-3] == "HHHHHH":
|
|
342
|
+
res = res[:, :-9]
|
|
343
|
+
if initial_sequence[-10:-4] == "HHHHHH":
|
|
344
|
+
res = res[:, :-10]
|
|
345
|
+
if initial_sequence[1:7] == "HHHHHH":
|
|
346
|
+
res = res[:, 7:]
|
|
347
|
+
if initial_sequence[2:8] == "HHHHHH":
|
|
348
|
+
res = res[:, 8:]
|
|
349
|
+
if initial_sequence[3:9] == "HHHHHH":
|
|
350
|
+
res = res[:, 9:]
|
|
351
|
+
if initial_sequence[4:10] == "HHHHHH":
|
|
352
|
+
res = res[:, 10:]
|
|
353
|
+
if res.shape[1] >= 4:
|
|
354
|
+
chain_seq = "".join(list(np.array(list(t['seq']))[res][0]))
|
|
355
|
+
my_dict[f'seq_chain_{letter}'] = chain_seq
|
|
356
|
+
concat_seq += chain_seq
|
|
357
|
+
if idx in t['masked']:
|
|
358
|
+
mask_list.append(letter)
|
|
359
|
+
else:
|
|
360
|
+
visible_list.append(letter)
|
|
361
|
+
coords_dict_chain = {}
|
|
362
|
+
all_atoms = np.array(t['xyz'][res])[0] # [L, 14, 3]
|
|
363
|
+
for i, c in enumerate(['N', 'CA', 'C', 'O']):
|
|
364
|
+
coords_dict_chain[
|
|
365
|
+
f'{c}_chain_{letter}'] = all_atoms[:, i, :].tolist()
|
|
366
|
+
my_dict[f'coords_chain_{letter}'] = coords_dict_chain
|
|
367
|
+
my_dict['name'] = t['label']
|
|
368
|
+
my_dict['masked_list'] = mask_list
|
|
369
|
+
my_dict['visible_list'] = visible_list
|
|
370
|
+
my_dict['num_of_chains'] = len(mask_list) + len(visible_list)
|
|
371
|
+
my_dict['seq'] = concat_seq
|
|
372
|
+
return my_dict
|
|
373
|
+
|
|
374
|
+
def _process_pdb3(
|
|
375
|
+
self, b: Dict[str, Any]
|
|
376
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
|
|
377
|
+
torch.Tensor, torch.Tensor]:
|
|
378
|
+
L = len(b['seq'])
|
|
379
|
+
# residue idx with jumps across chains
|
|
380
|
+
residue_idx = -100 * np.ones([L], dtype=np.int32)
|
|
381
|
+
# get the list of masked / visible chains
|
|
382
|
+
masked_chains, visible_chains = b['masked_list'], b['visible_list']
|
|
383
|
+
visible_temp_dict, masked_temp_dict = {}, {}
|
|
384
|
+
for letter in masked_chains + visible_chains:
|
|
385
|
+
chain_seq = b[f'seq_chain_{letter}']
|
|
386
|
+
if letter in visible_chains:
|
|
387
|
+
visible_temp_dict[letter] = chain_seq
|
|
388
|
+
elif letter in masked_chains:
|
|
389
|
+
masked_temp_dict[letter] = chain_seq
|
|
390
|
+
# check for duplicate chains (same sequence but different identity)
|
|
391
|
+
for _, vm in masked_temp_dict.items():
|
|
392
|
+
for kv, vv in visible_temp_dict.items():
|
|
393
|
+
if vm == vv:
|
|
394
|
+
if kv not in masked_chains:
|
|
395
|
+
masked_chains.append(kv)
|
|
396
|
+
if kv in visible_chains:
|
|
397
|
+
visible_chains.remove(kv)
|
|
398
|
+
# build protein data structures
|
|
399
|
+
all_chains = masked_chains + visible_chains
|
|
400
|
+
np.random.shuffle(all_chains)
|
|
401
|
+
x_chain_list = []
|
|
402
|
+
chain_mask_list = []
|
|
403
|
+
chain_seq_list = []
|
|
404
|
+
chain_encoding_list = []
|
|
405
|
+
c, l0, l1 = 1, 0, 0
|
|
406
|
+
for letter in all_chains:
|
|
407
|
+
chain_seq = b[f'seq_chain_{letter}']
|
|
408
|
+
chain_length = len(chain_seq)
|
|
409
|
+
chain_coords = b[f'coords_chain_{letter}']
|
|
410
|
+
x_chain = np.stack([
|
|
411
|
+
chain_coords[c] for c in [
|
|
412
|
+
f'N_chain_{letter}', f'CA_chain_{letter}',
|
|
413
|
+
f'C_chain_{letter}', f'O_chain_{letter}'
|
|
414
|
+
]
|
|
415
|
+
], 1) # [chain_length, 4, 3]
|
|
416
|
+
x_chain_list.append(x_chain)
|
|
417
|
+
chain_seq_list.append(chain_seq)
|
|
418
|
+
if letter in visible_chains:
|
|
419
|
+
chain_mask = np.zeros(chain_length) # 0 for visible chains
|
|
420
|
+
elif letter in masked_chains:
|
|
421
|
+
chain_mask = np.ones(chain_length) # 1 for masked chains
|
|
422
|
+
chain_mask_list.append(chain_mask)
|
|
423
|
+
chain_encoding_list.append(c * np.ones(chain_length))
|
|
424
|
+
l1 += chain_length
|
|
425
|
+
residue_idx[l0:l1] = 100 * (c - 1) + np.arange(l0, l1)
|
|
426
|
+
l0 += chain_length
|
|
427
|
+
c += 1
|
|
428
|
+
x_chain_all = np.concatenate(x_chain_list, 0) # [L, 4, 3]
|
|
429
|
+
chain_seq_all = "".join(chain_seq_list)
|
|
430
|
+
# [L,] 1.0 for places that need to be predicted
|
|
431
|
+
chain_mask_all = np.concatenate(chain_mask_list, 0)
|
|
432
|
+
chain_encoding_all = np.concatenate(chain_encoding_list, 0)
|
|
433
|
+
|
|
434
|
+
# Convert to labels
|
|
435
|
+
alphabet = 'ACDEFGHIKLMNPQRSTVWYX'
|
|
436
|
+
chain_seq_label_all = np.asarray(
|
|
437
|
+
[alphabet.index(a) for a in chain_seq_all], dtype=np.int32)
|
|
438
|
+
|
|
439
|
+
isnan = np.isnan(x_chain_all)
|
|
440
|
+
mask = np.isfinite(np.sum(x_chain_all, (1, 2))).astype(np.float32)
|
|
441
|
+
x_chain_all[isnan] = 0.
|
|
442
|
+
|
|
443
|
+
# Conversion
|
|
444
|
+
return (
|
|
445
|
+
torch.from_numpy(x_chain_all).to(dtype=torch.float32),
|
|
446
|
+
torch.from_numpy(chain_seq_label_all).to(dtype=torch.long),
|
|
447
|
+
torch.from_numpy(mask).to(dtype=torch.float32),
|
|
448
|
+
torch.from_numpy(chain_mask_all).to(dtype=torch.float32),
|
|
449
|
+
torch.from_numpy(residue_idx).to(dtype=torch.long),
|
|
450
|
+
torch.from_numpy(chain_encoding_all).to(dtype=torch.long),
|
|
451
|
+
)
|
torch_geometric/datasets/qm9.py
CHANGED
|
@@ -202,7 +202,7 @@ class QM9(InMemoryDataset):
|
|
|
202
202
|
from rdkit import Chem, RDLogger
|
|
203
203
|
from rdkit.Chem.rdchem import BondType as BT
|
|
204
204
|
from rdkit.Chem.rdchem import HybridizationType
|
|
205
|
-
RDLogger.DisableLog('rdApp.*')
|
|
205
|
+
RDLogger.DisableLog('rdApp.*')
|
|
206
206
|
WITH_RDKIT = True
|
|
207
207
|
|
|
208
208
|
except ImportError:
|