pyg-nightly 2.7.0.dev20250702__py3-none-any.whl → 2.7.0.dev20250704__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pyg-nightly might be problematic. Click here for more details.

@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pyg-nightly
3
- Version: 2.7.0.dev20250702
3
+ Version: 2.7.0.dev20250704
4
4
  Summary: Graph Neural Network Library for PyTorch
5
5
  Keywords: deep-learning,pytorch,geometric-deep-learning,graph-neural-networks,graph-convolutional-networks
6
6
  Author-email: Matthias Fey <matthias@pyg.org>
@@ -1,4 +1,4 @@
1
- torch_geometric/__init__.py,sha256=ap-t4q8f9aTE0oAW_K5390u2Mlk8-S76rdeUEgPzglo,2250
1
+ torch_geometric/__init__.py,sha256=GOuL0XBOcsFqK-Q-c_STDpzZAG-vsctiDiU_Tg9W3t8,2250
2
2
  torch_geometric/_compile.py,sha256=9yqMTBKatZPr40WavJz9FjNi7pQj8YZAZOyZmmRGXgc,1351
3
3
  torch_geometric/_onnx.py,sha256=V9ffrIKSqhDw6xUZ12lkuSfNs48cQp2EeJ6Z19GfnVw,349
4
4
  torch_geometric/backend.py,sha256=lVaf7aLoVaB3M-UcByUJ1G4T4FOK6LXAg0CF4W3E8jo,1575
@@ -55,7 +55,7 @@ torch_geometric/data/temporal.py,sha256=WOJ6gFrTLikaLhUvotyUF5ql14FkE5Ox3hNkdSp6
55
55
  torch_geometric/data/view.py,sha256=XjkVSc-UWZFCT4DlXLShZtO8duhFQkS9gq88zZXANsk,1089
56
56
  torch_geometric/data/lightning/__init__.py,sha256=w3En1tJfy3kSqe1MycpOyZpHFO3fxBCgNCUOznPA3YU,178
57
57
  torch_geometric/data/lightning/datamodule.py,sha256=IjucsIKRcNv16DIqILQnqa_sz72a4-yivoySmEllv2o,29353
58
- torch_geometric/datasets/__init__.py,sha256=vIraHnDqD40Num-XwwNivjHQDboK9tmMvlZHjTAuljM,6291
58
+ torch_geometric/datasets/__init__.py,sha256=rgfUmjd9U3o8renKVl81Brscx4LOtwWmt6qAoaG41C4,6417
59
59
  torch_geometric/datasets/actor.py,sha256=oUxgJIX8bi5hJr1etWNYIFyVQNDDXi1nyVpHGGMEAGQ,4304
60
60
  torch_geometric/datasets/airfrans.py,sha256=8cCBmHPttrlKY_iwfyr-K-CUX_JEDjrIOg3r9dQSN7o,5439
61
61
  torch_geometric/datasets/airports.py,sha256=b3gkv3gY2JkUpmGiz36Z-g7EcnSfU8lBG1YsCOWdJ6k,3758
@@ -95,7 +95,7 @@ torch_geometric/datasets/gdelt_lite.py,sha256=zE1WagpgmsQARQhEgdCBtALRKyuQvIZqxT
95
95
  torch_geometric/datasets/ged_dataset.py,sha256=dtd-C6pCygNHLXgVfg3ZTWtTVHKT13Q3GlGrze1_rpo,9551
96
96
  torch_geometric/datasets/gemsec.py,sha256=oMTSryTgyed9z_4ydg3ql12KM-_35uqL1AoNls5nG8M,2820
97
97
  torch_geometric/datasets/geometry.py,sha256=-BxUMirZcUOf01c3avvF0b6wGPn-4S3Zj3Oau1RaJVk,4223
98
- torch_geometric/datasets/git_mol_dataset.py,sha256=LsS_dPYUpwhWXMBh17iT7IbjlLOP0fFzb-we9cuBDaQ,10681
98
+ torch_geometric/datasets/git_mol_dataset.py,sha256=l5u4U86tfjJdHtQPN7SM3Yjv25LD1Idtm7VHaqJqNik,10665
99
99
  torch_geometric/datasets/github.py,sha256=Qhqhkvi6eZ8VF_HqP1rL2iYToZavFNsQh7J1WdeM9dA,2687
100
100
  torch_geometric/datasets/gnn_benchmark_dataset.py,sha256=4P8n7czF-gf1egLYlAcSSvfB0GXIKpAbH5UjsuFld1M,6976
101
101
  torch_geometric/datasets/heterophilous_graph_dataset.py,sha256=yHHtwl4uPrid0vPOxvPV3sIS8HWdswar8FJ0h0OQ9is,4224
@@ -119,7 +119,7 @@ torch_geometric/datasets/medshapenet.py,sha256=eCBCXKpueweCwDSf_Q4_MwVA3IbJd04FS
119
119
  torch_geometric/datasets/mixhop_synthetic_dataset.py,sha256=4NNvTHUvvV6pcqQCyVDS5XhppXUeF2H9GTfFoc49eyU,3951
120
120
  torch_geometric/datasets/mnist_superpixels.py,sha256=o2ArbZ0_OE0u8VCaHmWwvngESlOFr9oM9dSEP_tjAS4,3340
121
121
  torch_geometric/datasets/modelnet.py,sha256=-qmLjlQiKVWmtHefAIIE97dQxEcaBfetMJnvgYZuwkg,5347
122
- torch_geometric/datasets/molecule_gpt_dataset.py,sha256=gVZv14PuZCanE4oxxHlqRNrvzGv6_KN318q5yFA3lS0,18797
122
+ torch_geometric/datasets/molecule_gpt_dataset.py,sha256=TFBduE3_3xxTFSHL3tirV-OAlBjSi6iHPOHJGQ_-tug,18785
123
123
  torch_geometric/datasets/molecule_net.py,sha256=pMzaJzd-LbBncZ0VoC87HfA8d1F4NwCWTb5YKvLM890,7404
