pyg-nightly 2.7.0.dev20250701__py3-none-any.whl → 2.7.0.dev20250703__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pyg-nightly might be problematic. Click here for more details.
- {pyg_nightly-2.7.0.dev20250701.dist-info → pyg_nightly-2.7.0.dev20250703.dist-info}/METADATA +1 -1
- {pyg_nightly-2.7.0.dev20250701.dist-info → pyg_nightly-2.7.0.dev20250703.dist-info}/RECORD +25 -22
- torch_geometric/__init__.py +1 -1
- torch_geometric/datasets/__init__.py +4 -0
- torch_geometric/datasets/git_mol_dataset.py +1 -1
- torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
- torch_geometric/datasets/qm9.py +1 -1
- torch_geometric/datasets/teeth3ds.py +269 -0
- torch_geometric/nn/aggr/patch_transformer.py +8 -2
- torch_geometric/nn/aggr/utils.py +9 -4
- torch_geometric/nn/conv/meshcnn_conv.py +9 -15
- torch_geometric/nn/encoding.py +12 -3
- torch_geometric/nn/models/__init__.py +2 -0
- torch_geometric/nn/models/glem.py +7 -3
- torch_geometric/nn/models/protein_mpnn.py +304 -0
- torch_geometric/nn/norm/batch_norm.py +17 -7
- torch_geometric/nn/norm/diff_group_norm.py +7 -2
- torch_geometric/nn/norm/graph_norm.py +9 -4
- torch_geometric/nn/norm/instance_norm.py +5 -1
- torch_geometric/nn/norm/layer_norm.py +12 -4
- torch_geometric/nn/norm/msg_norm.py +8 -2
- torch_geometric/utils/convert.py +15 -8
- torch_geometric/utils/smiles.py +1 -1
- {pyg_nightly-2.7.0.dev20250701.dist-info → pyg_nightly-2.7.0.dev20250703.dist-info}/WHEEL +0 -0
- {pyg_nightly-2.7.0.dev20250701.dist-info → pyg_nightly-2.7.0.dev20250703.dist-info}/licenses/LICENSE +0 -0
{pyg_nightly-2.7.0.dev20250701.dist-info → pyg_nightly-2.7.0.dev20250703.dist-info}/METADATA
RENAMED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: pyg-nightly
|
|
3
|
-
Version: 2.7.0.
|
|
3
|
+
Version: 2.7.0.dev20250703
|
|
4
4
|
Summary: Graph Neural Network Library for PyTorch
|
|
5
5
|
Keywords: deep-learning,pytorch,geometric-deep-learning,graph-neural-networks,graph-convolutional-networks
|
|
6
6
|
Author-email: Matthias Fey <matthias@pyg.org>
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
torch_geometric/__init__.py,sha256=
|
|
1
|
+
torch_geometric/__init__.py,sha256=8_AcpgPpDfVr7gwK3sYPD-pu9JYUmziOQzs9SPRvqcE,2250
|
|
2
2
|
torch_geometric/_compile.py,sha256=9yqMTBKatZPr40WavJz9FjNi7pQj8YZAZOyZmmRGXgc,1351
|
|
3
3
|
torch_geometric/_onnx.py,sha256=V9ffrIKSqhDw6xUZ12lkuSfNs48cQp2EeJ6Z19GfnVw,349
|
|
4
4
|
torch_geometric/backend.py,sha256=lVaf7aLoVaB3M-UcByUJ1G4T4FOK6LXAg0CF4W3E8jo,1575
|
|
@@ -55,7 +55,7 @@ torch_geometric/data/temporal.py,sha256=WOJ6gFrTLikaLhUvotyUF5ql14FkE5Ox3hNkdSp6
|
|
|
55
55
|
torch_geometric/data/view.py,sha256=XjkVSc-UWZFCT4DlXLShZtO8duhFQkS9gq88zZXANsk,1089
|
|
56
56
|
torch_geometric/data/lightning/__init__.py,sha256=w3En1tJfy3kSqe1MycpOyZpHFO3fxBCgNCUOznPA3YU,178
|
|
57
57
|
torch_geometric/data/lightning/datamodule.py,sha256=IjucsIKRcNv16DIqILQnqa_sz72a4-yivoySmEllv2o,29353
|
|
58
|
-
torch_geometric/datasets/__init__.py,sha256=
|
|
58
|
+
torch_geometric/datasets/__init__.py,sha256=rgfUmjd9U3o8renKVl81Brscx4LOtwWmt6qAoaG41C4,6417
|
|
59
59
|
torch_geometric/datasets/actor.py,sha256=oUxgJIX8bi5hJr1etWNYIFyVQNDDXi1nyVpHGGMEAGQ,4304
|
|
60
60
|
torch_geometric/datasets/airfrans.py,sha256=8cCBmHPttrlKY_iwfyr-K-CUX_JEDjrIOg3r9dQSN7o,5439
|
|
61
61
|
torch_geometric/datasets/airports.py,sha256=b3gkv3gY2JkUpmGiz36Z-g7EcnSfU8lBG1YsCOWdJ6k,3758
|
|
@@ -95,7 +95,7 @@ torch_geometric/datasets/gdelt_lite.py,sha256=zE1WagpgmsQARQhEgdCBtALRKyuQvIZqxT
|
|
|
95
95
|
torch_geometric/datasets/ged_dataset.py,sha256=dtd-C6pCygNHLXgVfg3ZTWtTVHKT13Q3GlGrze1_rpo,9551
|
|
96
96
|
torch_geometric/datasets/gemsec.py,sha256=oMTSryTgyed9z_4ydg3ql12KM-_35uqL1AoNls5nG8M,2820
|
|
97
97
|
torch_geometric/datasets/geometry.py,sha256=-BxUMirZcUOf01c3avvF0b6wGPn-4S3Zj3Oau1RaJVk,4223
|
|
98
|
-
torch_geometric/datasets/git_mol_dataset.py,sha256=
|
|
98
|
+
torch_geometric/datasets/git_mol_dataset.py,sha256=l5u4U86tfjJdHtQPN7SM3Yjv25LD1Idtm7VHaqJqNik,10665
|
|
99
99
|
torch_geometric/datasets/github.py,sha256=Qhqhkvi6eZ8VF_HqP1rL2iYToZavFNsQh7J1WdeM9dA,2687
|
|
100
100
|
torch_geometric/datasets/gnn_benchmark_dataset.py,sha256=4P8n7czF-gf1egLYlAcSSvfB0GXIKpAbH5UjsuFld1M,6976
|
|
101
101
|
torch_geometric/datasets/heterophilous_graph_dataset.py,sha256=yHHtwl4uPrid0vPOxvPV3sIS8HWdswar8FJ0h0OQ9is,4224
|
|
@@ -139,8 +139,9 @@ torch_geometric/datasets/pcqm4m.py,sha256=7ID_xXXIAyuNzYLI2lBWygZl9wGos-dbaz1b6E
|
|
|
139
139
|
torch_geometric/datasets/planetoid.py,sha256=RksfwR_PI7qGVphs-T-4jXDepYwQCweMXElLm096hgg,7201
|
|
140
140
|
torch_geometric/datasets/polblogs.py,sha256=IYzsvd4R0OojmOOZUoOdCwQYfwcTfth1PNtcBK1yOGc,3045
|
|
141
141
|
torch_geometric/datasets/ppi.py,sha256=zPtg-omC7WYvr9Tzwkb7zNjpXLODsvxKxKdGEUswp2E,5030
|
|
142
|
+
torch_geometric/datasets/protein_mpnn_dataset.py,sha256=TTeTVJMo0Rlt2_h9bbZMKJe3rTJcjCgY5cXGyWteBfA,17756
|
|
142
143
|
torch_geometric/datasets/qm7.py,sha256=bYyK8xlh9kTr5vqueNbLu9EAjIXkQH1KX1VWnjKfOJc,3323
|
|
143
|
-
torch_geometric/datasets/qm9.py,sha256=
|
|
144
|
+
torch_geometric/datasets/qm9.py,sha256=Ub1t8KNeWFZvw50_Qk-80yNFeYFDwdAeyQtp3JHZs7o,17197
|
|
144
145
|
torch_geometric/datasets/rcdd.py,sha256=gvOoM1tw_X5QMyBB4FkMUwNErMXAvImyjz5twktBAh8,5317
|
|
145
146
|
torch_geometric/datasets/reddit.py,sha256=QUgiKTaj6YTOYbgWgqV8mPYsctOui2ujaM8f8qy81v0,3131
|
|
146
147
|
torch_geometric/datasets/reddit2.py,sha256=WSdrhbDPcUEG37XWNUd0uKnqgI911MOcfjXmgjbTPoQ,4291
|
|
@@ -153,6 +154,7 @@ torch_geometric/datasets/snap_dataset.py,sha256=deJvB6cpIQ3bu_pcWoqgEo1-Kl_NcFi7
|
|
|
153
154
|
torch_geometric/datasets/suite_sparse.py,sha256=eqjH4vAUq872qdk3YdLkZSwlu6r7HHpTgK0vEVGmY1s,3278
|
|
154
155
|
torch_geometric/datasets/tag_dataset.py,sha256=qTnwr2N1tbWYeLGbItfv70UxQ3n1rKesjeVU3kcOCP8,14757
|
|
155
156
|
torch_geometric/datasets/taobao.py,sha256=CUcZpbWsNTasevflO8zqP0YvENy89P7wpKS4MHaDJ6Q,4170
|
|
157
|
+
torch_geometric/datasets/teeth3ds.py,sha256=hZvhcq9lsQENNFr5hk50w2T3CgxE_tlnQfrCgN6uIDQ,9919
|
|
156
158
|
torch_geometric/datasets/tosca.py,sha256=nUSF8NQT1GlkwWQLshjWmr8xORsvRHzzIqhUyDCvABc,4632
|
|
157
159
|
torch_geometric/datasets/tu_dataset.py,sha256=14OSaXBgVwT1dX2h1wZ3xVIwoo0GQBEfR3yWh6Q0VF0,7847
|
|
158
160
|
torch_geometric/datasets/twitch.py,sha256=qfEerf-Uaojx2ZvegENowdG4E7RoUT_HUO9xtULadvo,3658
|
|
@@ -295,7 +297,7 @@ torch_geometric/metrics/__init__.py,sha256=3krvDobW6vV5yHTjq2S2pmOXxNfysNG26muq7
|
|
|
295
297
|
torch_geometric/metrics/link_pred.py,sha256=dtaI39JB-WqE1B-raiElns6xySRwmkbb9izbcyt6xHI,30886
|
|
296
298
|
torch_geometric/nn/__init__.py,sha256=kQHHHUxFDht2ztD-XFQuv98TvC8MdodaFsIjAvltJBw,874
|
|
297
299
|
torch_geometric/nn/data_parallel.py,sha256=YiybTWoSFyfSzlXAamZ_-y1f7B6tvDEFHOuy_AyJz9Q,4761
|
|
298
|
-
torch_geometric/nn/encoding.py,sha256=
|
|
300
|
+
torch_geometric/nn/encoding.py,sha256=3DCOCO-XFt-lMb97sHWGN-4KeGUFY5lVo9P00SzrCNk,3559
|
|
299
301
|
torch_geometric/nn/fx.py,sha256=PDtaHJAgodh4xf8FNl4fVxPGZJDbRaq3Q9z8qb1DNNI,16063
|
|
300
302
|
torch_geometric/nn/glob.py,sha256=MdHjcUlHmFmTevzwND1_x7dXXJPzIDTBJRGOrGdZ8dQ,1088
|
|
301
303
|
torch_geometric/nn/inits.py,sha256=_8FqacCLPz5Ft2zB5s6dtKGTKWtfrLyCLLuv1QvyKjk,2457
|
|
@@ -325,13 +327,13 @@ torch_geometric/nn/aggr/lcm.py,sha256=TcNqEvHnWpqOc9RFFioBAssQaUhOgMpH1_ovOmgl3w
|
|
|
325
327
|
torch_geometric/nn/aggr/lstm.py,sha256=AdLa4rDd8t_X-GADDTOzRFuifSA0tIYVGKfoOckVtUE,2214
|
|
326
328
|
torch_geometric/nn/aggr/mlp.py,sha256=sHQ4vQcZ-h2aOfFIBiXpAjr2lj7zHT3_TyqQr3WUjxQ,2514
|
|
327
329
|
torch_geometric/nn/aggr/multi.py,sha256=theSIaDlLjGUyAtqDvOFORRpI9gYoZMXUtypX1PV5NQ,8170
|
|
328
|
-
torch_geometric/nn/aggr/patch_transformer.py,sha256=
|
|
330
|
+
torch_geometric/nn/aggr/patch_transformer.py,sha256=tWWBqBIuIPJfvFhkEs-S8cdEhuU1qxHsxoLh_ZnHznw,5498
|
|
329
331
|
torch_geometric/nn/aggr/quantile.py,sha256=sRnKyt4CXr9RmjoPyTl4VUvXgSCMl9PG-fhCGsSZ76c,6189
|
|
330
332
|
torch_geometric/nn/aggr/scaler.py,sha256=GV6gxUFBoKYMQTGybwzoPh708OY6k6chtUYmCIbFGXk,4638
|
|
331
333
|
torch_geometric/nn/aggr/set2set.py,sha256=4GdmsjbBIrap3CG2naeFNsYe5eE-fhrNQOXM1-TIxyM,2446
|
|
332
334
|
torch_geometric/nn/aggr/set_transformer.py,sha256=FG7_JizpFX14M6VSCwLSjYXYdJ1ZiQVbvnaYHIraiuM,4213
|
|
333
335
|
torch_geometric/nn/aggr/sort.py,sha256=bvOOWnFkNOBOZih4rqVZQsjfeDX3vmXo1bpPSFD846w,2507
|
|
334
|
-
torch_geometric/nn/aggr/utils.py,sha256=
|
|
336
|
+
torch_geometric/nn/aggr/utils.py,sha256=SQvdc0g6p_E2j0prA14MW2ekjEDvV-g545N0Q85uc-o,8625
|
|
335
337
|
torch_geometric/nn/aggr/variance_preserving.py,sha256=fu-U_aGYpVLpgSFvVg0ONMe6nqoyv8tZ6Y35qMYTf9w,1126
|
|
336
338
|
torch_geometric/nn/attention/__init__.py,sha256=wLKTmlfP7qL9sZHy4cmDFHEtdwa-MEKE1dT51L1_w10,192
|
|
337
339
|
torch_geometric/nn/attention/performer.py,sha256=2PCDn4_-oNTao2-DkXIaoi18anP01OxRELF2pvp-jk8,7357
|
|
@@ -374,7 +376,7 @@ torch_geometric/nn/conv/hgt_conv.py,sha256=lUhTWUMovMtn9yR_b2-kLNLqHChGOUl2OtXBY
|
|
|
374
376
|
torch_geometric/nn/conv/hypergraph_conv.py,sha256=4BosbbqJyprlI6QjPqIfMxCqnARU_0mUn1zcAQhbw90,8691
|
|
375
377
|
torch_geometric/nn/conv/le_conv.py,sha256=DonmmYZOKk5wIlTZzzIfNKqBY6MO0MRxYhyr0YtNz-Q,3494
|
|
376
378
|
torch_geometric/nn/conv/lg_conv.py,sha256=8jMa79iPsOUbXEfBIc3wmbvAD8T3d1j37LeIFTX3Yag,2369
|
|
377
|
-
torch_geometric/nn/conv/meshcnn_conv.py,sha256=
|
|
379
|
+
torch_geometric/nn/conv/meshcnn_conv.py,sha256=92zUcgfS0Fwv-MpddF4Ia1a65y7ddPAkazYf7D6kvwg,21951
|
|
378
380
|
torch_geometric/nn/conv/message_passing.py,sha256=ZuTvSvodGy1GyAW4mHtuoMUuxclam-7opidYNY5IHm8,44377
|
|
379
381
|
torch_geometric/nn/conv/mf_conv.py,sha256=SkOGMN1tFT9dcqy8xYowsB2ozw6QfkoArgR1BksZZaU,4340
|
|
380
382
|
torch_geometric/nn/conv/mixhop_conv.py,sha256=qVDPWeWcnO7_eHM0ZnpKtr8SISjb4jp0xjgpoDrwjlk,4555
|
|
@@ -429,7 +431,7 @@ torch_geometric/nn/kge/distmult.py,sha256=dGQ0bVzjreZgFN1lXE23_IIidsiOq7ehPrMb-N
|
|
|
429
431
|
torch_geometric/nn/kge/loader.py,sha256=5Uc1j3OUMQnBYSHDqL7pLCty1siFLzoPkztigYO2zP8,771
|
|
430
432
|
torch_geometric/nn/kge/rotate.py,sha256=XLuO1AbyTt5cJxr97ZzoyAyIEsHKesgW5TvDmnGJAao,3208
|
|
431
433
|
torch_geometric/nn/kge/transe.py,sha256=jlejq5BLMm-sb1wWcLDp7pZqCdelWBgjDIC8ctbjSdU,3088
|
|
432
|
-
torch_geometric/nn/models/__init__.py,sha256=
|
|
434
|
+
torch_geometric/nn/models/__init__.py,sha256=fbHQauZw9Snvl2PuN5cjZoAW8SwUl6E-p2IOmwUKB3A,2395
|
|
433
435
|
torch_geometric/nn/models/attentive_fp.py,sha256=1z3iTV2O5W9tqHFAdno8FeBFeXmuG-TDZk4lwwVh3Ac,6634
|
|
434
436
|
torch_geometric/nn/models/attract_repel.py,sha256=h9OyogT0NY0xiT0DkpJHMxH6ZUmo8R-CmwZdKEwq8Ek,5277
|
|
435
437
|
torch_geometric/nn/models/autoencoder.py,sha256=nGje-zty78Y3hxOJ9o0_6QziJjOvBlknk6z0_fDQwQU,10770
|
|
@@ -442,7 +444,7 @@ torch_geometric/nn/models/dimenet.py,sha256=O2rqEx5HWs_lMwRD8eq6WMkbqJaCLL5zgWUJ
|
|
|
442
444
|
torch_geometric/nn/models/dimenet_utils.py,sha256=Eyn_EiJqwKvuYj6BtRpSxrzMG3v4Gk98X9MxZ7uvwm4,5069
|
|
443
445
|
torch_geometric/nn/models/g_retriever.py,sha256=tVibbqM_r-1LnA3R3oVyzp0bpuN3qPoYqcU6LZ8dYEk,8260
|
|
444
446
|
torch_geometric/nn/models/git_mol.py,sha256=Wc6Hx6RDDR7sDWRWHfA5eK9e9gFsrTZ9OLmpMfoj3pE,12676
|
|
445
|
-
torch_geometric/nn/models/glem.py,sha256=
|
|
447
|
+
torch_geometric/nn/models/glem.py,sha256=GlL_I63g-_5eTycSGRj720YntldQ-CQ351RaDPc6XAU,16674
|
|
446
448
|
torch_geometric/nn/models/gnnff.py,sha256=15dkiLgy0LmH1hnUrpeoHioIp4BPTfjpVATpnGRt9E0,7860
|
|
447
449
|
torch_geometric/nn/models/gpse.py,sha256=acEAeeicLgzKRL54WhvIFxjA5XViHgXgMEH-NgbMdqI,41971
|
|
448
450
|
torch_geometric/nn/models/graph_mixer.py,sha256=mthMeCOikR8gseEsu4oJ3Cd9C35zHSv1p32ROwnG-6s,9246
|
|
@@ -459,6 +461,7 @@ torch_geometric/nn/models/molecule_gpt.py,sha256=k-XULH6jaurj-R2EE4sIWTkqlNqa3Cz
|
|
|
459
461
|
torch_geometric/nn/models/neural_fingerprint.py,sha256=pTLJgU9Uh2Lnf9bggLj4cKI8YdEFcMF-9MALuubqbuQ,2378
|
|
460
462
|
torch_geometric/nn/models/node2vec.py,sha256=81Ku4Rp4IwLEAy06KEgJ2fYtXXVL_uv_Hb8lBr6YXrE,7664
|
|
461
463
|
torch_geometric/nn/models/pmlp.py,sha256=dcAASVSyQMMhItSfEJWPeAFh0R3tNCwAHwdrShwQ8o4,3538
|
|
464
|
+
torch_geometric/nn/models/protein_mpnn.py,sha256=QXHfltiJPmakpzgJKw_1vwCGBlszv9nfY4r4F38Sg9k,11031
|
|
462
465
|
torch_geometric/nn/models/re_net.py,sha256=pz66q5b5BoGDNVQvpEGS2RGoeKvpjkYAv9r3WAuvITk,8986
|
|
463
466
|
torch_geometric/nn/models/rect.py,sha256=2F3XyyvHTAEuqfJpiNB5M8pSGy738LhPiom5I-SDWqM,2808
|
|
464
467
|
torch_geometric/nn/models/rev_gnn.py,sha256=Bpme087Zs227lcB0ODOKWsxaly67q96wseaRt6bacjs,11796
|
|
@@ -472,14 +475,14 @@ torch_geometric/nn/nlp/llm.py,sha256=DAv9jOZKXKQNVU2pNMyS1q8gVUtlin_unc6FjLhOYto
|
|
|
472
475
|
torch_geometric/nn/nlp/sentence_transformer.py,sha256=q5M7SGtrUzoSiNhKCGFb7JatWiukdhNF6zdq2yiqxwE,4475
|
|
473
476
|
torch_geometric/nn/nlp/vision_transformer.py,sha256=diVBefjIynzYs8WBlcpTeSVnw1PUecHY--B9Yd-W2hA,863
|
|
474
477
|
torch_geometric/nn/norm/__init__.py,sha256=u2qIDrkbeuObGVXSAIftAlvSd6ouGTtxznCfD-59UiA,669
|
|
475
|
-
torch_geometric/nn/norm/batch_norm.py,sha256=
|
|
476
|
-
torch_geometric/nn/norm/diff_group_norm.py,sha256=
|
|
477
|
-
torch_geometric/nn/norm/graph_norm.py,sha256=
|
|
478
|
+
torch_geometric/nn/norm/batch_norm.py,sha256=fzUNmpdCUsMnNcso3PKDUdWc0UQvziK80-w0ZC0Vb8U,8706
|
|
479
|
+
torch_geometric/nn/norm/diff_group_norm.py,sha256=mT0gM5a8txcAFNwZGKFu12qnNF5Pn95zrMx-RisRsh4,4938
|
|
480
|
+
torch_geometric/nn/norm/graph_norm.py,sha256=VRmpi2jNYRQWXzX6Z0FmBxdEiV-EYXNPbGAGC0XNKH8,2964
|
|
478
481
|
torch_geometric/nn/norm/graph_size_norm.py,sha256=sh5Nue1Ix2jC1T7o7KqOw0_TAOcpZ4VbYzhADWE97-M,1491
|
|
479
|
-
torch_geometric/nn/norm/instance_norm.py,sha256=
|
|
480
|
-
torch_geometric/nn/norm/layer_norm.py,sha256=
|
|
482
|
+
torch_geometric/nn/norm/instance_norm.py,sha256=L8VquSF7Jh5xfxA4YEcSO3IZ4fWR7VIiSukNeRXi4z0,4870
|
|
483
|
+
torch_geometric/nn/norm/layer_norm.py,sha256=XiEyoXdDta6vlInLfwbJVsEgTkBFG6PJm6SpK99-cPE,8243
|
|
481
484
|
torch_geometric/nn/norm/mean_subtraction_norm.py,sha256=KVHOp413mw7obwAN09Le6XdgobtCXpi4UKpjpG1M550,1322
|
|
482
|
-
torch_geometric/nn/norm/msg_norm.py,sha256=
|
|
485
|
+
torch_geometric/nn/norm/msg_norm.py,sha256=NiV51ce1JgxVY6GbzktoSslDnZKWJrMJYZc_eBxz-pg,1903
|
|
483
486
|
torch_geometric/nn/norm/pair_norm.py,sha256=IfHMiVYw_xsy035NakbPGdQVaVC-Ue3Oxwo651Vc47I,2824
|
|
484
487
|
torch_geometric/nn/pool/__init__.py,sha256=VU9cPdLC-MPgt1kfS0ZwehfSD3g0V30VQuR1Wo0mzZE,14250
|
|
485
488
|
torch_geometric/nn/pool/approx_knn.py,sha256=n7C8Cbar6o5tJcuAbzhM5hqMK26hW8dm5DopuocidO0,3967
|
|
@@ -613,7 +616,7 @@ torch_geometric/utils/_tree_decomposition.py,sha256=ZtpjPQJgXbQWtSWjo-Fmhrov0DGO
|
|
|
613
616
|
torch_geometric/utils/_trim_to_layer.py,sha256=cauOEzMJJK4w9BC-Pg1bHVncBYqG9XxQex3rn10BFjc,8339
|
|
614
617
|
torch_geometric/utils/_unbatch.py,sha256=B0vjKI96PtHvSBG8F_lqvsiJE134aVjUurPZsG6UZRI,2378
|
|
615
618
|
torch_geometric/utils/augmentation.py,sha256=1F0YCuaklZ9ZbXxdFV0oOoemWvLd8p60WvFo2chzl7E,8600
|
|
616
|
-
torch_geometric/utils/convert.py,sha256=
|
|
619
|
+
torch_geometric/utils/convert.py,sha256=RE5n5no74Xu39-QMWFE0-1RvTgykdK33ymyjF9WcuSs,21938
|
|
617
620
|
torch_geometric/utils/cross_entropy.py,sha256=ZFS5bivtzv3EV9zqgKsekmuQyoZZggPSclhl_tRNHxo,3047
|
|
618
621
|
torch_geometric/utils/dropout.py,sha256=gg0rDnD4FLvBaKSoLAkZwViAQflhLefJm6_Mju5dmQs,11416
|
|
619
622
|
torch_geometric/utils/embedding.py,sha256=Ac_MPSrZGpw-e-gU6Yz-seUioC2WZxBSSzXFeclGwMk,5232
|
|
@@ -634,13 +637,13 @@ torch_geometric/utils/num_nodes.py,sha256=F15ciTFOe8AxjkUh1wKH7RLmJvQYYpz-l3pPPv
|
|
|
634
637
|
torch_geometric/utils/ppr.py,sha256=ebiHbQqRJsQbGUI5xu-IkzQSQsgIaC71vgO0KcXIKAk,4055
|
|
635
638
|
torch_geometric/utils/random.py,sha256=Rv5HlhG5310rytbT9EZ7xWLGKQfozfz1azvYi5nx2-U,5148
|
|
636
639
|
torch_geometric/utils/repeat.py,sha256=RxCoRoEisaP6NouXPPW5tY1Rn-tIfrmpJPm0qGP6W8M,815
|
|
637
|
-
torch_geometric/utils/smiles.py,sha256=
|
|
640
|
+
torch_geometric/utils/smiles.py,sha256=CFqeNtSBXQtY9Ex2gQzI0La490IpVVrm01QdRYEpV7w,7114
|
|
638
641
|
torch_geometric/utils/sparse.py,sha256=1DbaEwdyvnzvg5qVjPlnWcEVDMkxrQLX1jJ0dr6P4js,25135
|
|
639
642
|
torch_geometric/utils/undirected.py,sha256=H_nfpI0_WluOG6VfjPyldvcjL4w5USAKWu2x5nUAUjw,6222
|
|
640
643
|
torch_geometric/visualization/__init__.py,sha256=b-HnVesXjyJ_L1N-DnjiRiRVf7lhwKaBQF_2i5YMVSU,208
|
|
641
644
|
torch_geometric/visualization/graph.py,sha256=mfZHXYfiU-CWMtfawYc80IxVwVmtK9hbIkSKhM_j7oI,14311
|
|
642
645
|
torch_geometric/visualization/influence.py,sha256=CWMvuNA_Nf1sfbJmQgn58yS4OFpeKXeZPe7kEuvkUBw,477
|
|
643
|
-
pyg_nightly-2.7.0.
|
|
644
|
-
pyg_nightly-2.7.0.
|
|
645
|
-
pyg_nightly-2.7.0.
|
|
646
|
-
pyg_nightly-2.7.0.
|
|
646
|
+
pyg_nightly-2.7.0.dev20250703.dist-info/licenses/LICENSE,sha256=ic-27cMJc1kWoMEYncz3Ya3Ur2Bi3bNLWib2DT763-o,1067
|
|
647
|
+
pyg_nightly-2.7.0.dev20250703.dist-info/WHEEL,sha256=G2gURzTEtmeR8nrdXUJfNiB3VYVxigPQ-bEQujpNiNs,82
|
|
648
|
+
pyg_nightly-2.7.0.dev20250703.dist-info/METADATA,sha256=lm-xVNswGkfSaczkhV_ANegM0oGwBEfhCv75gwj0X5Q,63005
|
|
649
|
+
pyg_nightly-2.7.0.dev20250703.dist-info/RECORD,,
|
torch_geometric/__init__.py
CHANGED
|
@@ -31,7 +31,7 @@ from .lazy_loader import LazyLoader
|
|
|
31
31
|
contrib = LazyLoader('contrib', globals(), 'torch_geometric.contrib')
|
|
32
32
|
graphgym = LazyLoader('graphgym', globals(), 'torch_geometric.graphgym')
|
|
33
33
|
|
|
34
|
-
__version__ = '2.7.0.
|
|
34
|
+
__version__ = '2.7.0.dev20250703'
|
|
35
35
|
|
|
36
36
|
__all__ = [
|
|
37
37
|
'Index',
|
|
@@ -81,8 +81,10 @@ from .web_qsp_dataset import WebQSPDataset, CWQDataset
|
|
|
81
81
|
from .git_mol_dataset import GitMolDataset
|
|
82
82
|
from .molecule_gpt_dataset import MoleculeGPTDataset
|
|
83
83
|
from .instruct_mol_dataset import InstructMolDataset
|
|
84
|
+
from .protein_mpnn_dataset import ProteinMPNNDataset
|
|
84
85
|
from .tag_dataset import TAGDataset
|
|
85
86
|
from .city import CityNetwork
|
|
87
|
+
from .teeth3ds import Teeth3DS
|
|
86
88
|
|
|
87
89
|
from .dbp15k import DBP15K
|
|
88
90
|
from .aminer import AMiner
|
|
@@ -201,8 +203,10 @@ homo_datasets = [
|
|
|
201
203
|
'GitMolDataset',
|
|
202
204
|
'MoleculeGPTDataset',
|
|
203
205
|
'InstructMolDataset',
|
|
206
|
+
'ProteinMPNNDataset',
|
|
204
207
|
'TAGDataset',
|
|
205
208
|
'CityNetwork',
|
|
209
|
+
'Teeth3DS',
|
|
206
210
|
]
|
|
207
211
|
|
|
208
212
|
hetero_datasets = [
|
|
@@ -0,0 +1,451 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import pickle
|
|
3
|
+
import random
|
|
4
|
+
from collections import defaultdict
|
|
5
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import torch
|
|
9
|
+
from tqdm import tqdm
|
|
10
|
+
|
|
11
|
+
from torch_geometric.data import (
|
|
12
|
+
Data,
|
|
13
|
+
InMemoryDataset,
|
|
14
|
+
download_url,
|
|
15
|
+
extract_tar,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class ProteinMPNNDataset(InMemoryDataset):
|
|
20
|
+
r"""The ProteinMPNN dataset from the `"Robust deep learning based protein
|
|
21
|
+
sequence design using ProteinMPNN"
|
|
22
|
+
<https://www.biorxiv.org/content/10.1101/2022.06.03.494563v1>`_ paper.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
root (str): Root directory where the dataset should be saved.
|
|
26
|
+
size (str): Size of the PDB information to train the model.
|
|
27
|
+
If :obj:`"small"`, loads the small dataset (229.4 MB).
|
|
28
|
+
If :obj:`"large"`, loads the large dataset (64.1 GB).
|
|
29
|
+
(default: :obj:`"small"`)
|
|
30
|
+
split (str, optional): If :obj:`"train"`, loads the training dataset.
|
|
31
|
+
If :obj:`"valid"`, loads the validation dataset.
|
|
32
|
+
If :obj:`"test"`, loads the test dataset.
|
|
33
|
+
(default: :obj:`"train"`)
|
|
34
|
+
datacut (str, optional): Date cutoff to filter the dataset.
|
|
35
|
+
(default: :obj:`"2030-01-01"`)
|
|
36
|
+
rescut (float, optional): PDB resolution cutoff.
|
|
37
|
+
(default: :obj:`3.5`)
|
|
38
|
+
homo (float, optional): Homology cutoff.
|
|
39
|
+
(default: :obj:`0.70`)
|
|
40
|
+
max_length (int, optional): Maximum length of the protein complex.
|
|
41
|
+
(default: :obj:`10000`)
|
|
42
|
+
num_units (int, optional): Number of units of the protein complex.
|
|
43
|
+
(default: :obj:`150`)
|
|
44
|
+
transform (callable, optional): A function/transform that takes in an
|
|
45
|
+
:obj:`torch_geometric.data.Data` object and returns a transformed
|
|
46
|
+
version. The data object will be transformed before every access.
|
|
47
|
+
(default: :obj:`None`)
|
|
48
|
+
pre_transform (callable, optional): A function/transform that takes in
|
|
49
|
+
an :obj:`torch_geometric.data.Data` object and returns a
|
|
50
|
+
transformed version. The data object will be transformed before
|
|
51
|
+
being saved to disk. (default: :obj:`None`)
|
|
52
|
+
pre_filter (callable, optional): A function that takes in an
|
|
53
|
+
:obj:`torch_geometric.data.Data` object and returns a boolean
|
|
54
|
+
value, indicating whether the data object should be included in the
|
|
55
|
+
final dataset. (default: :obj:`None`)
|
|
56
|
+
force_reload (bool, optional): Whether to re-process the dataset.
|
|
57
|
+
(default: :obj:`False`)
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
raw_url = {
|
|
61
|
+
'small':
|
|
62
|
+
'https://files.ipd.uw.edu/pub/training_sets/'
|
|
63
|
+
'pdb_2021aug02_sample.tar.gz',
|
|
64
|
+
'large':
|
|
65
|
+
'https://files.ipd.uw.edu/pub/training_sets/'
|
|
66
|
+
'pdb_2021aug02.tar.gz',
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
splits = {
|
|
70
|
+
'train': 1,
|
|
71
|
+
'valid': 2,
|
|
72
|
+
'test': 3,
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
def __init__(
|
|
76
|
+
self,
|
|
77
|
+
root: str,
|
|
78
|
+
size: str = 'small',
|
|
79
|
+
split: str = 'train',
|
|
80
|
+
datacut: str = '2030-01-01',
|
|
81
|
+
rescut: float = 3.5,
|
|
82
|
+
homo: float = 0.70,
|
|
83
|
+
max_length: int = 10000,
|
|
84
|
+
num_units: int = 150,
|
|
85
|
+
transform: Optional[Callable] = None,
|
|
86
|
+
pre_transform: Optional[Callable] = None,
|
|
87
|
+
pre_filter: Optional[Callable] = None,
|
|
88
|
+
force_reload: bool = False,
|
|
89
|
+
) -> None:
|
|
90
|
+
self.size = size
|
|
91
|
+
self.split = split
|
|
92
|
+
self.datacut = datacut
|
|
93
|
+
self.rescut = rescut
|
|
94
|
+
self.homo = homo
|
|
95
|
+
self.max_length = max_length
|
|
96
|
+
self.num_units = num_units
|
|
97
|
+
|
|
98
|
+
self.sub_folder = self.raw_url[self.size].split('/')[-1].split('.')[0]
|
|
99
|
+
|
|
100
|
+
super().__init__(root, transform, pre_transform, pre_filter,
|
|
101
|
+
force_reload=force_reload)
|
|
102
|
+
self.load(self.processed_paths[self.splits[self.split]])
|
|
103
|
+
|
|
104
|
+
@property
|
|
105
|
+
def raw_file_names(self) -> List[str]:
|
|
106
|
+
return [
|
|
107
|
+
f'{self.sub_folder}/{f}'
|
|
108
|
+
for f in ['list.csv', 'valid_clusters.txt', 'test_clusters.txt']
|
|
109
|
+
]
|
|
110
|
+
|
|
111
|
+
@property
|
|
112
|
+
def processed_file_names(self) -> List[str]:
|
|
113
|
+
return ['splits.pkl', 'train.pt', 'valid.pt', 'test.pt']
|
|
114
|
+
|
|
115
|
+
def download(self) -> None:
|
|
116
|
+
file_path = download_url(self.raw_url[self.size], self.raw_dir)
|
|
117
|
+
extract_tar(file_path, self.raw_dir)
|
|
118
|
+
os.unlink(file_path)
|
|
119
|
+
|
|
120
|
+
def process(self) -> None:
|
|
121
|
+
alphabet_set = set(list('ACDEFGHIKLMNPQRSTVWYX'))
|
|
122
|
+
cluster_ids = self._process_split()
|
|
123
|
+
total_items = sum(len(items) for items in cluster_ids.values())
|
|
124
|
+
data_list = []
|
|
125
|
+
|
|
126
|
+
with tqdm(total=total_items, desc="Processing") as pbar:
|
|
127
|
+
for _, items in cluster_ids.items():
|
|
128
|
+
for chain_id, _ in items:
|
|
129
|
+
item = self._process_pdb1(chain_id)
|
|
130
|
+
|
|
131
|
+
if 'label' not in item:
|
|
132
|
+
pbar.update(1)
|
|
133
|
+
continue
|
|
134
|
+
if len(list(np.unique(item['idx']))) >= 352:
|
|
135
|
+
pbar.update(1)
|
|
136
|
+
continue
|
|
137
|
+
|
|
138
|
+
my_dict = self._process_pdb2(item)
|
|
139
|
+
|
|
140
|
+
if len(my_dict['seq']) > self.max_length:
|
|
141
|
+
pbar.update(1)
|
|
142
|
+
continue
|
|
143
|
+
bad_chars = set(list(
|
|
144
|
+
my_dict['seq'])).difference(alphabet_set)
|
|
145
|
+
if len(bad_chars) > 0:
|
|
146
|
+
pbar.update(1)
|
|
147
|
+
continue
|
|
148
|
+
|
|
149
|
+
x_chain_all, chain_seq_label_all, mask, chain_mask_all, residue_idx, chain_encoding_all = self._process_pdb3( # noqa: E501
|
|
150
|
+
my_dict)
|
|
151
|
+
|
|
152
|
+
data = Data(
|
|
153
|
+
x=x_chain_all, # [seq_len, 4, 3]
|
|
154
|
+
chain_seq_label=chain_seq_label_all, # [seq_len]
|
|
155
|
+
mask=mask, # [seq_len]
|
|
156
|
+
chain_mask_all=chain_mask_all, # [seq_len]
|
|
157
|
+
residue_idx=residue_idx, # [seq_len]
|
|
158
|
+
chain_encoding_all=chain_encoding_all, # [seq_len]
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
if self.pre_filter is not None and not self.pre_filter(
|
|
162
|
+
data):
|
|
163
|
+
continue
|
|
164
|
+
if self.pre_transform is not None:
|
|
165
|
+
data = self.pre_transform(data)
|
|
166
|
+
|
|
167
|
+
data_list.append(data)
|
|
168
|
+
|
|
169
|
+
if len(data_list) >= self.num_units:
|
|
170
|
+
pbar.update(total_items - pbar.n)
|
|
171
|
+
break
|
|
172
|
+
pbar.update(1)
|
|
173
|
+
else:
|
|
174
|
+
continue
|
|
175
|
+
break
|
|
176
|
+
self.save(data_list, self.processed_paths[self.splits[self.split]])
|
|
177
|
+
|
|
178
|
+
def _process_split(self) -> Dict[int, List[Tuple[str, int]]]:
|
|
179
|
+
import pandas as pd
|
|
180
|
+
save_path = self.processed_paths[0]
|
|
181
|
+
|
|
182
|
+
if os.path.exists(save_path):
|
|
183
|
+
print('Load split')
|
|
184
|
+
with open(save_path, 'rb') as f:
|
|
185
|
+
data = pickle.load(f)
|
|
186
|
+
else:
|
|
187
|
+
# CHAINID, DEPOSITION, RESOLUTION, HASH, CLUSTER, SEQUENCE
|
|
188
|
+
df = pd.read_csv(self.raw_paths[0])
|
|
189
|
+
df = df[(df['RESOLUTION'] <= self.rescut)
|
|
190
|
+
& (df['DEPOSITION'] <= self.datacut)]
|
|
191
|
+
|
|
192
|
+
val_ids = pd.read_csv(self.raw_paths[1], header=None)[0].tolist()
|
|
193
|
+
test_ids = pd.read_csv(self.raw_paths[2], header=None)[0].tolist()
|
|
194
|
+
|
|
195
|
+
# compile training and validation sets
|
|
196
|
+
data = {
|
|
197
|
+
'train': defaultdict(list),
|
|
198
|
+
'valid': defaultdict(list),
|
|
199
|
+
'test': defaultdict(list),
|
|
200
|
+
}
|
|
201
|
+
|
|
202
|
+
for _, r in tqdm(df.iterrows(), desc='Processing split',
|
|
203
|
+
total=len(df)):
|
|
204
|
+
cluster_id = r['CLUSTER']
|
|
205
|
+
hash_id = r['HASH']
|
|
206
|
+
chain_id = r['CHAINID']
|
|
207
|
+
if cluster_id in val_ids:
|
|
208
|
+
data['valid'][cluster_id].append((chain_id, hash_id))
|
|
209
|
+
elif cluster_id in test_ids:
|
|
210
|
+
data['test'][cluster_id].append((chain_id, hash_id))
|
|
211
|
+
else:
|
|
212
|
+
data['train'][cluster_id].append((chain_id, hash_id))
|
|
213
|
+
|
|
214
|
+
with open(save_path, 'wb') as f:
|
|
215
|
+
pickle.dump(data, f)
|
|
216
|
+
|
|
217
|
+
return data[self.split]
|
|
218
|
+
|
|
219
|
+
def _process_pdb1(self, chain_id: str) -> Dict[str, Any]:
|
|
220
|
+
pdbid, chid = chain_id.split('_')
|
|
221
|
+
prefix = f'{self.raw_dir}/{self.sub_folder}/pdb/{pdbid[1:3]}/{pdbid}'
|
|
222
|
+
# load metadata
|
|
223
|
+
if not os.path.isfile(f'{prefix}.pt'):
|
|
224
|
+
return {'seq': np.zeros(5)}
|
|
225
|
+
meta = torch.load(f'{prefix}.pt')
|
|
226
|
+
asmb_ids = meta['asmb_ids']
|
|
227
|
+
asmb_chains = meta['asmb_chains']
|
|
228
|
+
chids = np.array(meta['chains'])
|
|
229
|
+
|
|
230
|
+
# find candidate assemblies which contain chid chain
|
|
231
|
+
asmb_candidates = {
|
|
232
|
+
a
|
|
233
|
+
for a, b in zip(asmb_ids, asmb_chains) if chid in b.split(',')
|
|
234
|
+
}
|
|
235
|
+
|
|
236
|
+
# if the chains is missing is missing from all the assemblies
|
|
237
|
+
# then return this chain alone
|
|
238
|
+
if len(asmb_candidates) < 1:
|
|
239
|
+
chain = torch.load(f'{prefix}_{chid}.pt')
|
|
240
|
+
L = len(chain['seq'])
|
|
241
|
+
return {
|
|
242
|
+
'seq': chain['seq'],
|
|
243
|
+
'xyz': chain['xyz'],
|
|
244
|
+
'idx': torch.zeros(L).int(),
|
|
245
|
+
'masked': torch.Tensor([0]).int(),
|
|
246
|
+
'label': chain_id,
|
|
247
|
+
}
|
|
248
|
+
|
|
249
|
+
# randomly pick one assembly from candidates
|
|
250
|
+
asmb_i = random.sample(list(asmb_candidates), 1)
|
|
251
|
+
|
|
252
|
+
# indices of selected transforms
|
|
253
|
+
idx = np.where(np.array(asmb_ids) == asmb_i)[0]
|
|
254
|
+
|
|
255
|
+
# load relevant chains
|
|
256
|
+
chains = {
|
|
257
|
+
c: torch.load(f'{prefix}_{c}.pt')
|
|
258
|
+
for i in idx
|
|
259
|
+
for c in asmb_chains[i] if c in meta['chains']
|
|
260
|
+
}
|
|
261
|
+
|
|
262
|
+
# generate assembly
|
|
263
|
+
asmb = {}
|
|
264
|
+
for k in idx:
|
|
265
|
+
|
|
266
|
+
# pick k-th xform
|
|
267
|
+
xform = meta[f'asmb_xform{k}']
|
|
268
|
+
u = xform[:, :3, :3]
|
|
269
|
+
r = xform[:, :3, 3]
|
|
270
|
+
|
|
271
|
+
# select chains which k-th xform should be applied to
|
|
272
|
+
s1 = set(meta['chains'])
|
|
273
|
+
s2 = set(asmb_chains[k].split(','))
|
|
274
|
+
chains_k = s1 & s2
|
|
275
|
+
|
|
276
|
+
# transform selected chains
|
|
277
|
+
for c in chains_k:
|
|
278
|
+
try:
|
|
279
|
+
xyz = chains[c]['xyz']
|
|
280
|
+
xyz_ru = torch.einsum('bij,raj->brai', u, xyz) + r[:, None,
|
|
281
|
+
None, :]
|
|
282
|
+
asmb.update({
|
|
283
|
+
(c, k, i): xyz_i
|
|
284
|
+
for i, xyz_i in enumerate(xyz_ru)
|
|
285
|
+
})
|
|
286
|
+
except KeyError:
|
|
287
|
+
return {'seq': np.zeros(5)}
|
|
288
|
+
|
|
289
|
+
# select chains which share considerable similarity to chid
|
|
290
|
+
seqid = meta['tm'][chids == chid][0, :, 1]
|
|
291
|
+
homo = {
|
|
292
|
+
ch_j
|
|
293
|
+
for seqid_j, ch_j in zip(seqid, chids) if seqid_j > self.homo
|
|
294
|
+
}
|
|
295
|
+
# stack all chains in the assembly together
|
|
296
|
+
seq: str = ''
|
|
297
|
+
xyz_all: List[torch.Tensor] = []
|
|
298
|
+
idx_all: List[torch.Tensor] = []
|
|
299
|
+
masked: List[int] = []
|
|
300
|
+
seq_list = []
|
|
301
|
+
for counter, (k, v) in enumerate(asmb.items()):
|
|
302
|
+
seq += chains[k[0]]['seq']
|
|
303
|
+
seq_list.append(chains[k[0]]['seq'])
|
|
304
|
+
xyz_all.append(v)
|
|
305
|
+
idx_all.append(torch.full((v.shape[0], ), counter))
|
|
306
|
+
if k[0] in homo:
|
|
307
|
+
masked.append(counter)
|
|
308
|
+
|
|
309
|
+
return {
|
|
310
|
+
'seq': seq,
|
|
311
|
+
'xyz': torch.cat(xyz_all, dim=0),
|
|
312
|
+
'idx': torch.cat(idx_all, dim=0),
|
|
313
|
+
'masked': torch.Tensor(masked).int(),
|
|
314
|
+
'label': chain_id,
|
|
315
|
+
}
|
|
316
|
+
|
|
317
|
+
def _process_pdb2(self, t: Dict[str, Any]) -> Dict[str, Any]:
|
|
318
|
+
init_alphabet = list(
|
|
319
|
+
'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz')
|
|
320
|
+
extra_alphabet = [str(item) for item in list(np.arange(300))]
|
|
321
|
+
chain_alphabet = init_alphabet + extra_alphabet
|
|
322
|
+
my_dict: Dict[str, Union[str, int, Dict[str, Any], List[Any]]] = {}
|
|
323
|
+
concat_seq = ''
|
|
324
|
+
mask_list = []
|
|
325
|
+
visible_list = []
|
|
326
|
+
for idx in list(np.unique(t['idx'])):
|
|
327
|
+
letter = chain_alphabet[idx]
|
|
328
|
+
res = np.argwhere(t['idx'] == idx)
|
|
329
|
+
initial_sequence = "".join(list(
|
|
330
|
+
np.array(list(t['seq']))[res][
|
|
331
|
+
0,
|
|
332
|
+
]))
|
|
333
|
+
if initial_sequence[-6:] == "HHHHHH":
|
|
334
|
+
res = res[:, :-6]
|
|
335
|
+
if initial_sequence[0:6] == "HHHHHH":
|
|
336
|
+
res = res[:, 6:]
|
|
337
|
+
if initial_sequence[-7:-1] == "HHHHHH":
|
|
338
|
+
res = res[:, :-7]
|
|
339
|
+
if initial_sequence[-8:-2] == "HHHHHH":
|
|
340
|
+
res = res[:, :-8]
|
|
341
|
+
if initial_sequence[-9:-3] == "HHHHHH":
|
|
342
|
+
res = res[:, :-9]
|
|
343
|
+
if initial_sequence[-10:-4] == "HHHHHH":
|
|
344
|
+
res = res[:, :-10]
|
|
345
|
+
if initial_sequence[1:7] == "HHHHHH":
|
|
346
|
+
res = res[:, 7:]
|
|
347
|
+
if initial_sequence[2:8] == "HHHHHH":
|
|
348
|
+
res = res[:, 8:]
|
|
349
|
+
if initial_sequence[3:9] == "HHHHHH":
|
|
350
|
+
res = res[:, 9:]
|
|
351
|
+
if initial_sequence[4:10] == "HHHHHH":
|
|
352
|
+
res = res[:, 10:]
|
|
353
|
+
if res.shape[1] >= 4:
|
|
354
|
+
chain_seq = "".join(list(np.array(list(t['seq']))[res][0]))
|
|
355
|
+
my_dict[f'seq_chain_{letter}'] = chain_seq
|
|
356
|
+
concat_seq += chain_seq
|
|
357
|
+
if idx in t['masked']:
|
|
358
|
+
mask_list.append(letter)
|
|
359
|
+
else:
|
|
360
|
+
visible_list.append(letter)
|
|
361
|
+
coords_dict_chain = {}
|
|
362
|
+
all_atoms = np.array(t['xyz'][res])[0] # [L, 14, 3]
|
|
363
|
+
for i, c in enumerate(['N', 'CA', 'C', 'O']):
|
|
364
|
+
coords_dict_chain[
|
|
365
|
+
f'{c}_chain_{letter}'] = all_atoms[:, i, :].tolist()
|
|
366
|
+
my_dict[f'coords_chain_{letter}'] = coords_dict_chain
|
|
367
|
+
my_dict['name'] = t['label']
|
|
368
|
+
my_dict['masked_list'] = mask_list
|
|
369
|
+
my_dict['visible_list'] = visible_list
|
|
370
|
+
my_dict['num_of_chains'] = len(mask_list) + len(visible_list)
|
|
371
|
+
my_dict['seq'] = concat_seq
|
|
372
|
+
return my_dict
|
|
373
|
+
|
|
374
|
+
def _process_pdb3(
|
|
375
|
+
self, b: Dict[str, Any]
|
|
376
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
|
|
377
|
+
torch.Tensor, torch.Tensor]:
|
|
378
|
+
L = len(b['seq'])
|
|
379
|
+
# residue idx with jumps across chains
|
|
380
|
+
residue_idx = -100 * np.ones([L], dtype=np.int32)
|
|
381
|
+
# get the list of masked / visible chains
|
|
382
|
+
masked_chains, visible_chains = b['masked_list'], b['visible_list']
|
|
383
|
+
visible_temp_dict, masked_temp_dict = {}, {}
|
|
384
|
+
for letter in masked_chains + visible_chains:
|
|
385
|
+
chain_seq = b[f'seq_chain_{letter}']
|
|
386
|
+
if letter in visible_chains:
|
|
387
|
+
visible_temp_dict[letter] = chain_seq
|
|
388
|
+
elif letter in masked_chains:
|
|
389
|
+
masked_temp_dict[letter] = chain_seq
|
|
390
|
+
# check for duplicate chains (same sequence but different identity)
|
|
391
|
+
for _, vm in masked_temp_dict.items():
|
|
392
|
+
for kv, vv in visible_temp_dict.items():
|
|
393
|
+
if vm == vv:
|
|
394
|
+
if kv not in masked_chains:
|
|
395
|
+
masked_chains.append(kv)
|
|
396
|
+
if kv in visible_chains:
|
|
397
|
+
visible_chains.remove(kv)
|
|
398
|
+
# build protein data structures
|
|
399
|
+
all_chains = masked_chains + visible_chains
|
|
400
|
+
np.random.shuffle(all_chains)
|
|
401
|
+
x_chain_list = []
|
|
402
|
+
chain_mask_list = []
|
|
403
|
+
chain_seq_list = []
|
|
404
|
+
chain_encoding_list = []
|
|
405
|
+
c, l0, l1 = 1, 0, 0
|
|
406
|
+
for letter in all_chains:
|
|
407
|
+
chain_seq = b[f'seq_chain_{letter}']
|
|
408
|
+
chain_length = len(chain_seq)
|
|
409
|
+
chain_coords = b[f'coords_chain_{letter}']
|
|
410
|
+
x_chain = np.stack([
|
|
411
|
+
chain_coords[c] for c in [
|
|
412
|
+
f'N_chain_{letter}', f'CA_chain_{letter}',
|
|
413
|
+
f'C_chain_{letter}', f'O_chain_{letter}'
|
|
414
|
+
]
|
|
415
|
+
], 1) # [chain_length, 4, 3]
|
|
416
|
+
x_chain_list.append(x_chain)
|
|
417
|
+
chain_seq_list.append(chain_seq)
|
|
418
|
+
if letter in visible_chains:
|
|
419
|
+
chain_mask = np.zeros(chain_length) # 0 for visible chains
|
|
420
|
+
elif letter in masked_chains:
|
|
421
|
+
chain_mask = np.ones(chain_length) # 1 for masked chains
|
|
422
|
+
chain_mask_list.append(chain_mask)
|
|
423
|
+
chain_encoding_list.append(c * np.ones(chain_length))
|
|
424
|
+
l1 += chain_length
|
|
425
|
+
residue_idx[l0:l1] = 100 * (c - 1) + np.arange(l0, l1)
|
|
426
|
+
l0 += chain_length
|
|
427
|
+
c += 1
|
|
428
|
+
x_chain_all = np.concatenate(x_chain_list, 0) # [L, 4, 3]
|
|
429
|
+
chain_seq_all = "".join(chain_seq_list)
|
|
430
|
+
# [L,] 1.0 for places that need to be predicted
|
|
431
|
+
chain_mask_all = np.concatenate(chain_mask_list, 0)
|
|
432
|
+
chain_encoding_all = np.concatenate(chain_encoding_list, 0)
|
|
433
|
+
|
|
434
|
+
# Convert to labels
|
|
435
|
+
alphabet = 'ACDEFGHIKLMNPQRSTVWYX'
|
|
436
|
+
chain_seq_label_all = np.asarray(
|
|
437
|
+
[alphabet.index(a) for a in chain_seq_all], dtype=np.int32)
|
|
438
|
+
|
|
439
|
+
isnan = np.isnan(x_chain_all)
|
|
440
|
+
mask = np.isfinite(np.sum(x_chain_all, (1, 2))).astype(np.float32)
|
|
441
|
+
x_chain_all[isnan] = 0.
|
|
442
|
+
|
|
443
|
+
# Conversion
|
|
444
|
+
return (
|
|
445
|
+
torch.from_numpy(x_chain_all).to(dtype=torch.float32),
|
|
446
|
+
torch.from_numpy(chain_seq_label_all).to(dtype=torch.long),
|
|
447
|
+
torch.from_numpy(mask).to(dtype=torch.float32),
|
|
448
|
+
torch.from_numpy(chain_mask_all).to(dtype=torch.float32),
|
|
449
|
+
torch.from_numpy(residue_idx).to(dtype=torch.long),
|
|
450
|
+
torch.from_numpy(chain_encoding_all).to(dtype=torch.long),
|
|
451
|
+
)
|
torch_geometric/datasets/qm9.py
CHANGED
|
@@ -202,7 +202,7 @@ class QM9(InMemoryDataset):
|
|
|
202
202
|
from rdkit import Chem, RDLogger
|
|
203
203
|
from rdkit.Chem.rdchem import BondType as BT
|
|
204
204
|
from rdkit.Chem.rdchem import HybridizationType
|
|
205
|
-
RDLogger.DisableLog('rdApp.*')
|
|
205
|
+
RDLogger.DisableLog('rdApp.*')
|
|
206
206
|
WITH_RDKIT = True
|
|
207
207
|
|
|
208
208
|
except ImportError:
|