124
124
  torch_geometric/datasets/movie_lens.py,sha256=M4Bu0Xus8IkW8GYzjxPxSdPXNbcCCx9cu6cncxBvLx8,4033
125
125
  torch_geometric/datasets/movie_lens_100k.py,sha256=eTpBAteM3jqTEtiwLxmhVj4r8JvftvPx8Hvs-3ZIHlU,6057
@@ -139,8 +139,9 @@ torch_geometric/datasets/pcqm4m.py,sha256=7ID_xXXIAyuNzYLI2lBWygZl9wGos-dbaz1b6E
139
139
  torch_geometric/datasets/planetoid.py,sha256=RksfwR_PI7qGVphs-T-4jXDepYwQCweMXElLm096hgg,7201
140
140
  torch_geometric/datasets/polblogs.py,sha256=IYzsvd4R0OojmOOZUoOdCwQYfwcTfth1PNtcBK1yOGc,3045
141
141
  torch_geometric/datasets/ppi.py,sha256=zPtg-omC7WYvr9Tzwkb7zNjpXLODsvxKxKdGEUswp2E,5030
142
+ torch_geometric/datasets/protein_mpnn_dataset.py,sha256=TTeTVJMo0Rlt2_h9bbZMKJe3rTJcjCgY5cXGyWteBfA,17756
142
143
  torch_geometric/datasets/qm7.py,sha256=bYyK8xlh9kTr5vqueNbLu9EAjIXkQH1KX1VWnjKfOJc,3323
143
- torch_geometric/datasets/qm9.py,sha256=XU2HTPbgJJ_6hT--X0J2xkXliCbt7_-hub9nuIUQlug,17213
144
+ torch_geometric/datasets/qm9.py,sha256=Ub1t8KNeWFZvw50_Qk-80yNFeYFDwdAeyQtp3JHZs7o,17197
144
145
  torch_geometric/datasets/rcdd.py,sha256=gvOoM1tw_X5QMyBB4FkMUwNErMXAvImyjz5twktBAh8,5317
145
146
  torch_geometric/datasets/reddit.py,sha256=QUgiKTaj6YTOYbgWgqV8mPYsctOui2ujaM8f8qy81v0,3131
146
147
  torch_geometric/datasets/reddit2.py,sha256=WSdrhbDPcUEG37XWNUd0uKnqgI911MOcfjXmgjbTPoQ,4291
@@ -153,6 +154,7 @@ torch_geometric/datasets/snap_dataset.py,sha256=deJvB6cpIQ3bu_pcWoqgEo1-Kl_NcFi7
153
154
  torch_geometric/datasets/suite_sparse.py,sha256=eqjH4vAUq872qdk3YdLkZSwlu6r7HHpTgK0vEVGmY1s,3278
154
155
  torch_geometric/datasets/tag_dataset.py,sha256=qTnwr2N1tbWYeLGbItfv70UxQ3n1rKesjeVU3kcOCP8,14757
155
156
  torch_geometric/datasets/taobao.py,sha256=CUcZpbWsNTasevflO8zqP0YvENy89P7wpKS4MHaDJ6Q,4170
157
+ torch_geometric/datasets/teeth3ds.py,sha256=hZvhcq9lsQENNFr5hk50w2T3CgxE_tlnQfrCgN6uIDQ,9919
156
158
  torch_geometric/datasets/tosca.py,sha256=nUSF8NQT1GlkwWQLshjWmr8xORsvRHzzIqhUyDCvABc,4632
157
159
  torch_geometric/datasets/tu_dataset.py,sha256=14OSaXBgVwT1dX2h1wZ3xVIwoo0GQBEfR3yWh6Q0VF0,7847
158
160
  torch_geometric/datasets/twitch.py,sha256=qfEerf-Uaojx2ZvegENowdG4E7RoUT_HUO9xtULadvo,3658
@@ -333,8 +335,9 @@ torch_geometric/nn/aggr/set_transformer.py,sha256=FG7_JizpFX14M6VSCwLSjYXYdJ1ZiQ
333
335
  torch_geometric/nn/aggr/sort.py,sha256=bvOOWnFkNOBOZih4rqVZQsjfeDX3vmXo1bpPSFD846w,2507
334
336
  torch_geometric/nn/aggr/utils.py,sha256=SQvdc0g6p_E2j0prA14MW2ekjEDvV-g545N0Q85uc-o,8625
335
337
  torch_geometric/nn/aggr/variance_preserving.py,sha256=fu-U_aGYpVLpgSFvVg0ONMe6nqoyv8tZ6Y35qMYTf9w,1126
336
- torch_geometric/nn/attention/__init__.py,sha256=wLKTmlfP7qL9sZHy4cmDFHEtdwa-MEKE1dT51L1_w10,192
338
+ torch_geometric/nn/attention/__init__.py,sha256=w-jDQFpVqARJKjttTgKkD9kkAqRJl4MpASCfiNYIfr0,263
337
339
  torch_geometric/nn/attention/performer.py,sha256=2PCDn4_-oNTao2-DkXIaoi18anP01OxRELF2pvp-jk8,7357
340
+ torch_geometric/nn/attention/polynormer.py,sha256=uBxGs0nldp6oGlByqbxgEk23VeXLEd6B3myS5BOKDRs,3998
338
341
  torch_geometric/nn/attention/qformer.py,sha256=7J-pWm_vpumK38IC-iCBz4oqL-BEIofEIxJ0wfjWq9A,2338
339
342
  torch_geometric/nn/attention/sgformer.py,sha256=OBC5HQxbY289bPDtwN8UbPH46To2GRTeVN-najogD-o,3747
340
343
  torch_geometric/nn/conv/__init__.py,sha256=8CK-DFG2PEo2ZaFyg-IUlQH8ecQoDDi556uv3ugeQyc,3572
@@ -374,7 +377,7 @@ torch_geometric/nn/conv/hgt_conv.py,sha256=lUhTWUMovMtn9yR_b2-kLNLqHChGOUl2OtXBY
374
377
  torch_geometric/nn/conv/hypergraph_conv.py,sha256=4BosbbqJyprlI6QjPqIfMxCqnARU_0mUn1zcAQhbw90,8691
375
378
  torch_geometric/nn/conv/le_conv.py,sha256=DonmmYZOKk5wIlTZzzIfNKqBY6MO0MRxYhyr0YtNz-Q,3494
376
379
  torch_geometric/nn/conv/lg_conv.py,sha256=8jMa79iPsOUbXEfBIc3wmbvAD8T3d1j37LeIFTX3Yag,2369
377
- torch_geometric/nn/conv/meshcnn_conv.py,sha256=Z6p9KwGc_Kj4XQnTWqzbXQzbbpVlMv7ga0DuDB0jLSg,22279
380
+ torch_geometric/nn/conv/meshcnn_conv.py,sha256=92zUcgfS0Fwv-MpddF4Ia1a65y7ddPAkazYf7D6kvwg,21951
378
381
  torch_geometric/nn/conv/message_passing.py,sha256=ZuTvSvodGy1GyAW4mHtuoMUuxclam-7opidYNY5IHm8,44377
379
382
  torch_geometric/nn/conv/mf_conv.py,sha256=SkOGMN1tFT9dcqy8xYowsB2ozw6QfkoArgR1BksZZaU,4340
380
383
  torch_geometric/nn/conv/mixhop_conv.py,sha256=qVDPWeWcnO7_eHM0ZnpKtr8SISjb4jp0xjgpoDrwjlk,4555
@@ -429,7 +432,7 @@ torch_geometric/nn/kge/distmult.py,sha256=dGQ0bVzjreZgFN1lXE23_IIidsiOq7ehPrMb-N
429
432
  torch_geometric/nn/kge/loader.py,sha256=5Uc1j3OUMQnBYSHDqL7pLCty1siFLzoPkztigYO2zP8,771
430
433
  torch_geometric/nn/kge/rotate.py,sha256=XLuO1AbyTt5cJxr97ZzoyAyIEsHKesgW5TvDmnGJAao,3208
431
434
  torch_geometric/nn/kge/transe.py,sha256=jlejq5BLMm-sb1wWcLDp7pZqCdelWBgjDIC8ctbjSdU,3088
432
- torch_geometric/nn/models/__init__.py,sha256=4mZ5dyiZ9aa1NaBth1qYV-hZdnG_Np1XWvRLB4Qv6RM,2338
435
+ torch_geometric/nn/models/__init__.py,sha256=71Hqc-ZMfCKn9lelFYDjpHXapbEa0wqVAd2OXCb1y5o,2448
433
436
  torch_geometric/nn/models/attentive_fp.py,sha256=1z3iTV2O5W9tqHFAdno8FeBFeXmuG-TDZk4lwwVh3Ac,6634
434
437
  torch_geometric/nn/models/attract_repel.py,sha256=h9OyogT0NY0xiT0DkpJHMxH6ZUmo8R-CmwZdKEwq8Ek,5277
435
438
  torch_geometric/nn/models/autoencoder.py,sha256=nGje-zty78Y3hxOJ9o0_6QziJjOvBlknk6z0_fDQwQU,10770
@@ -442,7 +445,7 @@ torch_geometric/nn/models/dimenet.py,sha256=O2rqEx5HWs_lMwRD8eq6WMkbqJaCLL5zgWUJ
442
445
  torch_geometric/nn/models/dimenet_utils.py,sha256=Eyn_EiJqwKvuYj6BtRpSxrzMG3v4Gk98X9MxZ7uvwm4,5069
443
446
  torch_geometric/nn/models/g_retriever.py,sha256=tVibbqM_r-1LnA3R3oVyzp0bpuN3qPoYqcU6LZ8dYEk,8260
444
447
  torch_geometric/nn/models/git_mol.py,sha256=Wc6Hx6RDDR7sDWRWHfA5eK9e9gFsrTZ9OLmpMfoj3pE,12676
445
- torch_geometric/nn/models/glem.py,sha256=PlLjfMM4lKLs7c7tRC4LVD8tj0jpUXNxcnGbYut7vBE,16624
448
+ torch_geometric/nn/models/glem.py,sha256=GlL_I63g-_5eTycSGRj720YntldQ-CQ351RaDPc6XAU,16674
446
449
  torch_geometric/nn/models/gnnff.py,sha256=15dkiLgy0LmH1hnUrpeoHioIp4BPTfjpVATpnGRt9E0,7860
447
450
  torch_geometric/nn/models/gpse.py,sha256=acEAeeicLgzKRL54WhvIFxjA5XViHgXgMEH-NgbMdqI,41971
448
451
  torch_geometric/nn/models/graph_mixer.py,sha256=mthMeCOikR8gseEsu4oJ3Cd9C35zHSv1p32ROwnG-6s,9246
@@ -459,6 +462,8 @@ torch_geometric/nn/models/molecule_gpt.py,sha256=k-XULH6jaurj-R2EE4sIWTkqlNqa3Cz
459
462
  torch_geometric/nn/models/neural_fingerprint.py,sha256=pTLJgU9Uh2Lnf9bggLj4cKI8YdEFcMF-9MALuubqbuQ,2378
460
463
  torch_geometric/nn/models/node2vec.py,sha256=81Ku4Rp4IwLEAy06KEgJ2fYtXXVL_uv_Hb8lBr6YXrE,7664
461
464
  torch_geometric/nn/models/pmlp.py,sha256=dcAASVSyQMMhItSfEJWPeAFh0R3tNCwAHwdrShwQ8o4,3538
465
+ torch_geometric/nn/models/polynormer.py,sha256=mayWdzdolT5PCt_Oo7UGG-JUripMHWB2lUWF1bh6goU,7640
466
+ torch_geometric/nn/models/protein_mpnn.py,sha256=QXHfltiJPmakpzgJKw_1vwCGBlszv9nfY4r4F38Sg9k,11031
462
467
  torch_geometric/nn/models/re_net.py,sha256=pz66q5b5BoGDNVQvpEGS2RGoeKvpjkYAv9r3WAuvITk,8986
463
468
  torch_geometric/nn/models/rect.py,sha256=2F3XyyvHTAEuqfJpiNB5M8pSGy738LhPiom5I-SDWqM,2808
464
469
  torch_geometric/nn/models/rev_gnn.py,sha256=Bpme087Zs227lcB0ODOKWsxaly67q96wseaRt6bacjs,11796
@@ -613,7 +618,7 @@ torch_geometric/utils/_tree_decomposition.py,sha256=ZtpjPQJgXbQWtSWjo-Fmhrov0DGO
613
618
  torch_geometric/utils/_trim_to_layer.py,sha256=cauOEzMJJK4w9BC-Pg1bHVncBYqG9XxQex3rn10BFjc,8339
614
619
  torch_geometric/utils/_unbatch.py,sha256=B0vjKI96PtHvSBG8F_lqvsiJE134aVjUurPZsG6UZRI,2378
615
620
  torch_geometric/utils/augmentation.py,sha256=1F0YCuaklZ9ZbXxdFV0oOoemWvLd8p60WvFo2chzl7E,8600
616
- torch_geometric/utils/convert.py,sha256=j0t_87c-U_-15YKFfkOZfloEc5NbjgeLIk851zHG8WA,21665
621
+ torch_geometric/utils/convert.py,sha256=RE5n5no74Xu39-QMWFE0-1RvTgykdK33ymyjF9WcuSs,21938
617
622
  torch_geometric/utils/cross_entropy.py,sha256=ZFS5bivtzv3EV9zqgKsekmuQyoZZggPSclhl_tRNHxo,3047
618
623
  torch_geometric/utils/dropout.py,sha256=gg0rDnD4FLvBaKSoLAkZwViAQflhLefJm6_Mju5dmQs,11416
619
624
  torch_geometric/utils/embedding.py,sha256=Ac_MPSrZGpw-e-gU6Yz-seUioC2WZxBSSzXFeclGwMk,5232
@@ -634,13 +639,13 @@ torch_geometric/utils/num_nodes.py,sha256=F15ciTFOe8AxjkUh1wKH7RLmJvQYYpz-l3pPPv
634
639
  torch_geometric/utils/ppr.py,sha256=ebiHbQqRJsQbGUI5xu-IkzQSQsgIaC71vgO0KcXIKAk,4055
635
640
  torch_geometric/utils/random.py,sha256=Rv5HlhG5310rytbT9EZ7xWLGKQfozfz1azvYi5nx2-U,5148
636
641
  torch_geometric/utils/repeat.py,sha256=RxCoRoEisaP6NouXPPW5tY1Rn-tIfrmpJPm0qGP6W8M,815
637
- torch_geometric/utils/smiles.py,sha256=lGQ2BwJ49uBrQfIxxPz8ceTO9Jo-XCjlLxs1ql3xrsA,7130
642
+ torch_geometric/utils/smiles.py,sha256=CFqeNtSBXQtY9Ex2gQzI0La490IpVVrm01QdRYEpV7w,7114
638
643
  torch_geometric/utils/sparse.py,sha256=1DbaEwdyvnzvg5qVjPlnWcEVDMkxrQLX1jJ0dr6P4js,25135
639
644
  torch_geometric/utils/undirected.py,sha256=H_nfpI0_WluOG6VfjPyldvcjL4w5USAKWu2x5nUAUjw,6222
640
645
  torch_geometric/visualization/__init__.py,sha256=b-HnVesXjyJ_L1N-DnjiRiRVf7lhwKaBQF_2i5YMVSU,208
641
646
  torch_geometric/visualization/graph.py,sha256=mfZHXYfiU-CWMtfawYc80IxVwVmtK9hbIkSKhM_j7oI,14311
642
647
  torch_geometric/visualization/influence.py,sha256=CWMvuNA_Nf1sfbJmQgn58yS4OFpeKXeZPe7kEuvkUBw,477
643
- pyg_nightly-2.7.0.dev20250702.dist-info/licenses/LICENSE,sha256=ic-27cMJc1kWoMEYncz3Ya3Ur2Bi3bNLWib2DT763-o,1067
644
- pyg_nightly-2.7.0.dev20250702.dist-info/WHEEL,sha256=G2gURzTEtmeR8nrdXUJfNiB3VYVxigPQ-bEQujpNiNs,82
645
- pyg_nightly-2.7.0.dev20250702.dist-info/METADATA,sha256=66AyTfnfJvD0er8ePN_vOUgj6tD76JJy4QPaIvkh8bw,63005
646
- pyg_nightly-2.7.0.dev20250702.dist-info/RECORD,,
648
+ pyg_nightly-2.7.0.dev20250704.dist-info/licenses/LICENSE,sha256=ic-27cMJc1kWoMEYncz3Ya3Ur2Bi3bNLWib2DT763-o,1067
649
+ pyg_nightly-2.7.0.dev20250704.dist-info/WHEEL,sha256=G2gURzTEtmeR8nrdXUJfNiB3VYVxigPQ-bEQujpNiNs,82
650
+ pyg_nightly-2.7.0.dev20250704.dist-info/METADATA,sha256=Nau44bIMI13OXEqYNOlI1hYfDz8FXpUoPydv-JxRW2Q,63005
651
+ pyg_nightly-2.7.0.dev20250704.dist-info/RECORD,,
@@ -31,7 +31,7 @@ from .lazy_loader import LazyLoader
31
31
  contrib = LazyLoader('contrib', globals(), 'torch_geometric.contrib')
32
32
  graphgym = LazyLoader('graphgym', globals(), 'torch_geometric.graphgym')
33
33
 
34
- __version__ = '2.7.0.dev20250702'
34
+ __version__ = '2.7.0.dev20250704'
35
35
 
36
36
  __all__ = [
37
37
  'Index',
@@ -81,8 +81,10 @@ from .web_qsp_dataset import WebQSPDataset, CWQDataset
81
81
  from .git_mol_dataset import GitMolDataset
82
82
  from .molecule_gpt_dataset import MoleculeGPTDataset
83
83
  from .instruct_mol_dataset import InstructMolDataset
84
+ from .protein_mpnn_dataset import ProteinMPNNDataset
84
85
  from .tag_dataset import TAGDataset
85
86
  from .city import CityNetwork
87
+ from .teeth3ds import Teeth3DS
86
88
 
87
89
  from .dbp15k import DBP15K
88
90
  from .aminer import AMiner
@@ -201,8 +203,10 @@ homo_datasets = [
201
203
  'GitMolDataset',
202
204
  'MoleculeGPTDataset',
203
205
  'InstructMolDataset',
206
+ 'ProteinMPNNDataset',
204
207
  'TAGDataset',
205
208
  'CityNetwork',
209
+ 'Teeth3DS',
206
210
  ]
207
211
 
208
212
  hetero_datasets = [
@@ -102,7 +102,7 @@ class GitMolDataset(InMemoryDataset):
102
102
 
103
103
  try:
104
104
  from rdkit import Chem, RDLogger
105
- RDLogger.DisableLog('rdApp.*') # type: ignore
105
+ RDLogger.DisableLog('rdApp.*')
106
106
  WITH_RDKIT = True
107
107
 
108
108
  except ImportError:
@@ -438,7 +438,7 @@ class MoleculeGPTDataset(InMemoryDataset):
438
438
  for mol in tqdm(suppl):
439
439
  if mol.HasProp('PUBCHEM_COMPOUND_CID'):
440
440
  CID = mol.GetProp("PUBCHEM_COMPOUND_CID")
441
- CAN_SMILES = mol.GetProp("PUBCHEM_OPENEYE_CAN_SMILES")
441
+ CAN_SMILES = mol.GetProp("PUBCHEM_SMILES")
442
442
 
443
443
  m: Chem.Mol = Chem.MolFromSmiles(CAN_SMILES)
444
444
  if m is None:
@@ -0,0 +1,451 @@
1
+ import os
2
+ import pickle
3
+ import random
4
+ from collections import defaultdict
5
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
6
+
7
+ import numpy as np
8
+ import torch
9
+ from tqdm import tqdm
10
+
11
+ from torch_geometric.data import (
12
+ Data,
13
+ InMemoryDataset,
14
+ download_url,
15
+ extract_tar,
16
+ )
17
+
18
+
19
+ class ProteinMPNNDataset(InMemoryDataset):
20
+ r"""The ProteinMPNN dataset from the `"Robust deep learning based protein
21
+ sequence design using ProteinMPNN"
22
+ <https://www.biorxiv.org/content/10.1101/2022.06.03.494563v1>`_ paper.
23
+
24
+ Args:
25
+ root (str): Root directory where the dataset should be saved.
26
+ size (str): Size of the PDB information to train the model.
27
+ If :obj:`"small"`, loads the small dataset (229.4 MB).
28
+ If :obj:`"large"`, loads the large dataset (64.1 GB).
29
+ (default: :obj:`"small"`)
30
+ split (str, optional): If :obj:`"train"`, loads the training dataset.
31
+ If :obj:`"valid"`, loads the validation dataset.
32
+ If :obj:`"test"`, loads the test dataset.
33
+ (default: :obj:`"train"`)
34
+ datacut (str, optional): Date cutoff to filter the dataset.
35
+ (default: :obj:`"2030-01-01"`)
36
+ rescut (float, optional): PDB resolution cutoff.
37
+ (default: :obj:`3.5`)
38
+ homo (float, optional): Homology cutoff.
39
+ (default: :obj:`0.70`)
40
+ max_length (int, optional): Maximum length of the protein complex.
41
+ (default: :obj:`10000`)
42
+ num_units (int, optional): Number of units of the protein complex.
43
+ (default: :obj:`150`)
44
+ transform (callable, optional): A function/transform that takes in an
45
+ :obj:`torch_geometric.data.Data` object and returns a transformed
46
+ version. The data object will be transformed before every access.
47
+ (default: :obj:`None`)
48
+ pre_transform (callable, optional): A function/transform that takes in
49
+ an :obj:`torch_geometric.data.Data` object and returns a
50
+ transformed version. The data object will be transformed before
51
+ being saved to disk. (default: :obj:`None`)
52
+ pre_filter (callable, optional): A function that takes in an
53
+ :obj:`torch_geometric.data.Data` object and returns a boolean
54
+ value, indicating whether the data object should be included in the
55
+ final dataset. (default: :obj:`None`)
56
+ force_reload (bool, optional): Whether to re-process the dataset.
57
+ (default: :obj:`False`)
58
+ """
59
+
60
+ raw_url = {
61
+ 'small':
62
+ 'https://files.ipd.uw.edu/pub/training_sets/'
63
+ 'pdb_2021aug02_sample.tar.gz',
64
+ 'large':
65
+ 'https://files.ipd.uw.edu/pub/training_sets/'
66
+ 'pdb_2021aug02.tar.gz',
67
+ }
68
+
69
+ splits = {
70
+ 'train': 1,
71
+ 'valid': 2,
72
+ 'test': 3,
73
+ }
74
+
75
+ def __init__(
76
+ self,
77
+ root: str,
78
+ size: str = 'small',
79
+ split: str = 'train',
80
+ datacut: str = '2030-01-01',
81
+ rescut: float = 3.5,
82
+ homo: float = 0.70,
83
+ max_length: int = 10000,
84
+ num_units: int = 150,
85
+ transform: Optional[Callable] = None,
86
+ pre_transform: Optional[Callable] = None,
87
+ pre_filter: Optional[Callable] = None,
88
+ force_reload: bool = False,
89
+ ) -> None:
90
+ self.size = size
91
+ self.split = split
92
+ self.datacut = datacut
93
+ self.rescut = rescut
94
+ self.homo = homo
95
+ self.max_length = max_length
96
+ self.num_units = num_units
97
+
98
+ self.sub_folder = self.raw_url[self.size].split('/')[-1].split('.')[0]
99
+
100
+ super().__init__(root, transform, pre_transform, pre_filter,
101
+ force_reload=force_reload)
102
+ self.load(self.processed_paths[self.splits[self.split]])
103
+
104
+ @property
105
+ def raw_file_names(self) -> List[str]:
106
+ return [
107
+ f'{self.sub_folder}/{f}'
108
+ for f in ['list.csv', 'valid_clusters.txt', 'test_clusters.txt']
109
+ ]
110
+
111
+ @property
112
+ def processed_file_names(self) -> List[str]:
113
+ return ['splits.pkl', 'train.pt', 'valid.pt', 'test.pt']
114
+
115
+ def download(self) -> None:
116
+ file_path = download_url(self.raw_url[self.size], self.raw_dir)
117
+ extract_tar(file_path, self.raw_dir)
118
+ os.unlink(file_path)
119
+
120
+ def process(self) -> None:
121
+ alphabet_set = set(list('ACDEFGHIKLMNPQRSTVWYX'))
122
+ cluster_ids = self._process_split()
123
+ total_items = sum(len(items) for items in cluster_ids.values())
124
+ data_list = []
125
+
126
+ with tqdm(total=total_items, desc="Processing") as pbar:
127
+ for _, items in cluster_ids.items():
128
+ for chain_id, _ in items:
129
+ item = self._process_pdb1(chain_id)
130
+
131
+ if 'label' not in item:
132
+ pbar.update(1)
133
+ continue
134
+ if len(list(np.unique(item['idx']))) >= 352:
135
+ pbar.update(1)
136
+ continue
137
+
138
+ my_dict = self._process_pdb2(item)
139
+
140
+ if len(my_dict['seq']) > self.max_length:
141
+ pbar.update(1)
142
+ continue
143
+ bad_chars = set(list(
144
+ my_dict['seq'])).difference(alphabet_set)
145
+ if len(bad_chars) > 0:
146
+ pbar.update(1)
147
+ continue
148
+
149
+ x_chain_all, chain_seq_label_all, mask, chain_mask_all, residue_idx, chain_encoding_all = self._process_pdb3( # noqa: E501
150
+ my_dict)
151
+
152
+ data = Data(
153
+ x=x_chain_all, # [seq_len, 4, 3]
154
+ chain_seq_label=chain_seq_label_all, # [seq_len]
155
+ mask=mask, # [seq_len]
156
+ chain_mask_all=chain_mask_all, # [seq_len]
157
+ residue_idx=residue_idx, # [seq_len]
158
+ chain_encoding_all=chain_encoding_all, # [seq_len]
159
+ )
160
+
161
+ if self.pre_filter is not None and not self.pre_filter(
162
+ data):
163
+ continue
164
+ if self.pre_transform is not None:
165
+ data = self.pre_transform(data)
166
+
167
+ data_list.append(data)
168
+
169
+ if len(data_list) >= self.num_units:
170
+ pbar.update(total_items - pbar.n)
171
+ break
172
+ pbar.update(1)
173
+ else:
174
+ continue
175
+ break
176
+ self.save(data_list, self.processed_paths[self.splits[self.split]])
177
+
178
+ def _process_split(self) -> Dict[int, List[Tuple[str, int]]]:
179
+ import pandas as pd
180
+ save_path = self.processed_paths[0]
181
+
182
+ if os.path.exists(save_path):
183
+ print('Load split')
184
+ with open(save_path, 'rb') as f:
185
+ data = pickle.load(f)
186
+ else:
187
+ # CHAINID, DEPOSITION, RESOLUTION, HASH, CLUSTER, SEQUENCE
188
+ df = pd.read_csv(self.raw_paths[0])
189
+ df = df[(df['RESOLUTION'] <= self.rescut)
190
+ & (df['DEPOSITION'] <= self.datacut)]
191
+
192
+ val_ids = pd.read_csv(self.raw_paths[1], header=None)[0].tolist()
193
+ test_ids = pd.read_csv(self.raw_paths[2], header=None)[0].tolist()
194
+
195
+ # compile training and validation sets
196
+ data = {
197
+ 'train': defaultdict(list),
198
+ 'valid': defaultdict(list),
199
+ 'test': defaultdict(list),
200
+ }
201
+
202
+ for _, r in tqdm(df.iterrows(), desc='Processing split',
203
+ total=len(df)):
204
+ cluster_id = r['CLUSTER']
205
+ hash_id = r['HASH']
206
+ chain_id = r['CHAINID']
207
+ if cluster_id in val_ids:
208
+ data['valid'][cluster_id].append((chain_id, hash_id))
209
+ elif cluster_id in test_ids:
210
+ data['test'][cluster_id].append((chain_id, hash_id))
211
+ else:
212
+ data['train'][cluster_id].append((chain_id, hash_id))
213
+
214
+ with open(save_path, 'wb') as f:
215
+ pickle.dump(data, f)
216
+
217
+ return data[self.split]
218
+
219
+ def _process_pdb1(self, chain_id: str) -> Dict[str, Any]:
220
+ pdbid, chid = chain_id.split('_')
221
+ prefix = f'{self.raw_dir}/{self.sub_folder}/pdb/{pdbid[1:3]}/{pdbid}'
222
+ # load metadata
223
+ if not os.path.isfile(f'{prefix}.pt'):
224
+ return {'seq': np.zeros(5)}
225
+ meta = torch.load(f'{prefix}.pt')
226
+ asmb_ids = meta['asmb_ids']
227
+ asmb_chains = meta['asmb_chains']
228
+ chids = np.array(meta['chains'])
229
+
230
+ # find candidate assemblies which contain chid chain
231
+ asmb_candidates = {
232
+ a
233
+ for a, b in zip(asmb_ids, asmb_chains) if chid in b.split(',')
234
+ }
235
+
236
+ # if the chains is missing is missing from all the assemblies
237
+ # then return this chain alone
238
+ if len(asmb_candidates) < 1:
239
+ chain = torch.load(f'{prefix}_{chid}.pt')
240
+ L = len(chain['seq'])
241
+ return {
242
+ 'seq': chain['seq'],
243
+ 'xyz': chain['xyz'],
244
+ 'idx': torch.zeros(L).int(),
245
+ 'masked': torch.Tensor([0]).int(),
246
+ 'label': chain_id,
247
+ }
248
+
249
+ # randomly pick one assembly from candidates
250
+ asmb_i = random.sample(list(asmb_candidates), 1)
251
+
252
+ # indices of selected transforms
253
+ idx = np.where(np.array(asmb_ids) == asmb_i)[0]
254
+
255
+ # load relevant chains
256
+ chains = {
257
+ c: torch.load(f'{prefix}_{c}.pt')
258
+ for i in idx
259
+ for c in asmb_chains[i] if c in meta['chains']
260
+ }
261
+
262
+ # generate assembly
263
+ asmb = {}
264
+ for k in idx:
265
+
266
+ # pick k-th xform
267
+ xform = meta[f'asmb_xform{k}']
268
+ u = xform[:, :3, :3]
269
+ r = xform[:, :3, 3]
270
+
271
+ # select chains which k-th xform should be applied to
272
+ s1 = set(meta['chains'])
273
+ s2 = set(asmb_chains[k].split(','))
274
+ chains_k = s1 & s2
275
+
276
+ # transform selected chains
277
+ for c in chains_k:
278
+ try:
279
+ xyz = chains[c]['xyz']
280
+ xyz_ru = torch.einsum('bij,raj->brai', u, xyz) + r[:, None,
281
+ None, :]
282
+ asmb.update({
283
+ (c, k, i): xyz_i
284
+ for i, xyz_i in enumerate(xyz_ru)
285
+ })
286
+ except KeyError:
287
+ return {'seq': np.zeros(5)}
288
+
289
+ # select chains which share considerable similarity to chid
290
+ seqid = meta['tm'][chids == chid][0, :, 1]
291
+ homo = {
292
+ ch_j
293
+ for seqid_j, ch_j in zip(seqid, chids) if seqid_j > self.homo
294
+ }
295
+ # stack all chains in the assembly together
296
+ seq: str = ''
297
+ xyz_all: List[torch.Tensor] = []
298
+ idx_all: List[torch.Tensor] = []
299
+ masked: List[int] = []
300
+ seq_list = []
301
+ for counter, (k, v) in enumerate(asmb.items()):
302
+ seq += chains[k[0]]['seq']
303
+ seq_list.append(chains[k[0]]['seq'])
304
+ xyz_all.append(v)
305
+ idx_all.append(torch.full((v.shape[0], ), counter))
306
+ if k[0] in homo:
307
+ masked.append(counter)
308
+
309
+ return {
310
+ 'seq': seq,
311
+ 'xyz': torch.cat(xyz_all, dim=0),
312
+ 'idx': torch.cat(idx_all, dim=0),
313
+ 'masked': torch.Tensor(masked).int(),
314
+ 'label': chain_id,
315
+ }
316
+
317
+ def _process_pdb2(self, t: Dict[str, Any]) -> Dict[str, Any]:
318
+ init_alphabet = list(
319
+ 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz')
320
+ extra_alphabet = [str(item) for item in list(np.arange(300))]
321
+ chain_alphabet = init_alphabet + extra_alphabet
322
+ my_dict: Dict[str, Union[str, int, Dict[str, Any], List[Any]]] = {}
323
+ concat_seq = ''
324
+ mask_list = []
325
+ visible_list = []
326
+ for idx in list(np.unique(t['idx'])):
327
+ letter = chain_alphabet[idx]
328
+ res = np.argwhere(t['idx'] == idx)
329
+ initial_sequence = "".join(list(
330
+ np.array(list(t['seq']))[res][
331
+ 0,
332
+ ]))
333
+ if initial_sequence[-6:] == "HHHHHH":
334
+ res = res[:, :-6]
335
+ if initial_sequence[0:6] == "HHHHHH":
336
+ res = res[:, 6:]
337
+ if initial_sequence[-7:-1] == "HHHHHH":
338
+ res = res[:, :-7]
339
+ if initial_sequence[-8:-2] == "HHHHHH":
340
+ res = res[:, :-8]
341
+ if initial_sequence[-9:-3] == "HHHHHH":
342
+ res = res[:, :-9]
343
+ if initial_sequence[-10:-4] == "HHHHHH":
344
+ res = res[:, :-10]
345
+ if initial_sequence[1:7] == "HHHHHH":
346
+ res = res[:, 7:]
347
+ if initial_sequence[2:8] == "HHHHHH":
348
+ res = res[:, 8:]
349
+ if initial_sequence[3:9] == "HHHHHH":
350
+ res = res[:, 9:]
351
+ if initial_sequence[4:10] == "HHHHHH":
352
+ res = res[:, 10:]
353
+ if res.shape[1] >= 4:
354
+ chain_seq = "".join(list(np.array(list(t['seq']))[res][0]))
355
+ my_dict[f'seq_chain_{letter}'] = chain_seq
356
+ concat_seq += chain_seq
357
+ if idx in t['masked']:
358
+ mask_list.append(letter)
359
+ else:
360
+ visible_list.append(letter)
361
+ coords_dict_chain = {}
362
+ all_atoms = np.array(t['xyz'][res])[0] # [L, 14, 3]
363
+ for i, c in enumerate(['N', 'CA', 'C', 'O']):
364
+ coords_dict_chain[
365
+ f'{c}_chain_{letter}'] = all_atoms[:, i, :].tolist()
366
+ my_dict[f'coords_chain_{letter}'] = coords_dict_chain
367
+ my_dict['name'] = t['label']
368
+ my_dict['masked_list'] = mask_list
369
+ my_dict['visible_list'] = visible_list
370
+ my_dict['num_of_chains'] = len(mask_list) + len(visible_list)
371
+ my_dict['seq'] = concat_seq
372
+ return my_dict
373
+
374
+ def _process_pdb3(
375
+ self, b: Dict[str, Any]
376
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
377
+ torch.Tensor, torch.Tensor]:
378
+ L = len(b['seq'])
379
+ # residue idx with jumps across chains
380
+ residue_idx = -100 * np.ones([L], dtype=np.int32)
381
+ # get the list of masked / visible chains
382
+ masked_chains, visible_chains = b['masked_list'], b['visible_list']
383
+ visible_temp_dict, masked_temp_dict = {}, {}
384
+ for letter in masked_chains + visible_chains:
385
+ chain_seq = b[f'seq_chain_{letter}']
386
+ if letter in visible_chains:
387
+ visible_temp_dict[letter] = chain_seq
388
+ elif letter in masked_chains:
389
+ masked_temp_dict[letter] = chain_seq
390
+ # check for duplicate chains (same sequence but different identity)
391
+ for _, vm in masked_temp_dict.items():
392
+ for kv, vv in visible_temp_dict.items():
393
+ if vm == vv:
394
+ if kv not in masked_chains:
395
+ masked_chains.append(kv)
396
+ if kv in visible_chains:
397
+ visible_chains.remove(kv)
398
+ # build protein data structures
399
+ all_chains = masked_chains + visible_chains
400
+ np.random.shuffle(all_chains)
401
+ x_chain_list = []
402
+ chain_mask_list = []
403
+ chain_seq_list = []
404
+ chain_encoding_list = []
405
+ c, l0, l1 = 1, 0, 0
406
+ for letter in all_chains:
407
+ chain_seq = b[f'seq_chain_{letter}']
408
+ chain_length = len(chain_seq)
409
+ chain_coords = b[f'coords_chain_{letter}']
410
+ x_chain = np.stack([
411
+ chain_coords[c] for c in [
412
+ f'N_chain_{letter}', f'CA_chain_{letter}',
413
+ f'C_chain_{letter}', f'O_chain_{letter}'
414
+ ]
415
+ ], 1) # [chain_length, 4, 3]
416
+ x_chain_list.append(x_chain)
417
+ chain_seq_list.append(chain_seq)
418
+ if letter in visible_chains:
419
+ chain_mask = np.zeros(chain_length) # 0 for visible chains
420
+ elif letter in masked_chains:
421
+ chain_mask = np.ones(chain_length) # 1 for masked chains
422
+ chain_mask_list.append(chain_mask)
423
+ chain_encoding_list.append(c * np.ones(chain_length))
424
+ l1 += chain_length
425
+ residue_idx[l0:l1] = 100 * (c - 1) + np.arange(l0, l1)
426
+ l0 += chain_length
427
+ c += 1
428
+ x_chain_all = np.concatenate(x_chain_list, 0) # [L, 4, 3]
429
+ chain_seq_all = "".join(chain_seq_list)
430
+ # [L,] 1.0 for places that need to be predicted
431
+ chain_mask_all = np.concatenate(chain_mask_list, 0)
432
+ chain_encoding_all = np.concatenate(chain_encoding_list, 0)
433
+
434
+ # Convert to labels
435
+ alphabet = 'ACDEFGHIKLMNPQRSTVWYX'
436
+ chain_seq_label_all = np.asarray(
437
+ [alphabet.index(a) for a in chain_seq_all], dtype=np.int32)
438
+
439
+ isnan = np.isnan(x_chain_all)
440
+ mask = np.isfinite(np.sum(x_chain_all, (1, 2))).astype(np.float32)
441
+ x_chain_all[isnan] = 0.
442
+
443
+ # Conversion
444
+ return (
445
+ torch.from_numpy(x_chain_all).to(dtype=torch.float32),
446
+ torch.from_numpy(chain_seq_label_all).to(dtype=torch.long),
447
+ torch.from_numpy(mask).to(dtype=torch.float32),
448
+ torch.from_numpy(chain_mask_all).to(dtype=torch.float32),
449
+ torch.from_numpy(residue_idx).to(dtype=torch.long),
450
+ torch.from_numpy(chain_encoding_all).to(dtype=torch.long),
451
+ )
@@ -202,7 +202,7 @@ class QM9(InMemoryDataset):
202
202
  from rdkit import Chem, RDLogger
203
203
  from rdkit.Chem.rdchem import BondType as BT
204
204
  from rdkit.Chem.rdchem import HybridizationType
205
- RDLogger.DisableLog('rdApp.*') # type: ignore
205
+ RDLogger.DisableLog('rdApp.*')
206
206
  WITH_RDKIT = True
207
207
 
208
208
  except ImportError: