pyg-nightly 2.7.0.dev20250702__py3-none-any.whl → 2.7.0.dev20250703__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pyg-nightly might be problematic. Click here for more details.

@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pyg-nightly
3
- Version: 2.7.0.dev20250702
3
+ Version: 2.7.0.dev20250703
4
4
  Summary: Graph Neural Network Library for PyTorch
5
5
  Keywords: deep-learning,pytorch,geometric-deep-learning,graph-neural-networks,graph-convolutional-networks
6
6
  Author-email: Matthias Fey <matthias@pyg.org>
@@ -1,4 +1,4 @@
1
- torch_geometric/__init__.py,sha256=ap-t4q8f9aTE0oAW_K5390u2Mlk8-S76rdeUEgPzglo,2250
1
+ torch_geometric/__init__.py,sha256=8_AcpgPpDfVr7gwK3sYPD-pu9JYUmziOQzs9SPRvqcE,2250
2
2
  torch_geometric/_compile.py,sha256=9yqMTBKatZPr40WavJz9FjNi7pQj8YZAZOyZmmRGXgc,1351
3
3
  torch_geometric/_onnx.py,sha256=V9ffrIKSqhDw6xUZ12lkuSfNs48cQp2EeJ6Z19GfnVw,349
4
4
  torch_geometric/backend.py,sha256=lVaf7aLoVaB3M-UcByUJ1G4T4FOK6LXAg0CF4W3E8jo,1575
@@ -55,7 +55,7 @@ torch_geometric/data/temporal.py,sha256=WOJ6gFrTLikaLhUvotyUF5ql14FkE5Ox3hNkdSp6
55
55
  torch_geometric/data/view.py,sha256=XjkVSc-UWZFCT4DlXLShZtO8duhFQkS9gq88zZXANsk,1089
56
56
  torch_geometric/data/lightning/__init__.py,sha256=w3En1tJfy3kSqe1MycpOyZpHFO3fxBCgNCUOznPA3YU,178
57
57
  torch_geometric/data/lightning/datamodule.py,sha256=IjucsIKRcNv16DIqILQnqa_sz72a4-yivoySmEllv2o,29353
58
- torch_geometric/datasets/__init__.py,sha256=vIraHnDqD40Num-XwwNivjHQDboK9tmMvlZHjTAuljM,6291
58
+ torch_geometric/datasets/__init__.py,sha256=rgfUmjd9U3o8renKVl81Brscx4LOtwWmt6qAoaG41C4,6417
59
59
  torch_geometric/datasets/actor.py,sha256=oUxgJIX8bi5hJr1etWNYIFyVQNDDXi1nyVpHGGMEAGQ,4304
60
60
  torch_geometric/datasets/airfrans.py,sha256=8cCBmHPttrlKY_iwfyr-K-CUX_JEDjrIOg3r9dQSN7o,5439
61
61
  torch_geometric/datasets/airports.py,sha256=b3gkv3gY2JkUpmGiz36Z-g7EcnSfU8lBG1YsCOWdJ6k,3758
@@ -95,7 +95,7 @@ torch_geometric/datasets/gdelt_lite.py,sha256=zE1WagpgmsQARQhEgdCBtALRKyuQvIZqxT
95
95
  torch_geometric/datasets/ged_dataset.py,sha256=dtd-C6pCygNHLXgVfg3ZTWtTVHKT13Q3GlGrze1_rpo,9551
96
96
  torch_geometric/datasets/gemsec.py,sha256=oMTSryTgyed9z_4ydg3ql12KM-_35uqL1AoNls5nG8M,2820
97
97
  torch_geometric/datasets/geometry.py,sha256=-BxUMirZcUOf01c3avvF0b6wGPn-4S3Zj3Oau1RaJVk,4223
98
- torch_geometric/datasets/git_mol_dataset.py,sha256=LsS_dPYUpwhWXMBh17iT7IbjlLOP0fFzb-we9cuBDaQ,10681
98
+ torch_geometric/datasets/git_mol_dataset.py,sha256=l5u4U86tfjJdHtQPN7SM3Yjv25LD1Idtm7VHaqJqNik,10665
99
99
  torch_geometric/datasets/github.py,sha256=Qhqhkvi6eZ8VF_HqP1rL2iYToZavFNsQh7J1WdeM9dA,2687
100
100
  torch_geometric/datasets/gnn_benchmark_dataset.py,sha256=4P8n7czF-gf1egLYlAcSSvfB0GXIKpAbH5UjsuFld1M,6976
101
101
  torch_geometric/datasets/heterophilous_graph_dataset.py,sha256=yHHtwl4uPrid0vPOxvPV3sIS8HWdswar8FJ0h0OQ9is,4224
@@ -139,8 +139,9 @@ torch_geometric/datasets/pcqm4m.py,sha256=7ID_xXXIAyuNzYLI2lBWygZl9wGos-dbaz1b6E
139
139
  torch_geometric/datasets/planetoid.py,sha256=RksfwR_PI7qGVphs-T-4jXDepYwQCweMXElLm096hgg,7201
140
140
  torch_geometric/datasets/polblogs.py,sha256=IYzsvd4R0OojmOOZUoOdCwQYfwcTfth1PNtcBK1yOGc,3045
141
141
  torch_geometric/datasets/ppi.py,sha256=zPtg-omC7WYvr9Tzwkb7zNjpXLODsvxKxKdGEUswp2E,5030
142
+ torch_geometric/datasets/protein_mpnn_dataset.py,sha256=TTeTVJMo0Rlt2_h9bbZMKJe3rTJcjCgY5cXGyWteBfA,17756
142
143
  torch_geometric/datasets/qm7.py,sha256=bYyK8xlh9kTr5vqueNbLu9EAjIXkQH1KX1VWnjKfOJc,3323
143
- torch_geometric/datasets/qm9.py,sha256=XU2HTPbgJJ_6hT--X0J2xkXliCbt7_-hub9nuIUQlug,17213
144
+ torch_geometric/datasets/qm9.py,sha256=Ub1t8KNeWFZvw50_Qk-80yNFeYFDwdAeyQtp3JHZs7o,17197
144
145
  torch_geometric/datasets/rcdd.py,sha256=gvOoM1tw_X5QMyBB4FkMUwNErMXAvImyjz5twktBAh8,5317
145
146
  torch_geometric/datasets/reddit.py,sha256=QUgiKTaj6YTOYbgWgqV8mPYsctOui2ujaM8f8qy81v0,3131
146
147
  torch_geometric/datasets/reddit2.py,sha256=WSdrhbDPcUEG37XWNUd0uKnqgI911MOcfjXmgjbTPoQ,4291
@@ -153,6 +154,7 @@ torch_geometric/datasets/snap_dataset.py,sha256=deJvB6cpIQ3bu_pcWoqgEo1-Kl_NcFi7
153
154
  torch_geometric/datasets/suite_sparse.py,sha256=eqjH4vAUq872qdk3YdLkZSwlu6r7HHpTgK0vEVGmY1s,3278
154
155
  torch_geometric/datasets/tag_dataset.py,sha256=qTnwr2N1tbWYeLGbItfv70UxQ3n1rKesjeVU3kcOCP8,14757
155
156
  torch_geometric/datasets/taobao.py,sha256=CUcZpbWsNTasevflO8zqP0YvENy89P7wpKS4MHaDJ6Q,4170
157
+ torch_geometric/datasets/teeth3ds.py,sha256=hZvhcq9lsQENNFr5hk50w2T3CgxE_tlnQfrCgN6uIDQ,9919
156
158
  torch_geometric/datasets/tosca.py,sha256=nUSF8NQT1GlkwWQLshjWmr8xORsvRHzzIqhUyDCvABc,4632
157
159
  torch_geometric/datasets/tu_dataset.py,sha256=14OSaXBgVwT1dX2h1wZ3xVIwoo0GQBEfR3yWh6Q0VF0,7847
158
160
  torch_geometric/datasets/twitch.py,sha256=qfEerf-Uaojx2ZvegENowdG4E7RoUT_HUO9xtULadvo,3658
@@ -374,7 +376,7 @@ torch_geometric/nn/conv/hgt_conv.py,sha256=lUhTWUMovMtn9yR_b2-kLNLqHChGOUl2OtXBY
374
376
  torch_geometric/nn/conv/hypergraph_conv.py,sha256=4BosbbqJyprlI6QjPqIfMxCqnARU_0mUn1zcAQhbw90,8691
375
377
  torch_geometric/nn/conv/le_conv.py,sha256=DonmmYZOKk5wIlTZzzIfNKqBY6MO0MRxYhyr0YtNz-Q,3494
376
378
  torch_geometric/nn/conv/lg_conv.py,sha256=8jMa79iPsOUbXEfBIc3wmbvAD8T3d1j37LeIFTX3Yag,2369
377
- torch_geometric/nn/conv/meshcnn_conv.py,sha256=Z6p9KwGc_Kj4XQnTWqzbXQzbbpVlMv7ga0DuDB0jLSg,22279
379
+ torch_geometric/nn/conv/meshcnn_conv.py,sha256=92zUcgfS0Fwv-MpddF4Ia1a65y7ddPAkazYf7D6kvwg,21951
378
380
  torch_geometric/nn/conv/message_passing.py,sha256=ZuTvSvodGy1GyAW4mHtuoMUuxclam-7opidYNY5IHm8,44377
379
381
  torch_geometric/nn/conv/mf_conv.py,sha256=SkOGMN1tFT9dcqy8xYowsB2ozw6QfkoArgR1BksZZaU,4340
380
382
  torch_geometric/nn/conv/mixhop_conv.py,sha256=qVDPWeWcnO7_eHM0ZnpKtr8SISjb4jp0xjgpoDrwjlk,4555
@@ -429,7 +431,7 @@ torch_geometric/nn/kge/distmult.py,sha256=dGQ0bVzjreZgFN1lXE23_IIidsiOq7ehPrMb-N
429
431
  torch_geometric/nn/kge/loader.py,sha256=5Uc1j3OUMQnBYSHDqL7pLCty1siFLzoPkztigYO2zP8,771
430
432
  torch_geometric/nn/kge/rotate.py,sha256=XLuO1AbyTt5cJxr97ZzoyAyIEsHKesgW5TvDmnGJAao,3208
431
433
  torch_geometric/nn/kge/transe.py,sha256=jlejq5BLMm-sb1wWcLDp7pZqCdelWBgjDIC8ctbjSdU,3088
432
- torch_geometric/nn/models/__init__.py,sha256=4mZ5dyiZ9aa1NaBth1qYV-hZdnG_Np1XWvRLB4Qv6RM,2338
434
+ torch_geometric/nn/models/__init__.py,sha256=fbHQauZw9Snvl2PuN5cjZoAW8SwUl6E-p2IOmwUKB3A,2395
433
435
  torch_geometric/nn/models/attentive_fp.py,sha256=1z3iTV2O5W9tqHFAdno8FeBFeXmuG-TDZk4lwwVh3Ac,6634
434
436
  torch_geometric/nn/models/attract_repel.py,sha256=h9OyogT0NY0xiT0DkpJHMxH6ZUmo8R-CmwZdKEwq8Ek,5277
435
437
  torch_geometric/nn/models/autoencoder.py,sha256=nGje-zty78Y3hxOJ9o0_6QziJjOvBlknk6z0_fDQwQU,10770
@@ -442,7 +444,7 @@ torch_geometric/nn/models/dimenet.py,sha256=O2rqEx5HWs_lMwRD8eq6WMkbqJaCLL5zgWUJ
442
444
  torch_geometric/nn/models/dimenet_utils.py,sha256=Eyn_EiJqwKvuYj6BtRpSxrzMG3v4Gk98X9MxZ7uvwm4,5069
443
445
  torch_geometric/nn/models/g_retriever.py,sha256=tVibbqM_r-1LnA3R3oVyzp0bpuN3qPoYqcU6LZ8dYEk,8260
444
446
  torch_geometric/nn/models/git_mol.py,sha256=Wc6Hx6RDDR7sDWRWHfA5eK9e9gFsrTZ9OLmpMfoj3pE,12676
445
- torch_geometric/nn/models/glem.py,sha256=PlLjfMM4lKLs7c7tRC4LVD8tj0jpUXNxcnGbYut7vBE,16624
447
+ torch_geometric/nn/models/glem.py,sha256=GlL_I63g-_5eTycSGRj720YntldQ-CQ351RaDPc6XAU,16674
446
448
  torch_geometric/nn/models/gnnff.py,sha256=15dkiLgy0LmH1hnUrpeoHioIp4BPTfjpVATpnGRt9E0,7860
447
449
  torch_geometric/nn/models/gpse.py,sha256=acEAeeicLgzKRL54WhvIFxjA5XViHgXgMEH-NgbMdqI,41971
448
450
  torch_geometric/nn/models/graph_mixer.py,sha256=mthMeCOikR8gseEsu4oJ3Cd9C35zHSv1p32ROwnG-6s,9246
@@ -459,6 +461,7 @@ torch_geometric/nn/models/molecule_gpt.py,sha256=k-XULH6jaurj-R2EE4sIWTkqlNqa3Cz
459
461
  torch_geometric/nn/models/neural_fingerprint.py,sha256=pTLJgU9Uh2Lnf9bggLj4cKI8YdEFcMF-9MALuubqbuQ,2378
460
462
  torch_geometric/nn/models/node2vec.py,sha256=81Ku4Rp4IwLEAy06KEgJ2fYtXXVL_uv_Hb8lBr6YXrE,7664
461
463
  torch_geometric/nn/models/pmlp.py,sha256=dcAASVSyQMMhItSfEJWPeAFh0R3tNCwAHwdrShwQ8o4,3538
464
+ torch_geometric/nn/models/protein_mpnn.py,sha256=QXHfltiJPmakpzgJKw_1vwCGBlszv9nfY4r4F38Sg9k,11031
462
465
  torch_geometric/nn/models/re_net.py,sha256=pz66q5b5BoGDNVQvpEGS2RGoeKvpjkYAv9r3WAuvITk,8986
463
466
  torch_geometric/nn/models/rect.py,sha256=2F3XyyvHTAEuqfJpiNB5M8pSGy738LhPiom5I-SDWqM,2808
464
467
  torch_geometric/nn/models/rev_gnn.py,sha256=Bpme087Zs227lcB0ODOKWsxaly67q96wseaRt6bacjs,11796
@@ -613,7 +616,7 @@ torch_geometric/utils/_tree_decomposition.py,sha256=ZtpjPQJgXbQWtSWjo-Fmhrov0DGO
613
616
  torch_geometric/utils/_trim_to_layer.py,sha256=cauOEzMJJK4w9BC-Pg1bHVncBYqG9XxQex3rn10BFjc,8339
614
617
  torch_geometric/utils/_unbatch.py,sha256=B0vjKI96PtHvSBG8F_lqvsiJE134aVjUurPZsG6UZRI,2378
615
618
  torch_geometric/utils/augmentation.py,sha256=1F0YCuaklZ9ZbXxdFV0oOoemWvLd8p60WvFo2chzl7E,8600
616
- torch_geometric/utils/convert.py,sha256=j0t_87c-U_-15YKFfkOZfloEc5NbjgeLIk851zHG8WA,21665
619
+ torch_geometric/utils/convert.py,sha256=RE5n5no74Xu39-QMWFE0-1RvTgykdK33ymyjF9WcuSs,21938
617
620
  torch_geometric/utils/cross_entropy.py,sha256=ZFS5bivtzv3EV9zqgKsekmuQyoZZggPSclhl_tRNHxo,3047
618
621
  torch_geometric/utils/dropout.py,sha256=gg0rDnD4FLvBaKSoLAkZwViAQflhLefJm6_Mju5dmQs,11416
619
622
  torch_geometric/utils/embedding.py,sha256=Ac_MPSrZGpw-e-gU6Yz-seUioC2WZxBSSzXFeclGwMk,5232
@@ -634,13 +637,13 @@ torch_geometric/utils/num_nodes.py,sha256=F15ciTFOe8AxjkUh1wKH7RLmJvQYYpz-l3pPPv
634
637
  torch_geometric/utils/ppr.py,sha256=ebiHbQqRJsQbGUI5xu-IkzQSQsgIaC71vgO0KcXIKAk,4055
635
638
  torch_geometric/utils/random.py,sha256=Rv5HlhG5310rytbT9EZ7xWLGKQfozfz1azvYi5nx2-U,5148
636
639
  torch_geometric/utils/repeat.py,sha256=RxCoRoEisaP6NouXPPW5tY1Rn-tIfrmpJPm0qGP6W8M,815
637
- torch_geometric/utils/smiles.py,sha256=lGQ2BwJ49uBrQfIxxPz8ceTO9Jo-XCjlLxs1ql3xrsA,7130
640
+ torch_geometric/utils/smiles.py,sha256=CFqeNtSBXQtY9Ex2gQzI0La490IpVVrm01QdRYEpV7w,7114
638
641
  torch_geometric/utils/sparse.py,sha256=1DbaEwdyvnzvg5qVjPlnWcEVDMkxrQLX1jJ0dr6P4js,25135
639
642
  torch_geometric/utils/undirected.py,sha256=H_nfpI0_WluOG6VfjPyldvcjL4w5USAKWu2x5nUAUjw,6222
640
643
  torch_geometric/visualization/__init__.py,sha256=b-HnVesXjyJ_L1N-DnjiRiRVf7lhwKaBQF_2i5YMVSU,208
641
644
  torch_geometric/visualization/graph.py,sha256=mfZHXYfiU-CWMtfawYc80IxVwVmtK9hbIkSKhM_j7oI,14311
642
645
  torch_geometric/visualization/influence.py,sha256=CWMvuNA_Nf1sfbJmQgn58yS4OFpeKXeZPe7kEuvkUBw,477
643
- pyg_nightly-2.7.0.dev20250702.dist-info/licenses/LICENSE,sha256=ic-27cMJc1kWoMEYncz3Ya3Ur2Bi3bNLWib2DT763-o,1067
644
- pyg_nightly-2.7.0.dev20250702.dist-info/WHEEL,sha256=G2gURzTEtmeR8nrdXUJfNiB3VYVxigPQ-bEQujpNiNs,82
645
- pyg_nightly-2.7.0.dev20250702.dist-info/METADATA,sha256=66AyTfnfJvD0er8ePN_vOUgj6tD76JJy4QPaIvkh8bw,63005
646
- pyg_nightly-2.7.0.dev20250702.dist-info/RECORD,,
646
+ pyg_nightly-2.7.0.dev20250703.dist-info/licenses/LICENSE,sha256=ic-27cMJc1kWoMEYncz3Ya3Ur2Bi3bNLWib2DT763-o,1067
647
+ pyg_nightly-2.7.0.dev20250703.dist-info/WHEEL,sha256=G2gURzTEtmeR8nrdXUJfNiB3VYVxigPQ-bEQujpNiNs,82
648
+ pyg_nightly-2.7.0.dev20250703.dist-info/METADATA,sha256=lm-xVNswGkfSaczkhV_ANegM0oGwBEfhCv75gwj0X5Q,63005
649
+ pyg_nightly-2.7.0.dev20250703.dist-info/RECORD,,
@@ -31,7 +31,7 @@ from .lazy_loader import LazyLoader
31
31
  contrib = LazyLoader('contrib', globals(), 'torch_geometric.contrib')
32
32
  graphgym = LazyLoader('graphgym', globals(), 'torch_geometric.graphgym')
33
33
 
34
- __version__ = '2.7.0.dev20250702'
34
+ __version__ = '2.7.0.dev20250703'
35
35
 
36
36
  __all__ = [
37
37
  'Index',
@@ -81,8 +81,10 @@ from .web_qsp_dataset import WebQSPDataset, CWQDataset
81
81
  from .git_mol_dataset import GitMolDataset
82
82
  from .molecule_gpt_dataset import MoleculeGPTDataset
83
83
  from .instruct_mol_dataset import InstructMolDataset
84
+ from .protein_mpnn_dataset import ProteinMPNNDataset
84
85
  from .tag_dataset import TAGDataset
85
86
  from .city import CityNetwork
87
+ from .teeth3ds import Teeth3DS
86
88
 
87
89
  from .dbp15k import DBP15K
88
90
  from .aminer import AMiner
@@ -201,8 +203,10 @@ homo_datasets = [
201
203
  'GitMolDataset',
202
204
  'MoleculeGPTDataset',
203
205
  'InstructMolDataset',
206
+ 'ProteinMPNNDataset',
204
207
  'TAGDataset',
205
208
  'CityNetwork',
209
+ 'Teeth3DS',
206
210
  ]
207
211
 
208
212
  hetero_datasets = [
@@ -102,7 +102,7 @@ class GitMolDataset(InMemoryDataset):
102
102
 
103
103
  try:
104
104
  from rdkit import Chem, RDLogger
105
- RDLogger.DisableLog('rdApp.*') # type: ignore
105
+ RDLogger.DisableLog('rdApp.*')
106
106
  WITH_RDKIT = True
107
107
 
108
108
  except ImportError:
@@ -0,0 +1,451 @@
1
+ import os
2
+ import pickle
3
+ import random
4
+ from collections import defaultdict
5
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
6
+
7
+ import numpy as np
8
+ import torch
9
+ from tqdm import tqdm
10
+
11
+ from torch_geometric.data import (
12
+ Data,
13
+ InMemoryDataset,
14
+ download_url,
15
+ extract_tar,
16
+ )
17
+
18
+
19
+ class ProteinMPNNDataset(InMemoryDataset):
20
+ r"""The ProteinMPNN dataset from the `"Robust deep learning based protein
21
+ sequence design using ProteinMPNN"
22
+ <https://www.biorxiv.org/content/10.1101/2022.06.03.494563v1>`_ paper.
23
+
24
+ Args:
25
+ root (str): Root directory where the dataset should be saved.
26
+ size (str): Size of the PDB information to train the model.
27
+ If :obj:`"small"`, loads the small dataset (229.4 MB).
28
+ If :obj:`"large"`, loads the large dataset (64.1 GB).
29
+ (default: :obj:`"small"`)
30
+ split (str, optional): If :obj:`"train"`, loads the training dataset.
31
+ If :obj:`"valid"`, loads the validation dataset.
32
+ If :obj:`"test"`, loads the test dataset.
33
+ (default: :obj:`"train"`)
34
+ datacut (str, optional): Date cutoff to filter the dataset.
35
+ (default: :obj:`"2030-01-01"`)
36
+ rescut (float, optional): PDB resolution cutoff.
37
+ (default: :obj:`3.5`)
38
+ homo (float, optional): Homology cutoff.
39
+ (default: :obj:`0.70`)
40
+ max_length (int, optional): Maximum length of the protein complex.
41
+ (default: :obj:`10000`)
42
+ num_units (int, optional): Number of units of the protein complex.
43
+ (default: :obj:`150`)
44
+ transform (callable, optional): A function/transform that takes in an
45
+ :obj:`torch_geometric.data.Data` object and returns a transformed
46
+ version. The data object will be transformed before every access.
47
+ (default: :obj:`None`)
48
+ pre_transform (callable, optional): A function/transform that takes in
49
+ an :obj:`torch_geometric.data.Data` object and returns a
50
+ transformed version. The data object will be transformed before
51
+ being saved to disk. (default: :obj:`None`)
52
+ pre_filter (callable, optional): A function that takes in an
53
+ :obj:`torch_geometric.data.Data` object and returns a boolean
54
+ value, indicating whether the data object should be included in the
55
+ final dataset. (default: :obj:`None`)
56
+ force_reload (bool, optional): Whether to re-process the dataset.
57
+ (default: :obj:`False`)
58
+ """
59
+
60
+ raw_url = {
61
+ 'small':
62
+ 'https://files.ipd.uw.edu/pub/training_sets/'
63
+ 'pdb_2021aug02_sample.tar.gz',
64
+ 'large':
65
+ 'https://files.ipd.uw.edu/pub/training_sets/'
66
+ 'pdb_2021aug02.tar.gz',
67
+ }
68
+
69
+ splits = {
70
+ 'train': 1,
71
+ 'valid': 2,
72
+ 'test': 3,
73
+ }
74
+
75
+ def __init__(
76
+ self,
77
+ root: str,
78
+ size: str = 'small',
79
+ split: str = 'train',
80
+ datacut: str = '2030-01-01',
81
+ rescut: float = 3.5,
82
+ homo: float = 0.70,
83
+ max_length: int = 10000,
84
+ num_units: int = 150,
85
+ transform: Optional[Callable] = None,
86
+ pre_transform: Optional[Callable] = None,
87
+ pre_filter: Optional[Callable] = None,
88
+ force_reload: bool = False,
89
+ ) -> None:
90
+ self.size = size
91
+ self.split = split
92
+ self.datacut = datacut
93
+ self.rescut = rescut
94
+ self.homo = homo
95
+ self.max_length = max_length
96
+ self.num_units = num_units
97
+
98
+ self.sub_folder = self.raw_url[self.size].split('/')[-1].split('.')[0]
99
+
100
+ super().__init__(root, transform, pre_transform, pre_filter,
101
+ force_reload=force_reload)
102
+ self.load(self.processed_paths[self.splits[self.split]])
103
+
104
+ @property
105
+ def raw_file_names(self) -> List[str]:
106
+ return [
107
+ f'{self.sub_folder}/{f}'
108
+ for f in ['list.csv', 'valid_clusters.txt', 'test_clusters.txt']
109
+ ]
110
+
111
+ @property
112
+ def processed_file_names(self) -> List[str]:
113
+ return ['splits.pkl', 'train.pt', 'valid.pt', 'test.pt']
114
+
115
+ def download(self) -> None:
116
+ file_path = download_url(self.raw_url[self.size], self.raw_dir)
117
+ extract_tar(file_path, self.raw_dir)
118
+ os.unlink(file_path)
119
+
120
+ def process(self) -> None:
121
+ alphabet_set = set(list('ACDEFGHIKLMNPQRSTVWYX'))
122
+ cluster_ids = self._process_split()
123
+ total_items = sum(len(items) for items in cluster_ids.values())
124
+ data_list = []
125
+
126
+ with tqdm(total=total_items, desc="Processing") as pbar:
127
+ for _, items in cluster_ids.items():
128
+ for chain_id, _ in items:
129
+ item = self._process_pdb1(chain_id)
130
+
131
+ if 'label' not in item:
132
+ pbar.update(1)
133
+ continue
134
+ if len(list(np.unique(item['idx']))) >= 352:
135
+ pbar.update(1)
136
+ continue
137
+
138
+ my_dict = self._process_pdb2(item)
139
+
140
+ if len(my_dict['seq']) > self.max_length:
141
+ pbar.update(1)
142
+ continue
143
+ bad_chars = set(list(
144
+ my_dict['seq'])).difference(alphabet_set)
145
+ if len(bad_chars) > 0:
146
+ pbar.update(1)
147
+ continue
148
+
149
+ x_chain_all, chain_seq_label_all, mask, chain_mask_all, residue_idx, chain_encoding_all = self._process_pdb3( # noqa: E501
150
+ my_dict)
151
+
152
+ data = Data(
153
+ x=x_chain_all, # [seq_len, 4, 3]
154
+ chain_seq_label=chain_seq_label_all, # [seq_len]
155
+ mask=mask, # [seq_len]
156
+ chain_mask_all=chain_mask_all, # [seq_len]
157
+ residue_idx=residue_idx, # [seq_len]
158
+ chain_encoding_all=chain_encoding_all, # [seq_len]
159
+ )
160
+
161
+ if self.pre_filter is not None and not self.pre_filter(
162
+ data):
163
+ continue
164
+ if self.pre_transform is not None:
165
+ data = self.pre_transform(data)
166
+
167
+ data_list.append(data)
168
+
169
+ if len(data_list) >= self.num_units:
170
+ pbar.update(total_items - pbar.n)
171
+ break
172
+ pbar.update(1)
173
+ else:
174
+ continue
175
+ break
176
+ self.save(data_list, self.processed_paths[self.splits[self.split]])
177
+
178
+ def _process_split(self) -> Dict[int, List[Tuple[str, int]]]:
179
+ import pandas as pd
180
+ save_path = self.processed_paths[0]
181
+
182
+ if os.path.exists(save_path):
183
+ print('Load split')
184
+ with open(save_path, 'rb') as f:
185
+ data = pickle.load(f)
186
+ else:
187
+ # CHAINID, DEPOSITION, RESOLUTION, HASH, CLUSTER, SEQUENCE
188
+ df = pd.read_csv(self.raw_paths[0])
189
+ df = df[(df['RESOLUTION'] <= self.rescut)
190
+ & (df['DEPOSITION'] <= self.datacut)]
191
+
192
+ val_ids = pd.read_csv(self.raw_paths[1], header=None)[0].tolist()
193
+ test_ids = pd.read_csv(self.raw_paths[2], header=None)[0].tolist()
194
+
195
+ # compile training and validation sets
196
+ data = {
197
+ 'train': defaultdict(list),
198
+ 'valid': defaultdict(list),
199
+ 'test': defaultdict(list),
200
+ }
201
+
202
+ for _, r in tqdm(df.iterrows(), desc='Processing split',
203
+ total=len(df)):
204
+ cluster_id = r['CLUSTER']
205
+ hash_id = r['HASH']
206
+ chain_id = r['CHAINID']
207
+ if cluster_id in val_ids:
208
+ data['valid'][cluster_id].append((chain_id, hash_id))
209
+ elif cluster_id in test_ids:
210
+ data['test'][cluster_id].append((chain_id, hash_id))
211
+ else:
212
+ data['train'][cluster_id].append((chain_id, hash_id))
213
+
214
+ with open(save_path, 'wb') as f:
215
+ pickle.dump(data, f)
216
+
217
+ return data[self.split]
218
+
219
+ def _process_pdb1(self, chain_id: str) -> Dict[str, Any]:
220
+ pdbid, chid = chain_id.split('_')
221
+ prefix = f'{self.raw_dir}/{self.sub_folder}/pdb/{pdbid[1:3]}/{pdbid}'
222
+ # load metadata
223
+ if not os.path.isfile(f'{prefix}.pt'):
224
+ return {'seq': np.zeros(5)}
225
+ meta = torch.load(f'{prefix}.pt')
226
+ asmb_ids = meta['asmb_ids']
227
+ asmb_chains = meta['asmb_chains']
228
+ chids = np.array(meta['chains'])
229
+
230
+ # find candidate assemblies which contain chid chain
231
+ asmb_candidates = {
232
+ a
233
+ for a, b in zip(asmb_ids, asmb_chains) if chid in b.split(',')
234
+ }
235
+
236
+ # if the chains is missing is missing from all the assemblies
237
+ # then return this chain alone
238
+ if len(asmb_candidates) < 1:
239
+ chain = torch.load(f'{prefix}_{chid}.pt')
240
+ L = len(chain['seq'])
241
+ return {
242
+ 'seq': chain['seq'],
243
+ 'xyz': chain['xyz'],
244
+ 'idx': torch.zeros(L).int(),
245
+ 'masked': torch.Tensor([0]).int(),
246
+ 'label': chain_id,
247
+ }
248
+
249
+ # randomly pick one assembly from candidates
250
+ asmb_i = random.sample(list(asmb_candidates), 1)
251
+
252
+ # indices of selected transforms
253
+ idx = np.where(np.array(asmb_ids) == asmb_i)[0]
254
+
255
+ # load relevant chains
256
+ chains = {
257
+ c: torch.load(f'{prefix}_{c}.pt')
258
+ for i in idx
259
+ for c in asmb_chains[i] if c in meta['chains']
260
+ }
261
+
262
+ # generate assembly
263
+ asmb = {}
264
+ for k in idx:
265
+
266
+ # pick k-th xform
267
+ xform = meta[f'asmb_xform{k}']
268
+ u = xform[:, :3, :3]
269
+ r = xform[:, :3, 3]
270
+
271
+ # select chains which k-th xform should be applied to
272
+ s1 = set(meta['chains'])
273
+ s2 = set(asmb_chains[k].split(','))
274
+ chains_k = s1 & s2
275
+
276
+ # transform selected chains
277
+ for c in chains_k:
278
+ try:
279
+ xyz = chains[c]['xyz']
280
+ xyz_ru = torch.einsum('bij,raj->brai', u, xyz) + r[:, None,
281
+ None, :]
282
+ asmb.update({
283
+ (c, k, i): xyz_i
284
+ for i, xyz_i in enumerate(xyz_ru)
285
+ })
286
+ except KeyError:
287
+ return {'seq': np.zeros(5)}
288
+
289
+ # select chains which share considerable similarity to chid
290
+ seqid = meta['tm'][chids == chid][0, :, 1]
291
+ homo = {
292
+ ch_j
293
+ for seqid_j, ch_j in zip(seqid, chids) if seqid_j > self.homo
294
+ }
295
+ # stack all chains in the assembly together
296
+ seq: str = ''
297
+ xyz_all: List[torch.Tensor] = []
298
+ idx_all: List[torch.Tensor] = []
299
+ masked: List[int] = []
300
+ seq_list = []
301
+ for counter, (k, v) in enumerate(asmb.items()):
302
+ seq += chains[k[0]]['seq']
303
+ seq_list.append(chains[k[0]]['seq'])
304
+ xyz_all.append(v)
305
+ idx_all.append(torch.full((v.shape[0], ), counter))
306
+ if k[0] in homo:
307
+ masked.append(counter)
308
+
309
+ return {
310
+ 'seq': seq,
311
+ 'xyz': torch.cat(xyz_all, dim=0),
312
+ 'idx': torch.cat(idx_all, dim=0),
313
+ 'masked': torch.Tensor(masked).int(),
314
+ 'label': chain_id,
315
+ }
316
+
317
+ def _process_pdb2(self, t: Dict[str, Any]) -> Dict[str, Any]:
318
+ init_alphabet = list(
319
+ 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz')
320
+ extra_alphabet = [str(item) for item in list(np.arange(300))]
321
+ chain_alphabet = init_alphabet + extra_alphabet
322
+ my_dict: Dict[str, Union[str, int, Dict[str, Any], List[Any]]] = {}
323
+ concat_seq = ''
324
+ mask_list = []
325
+ visible_list = []
326
+ for idx in list(np.unique(t['idx'])):
327
+ letter = chain_alphabet[idx]
328
+ res = np.argwhere(t['idx'] == idx)
329
+ initial_sequence = "".join(list(
330
+ np.array(list(t['seq']))[res][
331
+ 0,
332
+ ]))
333
+ if initial_sequence[-6:] == "HHHHHH":
334
+ res = res[:, :-6]
335
+ if initial_sequence[0:6] == "HHHHHH":
336
+ res = res[:, 6:]
337
+ if initial_sequence[-7:-1] == "HHHHHH":
338
+ res = res[:, :-7]
339
+ if initial_sequence[-8:-2] == "HHHHHH":
340
+ res = res[:, :-8]
341
+ if initial_sequence[-9:-3] == "HHHHHH":
342
+ res = res[:, :-9]
343
+ if initial_sequence[-10:-4] == "HHHHHH":
344
+ res = res[:, :-10]
345
+ if initial_sequence[1:7] == "HHHHHH":
346
+ res = res[:, 7:]
347
+ if initial_sequence[2:8] == "HHHHHH":
348
+ res = res[:, 8:]
349
+ if initial_sequence[3:9] == "HHHHHH":
350
+ res = res[:, 9:]
351
+ if initial_sequence[4:10] == "HHHHHH":
352
+ res = res[:, 10:]
353
+ if res.shape[1] >= 4:
354
+ chain_seq = "".join(list(np.array(list(t['seq']))[res][0]))
355
+ my_dict[f'seq_chain_{letter}'] = chain_seq
356
+ concat_seq += chain_seq
357
+ if idx in t['masked']:
358
+ mask_list.append(letter)
359
+ else:
360
+ visible_list.append(letter)
361
+ coords_dict_chain = {}
362
+ all_atoms = np.array(t['xyz'][res])[0] # [L, 14, 3]
363
+ for i, c in enumerate(['N', 'CA', 'C', 'O']):
364
+ coords_dict_chain[
365
+ f'{c}_chain_{letter}'] = all_atoms[:, i, :].tolist()
366
+ my_dict[f'coords_chain_{letter}'] = coords_dict_chain
367
+ my_dict['name'] = t['label']
368
+ my_dict['masked_list'] = mask_list
369
+ my_dict['visible_list'] = visible_list
370
+ my_dict['num_of_chains'] = len(mask_list) + len(visible_list)
371
+ my_dict['seq'] = concat_seq
372
+ return my_dict
373
+
374
+ def _process_pdb3(
375
+ self, b: Dict[str, Any]
376
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
377
+ torch.Tensor, torch.Tensor]:
378
+ L = len(b['seq'])
379
+ # residue idx with jumps across chains
380
+ residue_idx = -100 * np.ones([L], dtype=np.int32)
381
+ # get the list of masked / visible chains
382
+ masked_chains, visible_chains = b['masked_list'], b['visible_list']
383
+ visible_temp_dict, masked_temp_dict = {}, {}
384
+ for letter in masked_chains + visible_chains:
385
+ chain_seq = b[f'seq_chain_{letter}']
386
+ if letter in visible_chains:
387
+ visible_temp_dict[letter] = chain_seq
388
+ elif letter in masked_chains:
389
+ masked_temp_dict[letter] = chain_seq
390
+ # check for duplicate chains (same sequence but different identity)
391
+ for _, vm in masked_temp_dict.items():
392
+ for kv, vv in visible_temp_dict.items():
393
+ if vm == vv:
394
+ if kv not in masked_chains:
395
+ masked_chains.append(kv)
396
+ if kv in visible_chains:
397
+ visible_chains.remove(kv)
398
+ # build protein data structures
399
+ all_chains = masked_chains + visible_chains
400
+ np.random.shuffle(all_chains)
401
+ x_chain_list = []
402
+ chain_mask_list = []
403
+ chain_seq_list = []
404
+ chain_encoding_list = []
405
+ c, l0, l1 = 1, 0, 0
406
+ for letter in all_chains:
407
+ chain_seq = b[f'seq_chain_{letter}']
408
+ chain_length = len(chain_seq)
409
+ chain_coords = b[f'coords_chain_{letter}']
410
+ x_chain = np.stack([
411
+ chain_coords[c] for c in [
412
+ f'N_chain_{letter}', f'CA_chain_{letter}',
413
+ f'C_chain_{letter}', f'O_chain_{letter}'
414
+ ]
415
+ ], 1) # [chain_length, 4, 3]
416
+ x_chain_list.append(x_chain)
417
+ chain_seq_list.append(chain_seq)
418
+ if letter in visible_chains:
419
+ chain_mask = np.zeros(chain_length) # 0 for visible chains
420
+ elif letter in masked_chains:
421
+ chain_mask = np.ones(chain_length) # 1 for masked chains
422
+ chain_mask_list.append(chain_mask)
423
+ chain_encoding_list.append(c * np.ones(chain_length))
424
+ l1 += chain_length
425
+ residue_idx[l0:l1] = 100 * (c - 1) + np.arange(l0, l1)
426
+ l0 += chain_length
427
+ c += 1
428
+ x_chain_all = np.concatenate(x_chain_list, 0) # [L, 4, 3]
429
+ chain_seq_all = "".join(chain_seq_list)
430
+ # [L,] 1.0 for places that need to be predicted
431
+ chain_mask_all = np.concatenate(chain_mask_list, 0)
432
+ chain_encoding_all = np.concatenate(chain_encoding_list, 0)
433
+
434
+ # Convert to labels
435
+ alphabet = 'ACDEFGHIKLMNPQRSTVWYX'
436
+ chain_seq_label_all = np.asarray(
437
+ [alphabet.index(a) for a in chain_seq_all], dtype=np.int32)
438
+
439
+ isnan = np.isnan(x_chain_all)
440
+ mask = np.isfinite(np.sum(x_chain_all, (1, 2))).astype(np.float32)
441
+ x_chain_all[isnan] = 0.
442
+
443
+ # Conversion
444
+ return (
445
+ torch.from_numpy(x_chain_all).to(dtype=torch.float32),
446
+ torch.from_numpy(chain_seq_label_all).to(dtype=torch.long),
447
+ torch.from_numpy(mask).to(dtype=torch.float32),
448
+ torch.from_numpy(chain_mask_all).to(dtype=torch.float32),
449
+ torch.from_numpy(residue_idx).to(dtype=torch.long),
450
+ torch.from_numpy(chain_encoding_all).to(dtype=torch.long),
451
+ )
@@ -202,7 +202,7 @@ class QM9(InMemoryDataset):
202
202
  from rdkit import Chem, RDLogger
203
203
  from rdkit.Chem.rdchem import BondType as BT
204
204
  from rdkit.Chem.rdchem import HybridizationType
205
- RDLogger.DisableLog('rdApp.*') # type: ignore
205
+ RDLogger.DisableLog('rdApp.*')
206
206
  WITH_RDKIT = True
207
207
 
208
208
  except ImportError:
@@ -0,0 +1,269 @@
1
+ import json
2
+ import os
3
+ import os.path as osp
4
+ from glob import glob
5
+ from typing import Callable, Dict, List, Optional
6
+
7
+ import numpy as np
8
+ import torch
9
+ from tqdm import tqdm
10
+
11
+ from torch_geometric.data import (
12
+ Data,
13
+ InMemoryDataset,
14
+ download_url,
15
+ extract_zip,
16
+ )
17
+
18
+
19
+ class Teeth3DS(InMemoryDataset):
20
+ r"""The Teeth3DS+ dataset from the `"An Extended Benchmark for Intra-oral
21
+ 3D Scans Analysis" <https://crns-smartvision.github.io/teeth3ds/>`_ paper.
22
+
23
+ This dataset is the first comprehensive public benchmark designed to
24
+ advance the field of intra-oral 3D scan analysis developed as part of the
25
+ 3DTeethSeg 2022 and 3DTeethLand 2024 MICCAI challenges, aiming to drive
26
+ research in teeth identification, segmentation, labeling, 3D modeling,
27
+ and dental landmark identification.
28
+ The dataset includes at least 1,800 intra-oral scans (containing 23,999
29
+ annotated teeth) collected from 900 patients, covering both upper and lower
30
+ jaws separately.
31
+
32
+ Args:
33
+ root (str): Root directory where the dataset should be saved.
34
+ split (str): The split name (one of :obj:`"Teeth3DS"`,
35
+ :obj:`"3DTeethSeg22_challenge"` or :obj:`"3DTeethLand_challenge"`).
36
+ train (bool, optional): If :obj:`True`, loads the training dataset,
37
+ otherwise the test dataset. (default: :obj:`True`)
38
+ num_samples (int, optional): Number of points to sample from each mesh.
39
+ (default: :obj:`30000`)
40
+ transform (callable, optional): A function/transform that takes in an
41
+ :obj:`torch_geometric.data.Data` object and returns a transformed
42
+ version. The data object will be transformed before every access.
43
+ (default: :obj:`None`)
44
+ pre_transform (callable, optional): A function/transform that takes in
45
+ an :obj:`torch_geometric.data.Data` object and returns a
46
+ transformed version. The data object will be transformed before
47
+ being saved to disk. (default: :obj:`None`)
48
+ force_reload (bool, optional): Whether to re-process the dataset.
49
+ (default: :obj:`False`)
50
+ """
51
+ urls = {
52
+ 'data_part_1.zip':
53
+ 'https://osf.io/download/qhprs/',
54
+ 'data_part_2.zip':
55
+ 'https://osf.io/download/4pwnr/',
56
+ 'data_part_3.zip':
57
+ 'https://osf.io/download/frwdp/',
58
+ 'data_part_4.zip':
59
+ 'https://osf.io/download/2arn4/',
60
+ 'data_part_5.zip':
61
+ 'https://osf.io/download/xrz5f/',
62
+ 'data_part_6.zip':
63
+ 'https://osf.io/download/23hgq/',
64
+ 'data_part_7.zip':
65
+ 'https://osf.io/download/u83ad/',
66
+ 'train_test_split':
67
+ 'https://files.de-1.osf.io/v1/'
68
+ 'resources/xctdy/providers/osfstorage/?zip='
69
+ }
70
+
71
+ sample_url = {
72
+ 'teeth3ds_sample': 'https://osf.io/download/vr38s/',
73
+ }
74
+
75
+ landmarks_urls = {
76
+ '3DTeethLand_landmarks_train.zip': 'https://osf.io/download/k5hbj/',
77
+ '3DTeethLand_landmarks_test.zip': 'https://osf.io/download/sqw5e/',
78
+ }
79
+
80
+ def __init__(
81
+ self,
82
+ root: str,
83
+ split:
84
+ str = 'Teeth3DS', # [3DTeethSeg22_challenge, 3DTeethLand_challenge]
85
+ train: bool = True,
86
+ num_samples: int = 30000,
87
+ transform: Optional[Callable] = None,
88
+ pre_transform: Optional[Callable] = None,
89
+ force_reload: bool = False,
90
+ ) -> None:
91
+
92
+ self.mode = 'training' if train else 'testing'
93
+ self.split = split
94
+ self.num_samples = num_samples
95
+
96
+ super().__init__(root, transform, pre_transform,
97
+ force_reload=force_reload)
98
+
99
+ @property
100
+ def processed_dir(self) -> str:
101
+ return os.path.join(self.root, f'processed_{self.split}_{self.mode}')
102
+
103
+ @property
104
+ def raw_file_names(self) -> List[str]:
105
+ return ['license.txt']
106
+
107
+ @property
108
+ def processed_file_names(self) -> List[str]:
109
+ # Directory containing train/test split files:
110
+ split_subdir = 'teeth3ds_sample' if self.split == 'sample' else ''
111
+ split_dir = osp.join(
112
+ self.raw_dir,
113
+ split_subdir,
114
+ f'{self.split}_train_test_split',
115
+ )
116
+
117
+ split_files = glob(osp.join(split_dir, f'{self.mode}*.txt'))
118
+
119
+ # Collect all file names from the split files:
120
+ combined_list = []
121
+ for file_path in split_files:
122
+ with open(file_path) as file:
123
+ combined_list.extend(file.read().splitlines())
124
+
125
+ # Generate the list of processed file paths:
126
+ return [f'{file_name}.pt' for file_name in combined_list]
127
+
128
+ def download(self) -> None:
129
+ if self.split == 'sample':
130
+ for key, url in self.sample_url.items():
131
+ path = download_url(url, self.root, filename=key)
132
+ extract_zip(path, self.raw_dir)
133
+ os.unlink(path)
134
+ else:
135
+ for key, url in self.urls.items():
136
+ path = download_url(url, self.root, filename=key)
137
+ extract_zip(path, self.raw_dir)
138
+ os.unlink(path)
139
+ for key, url in self.landmarks_urls.items():
140
+ path = download_url(url, self.root, filename=key)
141
+ extract_zip(path, self.raw_dir) # Extract each downloaded part
142
+ os.unlink(path)
143
+
144
+ def process_file(self, file_path: str) -> Optional[Data]:
145
+ """Processes the input file path to load mesh data, annotations,
146
+ and prepare the input features for a graph-based deep learning model.
147
+ """
148
+ import trimesh
149
+ from fpsample import bucket_fps_kdline_sampling
150
+
151
+ mesh = trimesh.load_mesh(file_path)
152
+
153
+ if isinstance(mesh, list):
154
+ # Handle the case where a list of Geometry objects is returned
155
+ mesh = mesh[0]
156
+
157
+ vertices = mesh.vertices
158
+ vertex_normals = mesh.vertex_normals
159
+
160
+ # Perform sampling on mesh vertices:
161
+ if len(vertices) < self.num_samples:
162
+ sampled_indices = np.random.choice(
163
+ len(vertices),
164
+ self.num_samples,
165
+ replace=True,
166
+ )
167
+ else:
168
+ sampled_indices = bucket_fps_kdline_sampling(
169
+ vertices,
170
+ self.num_samples,
171
+ h=5,
172
+ start_idx=0,
173
+ )
174
+
175
+ if len(sampled_indices) != self.num_samples:
176
+ raise RuntimeError(f"Sampled points mismatch, expected "
177
+ f"{self.num_samples} points, but got "
178
+ f"{len(sampled_indices)} for '{file_path}'")
179
+
180
+ # Extract features and annotations for the sampled points:
181
+ pos = torch.tensor(vertices[sampled_indices], dtype=torch.float)
182
+ x = torch.tensor(vertex_normals[sampled_indices], dtype=torch.float)
183
+
184
+ # Load segmentation annotations:
185
+ seg_annotation_path = file_path.replace('.obj', '.json')
186
+ if osp.exists(seg_annotation_path):
187
+ with open(seg_annotation_path) as f:
188
+ seg_annotations = json.load(f)
189
+ y = torch.tensor(
190
+ np.asarray(seg_annotations['labels'])[sampled_indices],
191
+ dtype=torch.float)
192
+ instances = torch.tensor(
193
+ np.asarray(seg_annotations['instances'])[sampled_indices],
194
+ dtype=torch.float)
195
+ else:
196
+ y = torch.empty(0, 3)
197
+ instances = torch.empty(0, 3)
198
+
199
+ # Load landmarks annotations:
200
+ landmarks_annotation_path = file_path.replace('.obj', '__kpt.json')
201
+
202
+ # Parse keypoint annotations into structured tensors:
203
+ keypoints_dict: Dict[str, List] = {
204
+ key: []
205
+ for key in [
206
+ 'Mesial', 'Distal', 'Cusp', 'InnerPoint', 'OuterPoint',
207
+ 'FacialPoint'
208
+ ]
209
+ }
210
+ keypoint_tensors: Dict[str, torch.Tensor] = {
211
+ key: torch.empty(0, 3)
212
+ for key in [
213
+ 'Mesial', 'Distal', 'Cusp', 'InnerPoint', 'OuterPoint',
214
+ 'FacialPoint'
215
+ ]
216
+ }
217
+ if osp.exists(landmarks_annotation_path):
218
+ with open(landmarks_annotation_path) as f:
219
+ landmarks_annotations = json.load(f)
220
+
221
+ for keypoint in landmarks_annotations['objects']:
222
+ keypoints_dict[keypoint['class']].extend(keypoint['coord'])
223
+
224
+ keypoint_tensors = {
225
+ k: torch.tensor(np.asarray(v),
226
+ dtype=torch.float).reshape(-1, 3)
227
+ for k, v in keypoints_dict.items()
228
+ }
229
+
230
+ data = Data(
231
+ pos=pos,
232
+ x=x,
233
+ y=y,
234
+ instances=instances,
235
+ jaw=file_path.split('.obj')[0].split('_')[1],
236
+ mesial=keypoint_tensors['Mesial'],
237
+ distal=keypoint_tensors['Distal'],
238
+ cusp=keypoint_tensors['Cusp'],
239
+ inner_point=keypoint_tensors['InnerPoint'],
240
+ outer_point=keypoint_tensors['OuterPoint'],
241
+ facial_point=keypoint_tensors['FacialPoint'],
242
+ )
243
+
244
+ if self.pre_transform is not None:
245
+ data = self.pre_transform(data)
246
+
247
+ return data
248
+
249
+ def process(self) -> None:
250
+ for file in tqdm(self.processed_file_names):
251
+ name = file.split('.')[0]
252
+ path = osp.join(self.raw_dir, '**', '*', name + '.obj')
253
+ paths = glob(path)
254
+ if len(paths) == 1:
255
+ data = self.process_file(paths[0])
256
+ torch.save(data, osp.join(self.processed_dir, file))
257
+
258
+ def len(self) -> int:
259
+ return len(self.processed_file_names)
260
+
261
+ def get(self, idx: int) -> Data:
262
+ return torch.load(
263
+ osp.join(self.processed_dir, self.processed_file_names[idx]),
264
+ weights_only=False,
265
+ )
266
+
267
+ def __repr__(self) -> str:
268
+ return (f'{self.__class__.__name__}({len(self)}, '
269
+ f'mode={self.mode}, split={self.split})')
@@ -1,7 +1,7 @@
1
1
  # The below is to suppress the warning on torch.nn.conv.MeshCNNConv::update
2
2
  # pyright: reportIncompatibleMethodOverride=false
3
+ import warnings
3
4
  from typing import Optional
4
- from warnings import warn
5
5
 
6
6
  import torch
7
7
  from torch.nn import Linear, Module, ModuleList
@@ -456,13 +456,10 @@ class MeshCNNConv(MessagePassing):
456
456
  {type(network)}"
457
457
  if not hasattr(network, "in_channels") and \
458
458
  not hasattr(network, "in_features"):
459
- warn(
460
- f"kernel[{i}] does not have attribute \
461
- 'in_channels' nor 'out_features'. The \
462
- network must take as input a \
463
- {self.in_channels}-dimensional tensor. \
464
- Still, assuming user configured \
465
- correctly. Continuing..", stacklevel=2)
459
+ warnings.warn(
460
+ f"kernel[{i}] does not have attribute 'in_channels' nor "
461
+ f"'out_features'. The network must take as input a "
462
+ f"{self.in_channels}-dimensional tensor.", stacklevel=2)
466
463
  else:
467
464
  input_dimension = getattr(network, "in_channels",
468
465
  network.in_features)
@@ -475,13 +472,10 @@ class MeshCNNConv(MessagePassing):
475
472
 
476
473
  if not hasattr(network, "out_channels") and \
477
474
  not hasattr(network, "out_features"):
478
- warn(
479
- f"kernel[{i}] does not have attribute \
480
- 'in_channels' nor 'out_features'. The \
481
- network must take as input a \
482
- {self.in_channels}-dimensional tensor. \
483
- Still, assuming user configured \
484
- correctly. Continuing..", stacklevel=2)
475
+ warnings.warn(
476
+ f"kernel[{i}] does not have attribute 'in_channels' nor "
477
+ f"'out_features'. The network must take as input a "
478
+ f"{self.in_channels}-dimensional tensor.", stacklevel=2)
485
479
  else:
486
480
  output_dimension = getattr(network, "out_channels",
487
481
  network.out_features)
@@ -32,6 +32,7 @@ from .visnet import ViSNet
32
32
  from .g_retriever import GRetriever
33
33
  from .git_mol import GITMol
34
34
  from .molecule_gpt import MoleculeGPT
35
+ from .protein_mpnn import ProteinMPNN
35
36
  from .glem import GLEM
36
37
  from .sgformer import SGFormer
37
38
  # Deprecated:
@@ -86,6 +87,7 @@ __all__ = classes = [
86
87
  'GRetriever',
87
88
  'GITMol',
88
89
  'MoleculeGPT',
90
+ 'ProteinMPNN',
89
91
  'GLEM',
90
92
  'SGFormer',
91
93
  'ARLinkPredictor',
@@ -8,6 +8,13 @@ from torch_geometric.loader import DataLoader, NeighborLoader
8
8
  from torch_geometric.nn.models import GraphSAGE, basic_gnn
9
9
 
10
10
 
11
+ def deal_nan(x):
12
+ if isinstance(x, torch.Tensor):
13
+ x = x.clone()
14
+ x[torch.isnan(x)] = 0.0
15
+ return x
16
+
17
+
11
18
  class GLEM(torch.nn.Module):
12
19
  r"""This GNN+LM co-training model is based on GLEM from the `"Learning on
13
20
  Large-scale Text-attributed Graphs via Variational Inference"
@@ -379,9 +386,6 @@ class GLEM(torch.nn.Module):
379
386
  is_augmented: use EM or just train GNN and LM with gold data
380
387
 
381
388
  """
382
- def deal_nan(x):
383
- return 0 if torch.isnan(x) else x
384
-
385
389
  if is_augmented and (sum(~is_gold) > 0):
386
390
  mle_loss = deal_nan(loss_func(logits[is_gold], labels[is_gold]))
387
391
  # all other labels beside from ground truth(gold labels)
@@ -0,0 +1,304 @@
1
+ from itertools import product
2
+ from typing import Tuple
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+
7
+ from torch_geometric.nn import knn_graph
8
+ from torch_geometric.nn.conv import MessagePassing
9
+ from torch_geometric.utils import to_dense_adj, to_dense_batch
10
+
11
+
12
+ class PositionWiseFeedForward(torch.nn.Module):
13
+ def __init__(self, in_channels: int, hidden_channels: int) -> None:
14
+ super().__init__()
15
+ self.out = torch.nn.Sequential(
16
+ torch.nn.Linear(in_channels, hidden_channels),
17
+ torch.nn.GELU(),
18
+ torch.nn.Linear(hidden_channels, in_channels),
19
+ )
20
+
21
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
22
+ return self.out(x)
23
+
24
+
25
+ class PositionalEncoding(torch.nn.Module):
26
+ def __init__(self, hidden_channels: int,
27
+ max_relative_feature: int = 32) -> None:
28
+ super().__init__()
29
+ self.max_relative_feature = max_relative_feature
30
+ self.emb = torch.nn.Embedding(2 * max_relative_feature + 2,
31
+ hidden_channels)
32
+
33
+ def forward(self, offset, mask) -> torch.Tensor:
34
+ d = torch.clip(offset + self.max_relative_feature, 0,
35
+ 2 * self.max_relative_feature) * mask + (1 - mask) * (
36
+ 2 * self.max_relative_feature + 1) # noqa: E501
37
+ return self.emb(d.long())
38
+
39
+
40
+ class Encoder(MessagePassing):
41
+ def __init__(
42
+ self,
43
+ in_channels: int,
44
+ hidden_channels: int,
45
+ dropout: float = 0.1,
46
+ scale: float = 30,
47
+ ) -> None:
48
+ super().__init__()
49
+ self.out_v = torch.nn.Sequential(
50
+ torch.nn.Linear(in_channels, hidden_channels),
51
+ torch.nn.GELU(),
52
+ torch.nn.Linear(hidden_channels, hidden_channels),
53
+ torch.nn.GELU(),
54
+ torch.nn.Linear(hidden_channels, hidden_channels),
55
+ )
56
+ self.out_e = torch.nn.Sequential(
57
+ torch.nn.Linear(in_channels, hidden_channels),
58
+ torch.nn.GELU(),
59
+ torch.nn.Linear(hidden_channels, hidden_channels),
60
+ torch.nn.GELU(),
61
+ torch.nn.Linear(hidden_channels, hidden_channels),
62
+ )
63
+ self.dropout1 = torch.nn.Dropout(dropout)
64
+ self.dropout2 = torch.nn.Dropout(dropout)
65
+ self.dropout3 = torch.nn.Dropout(dropout)
66
+ self.norm1 = torch.nn.LayerNorm(hidden_channels)
67
+ self.norm2 = torch.nn.LayerNorm(hidden_channels)
68
+ self.norm3 = torch.nn.LayerNorm(hidden_channels)
69
+ self.scale = scale
70
+ self.dense = PositionWiseFeedForward(hidden_channels,
71
+ hidden_channels * 4)
72
+
73
+ def forward(
74
+ self,
75
+ x: torch.Tensor,
76
+ edge_index: torch.Tensor,
77
+ edge_attr: torch.Tensor,
78
+ ) -> torch.Tensor:
79
+ # x: [N, d_v]
80
+ # edge_index: [2, E]
81
+ # edge_attr: [E, d_e]
82
+ # update node features
83
+ h_message = self.propagate(x=x, edge_index=edge_index,
84
+ edge_attr=edge_attr)
85
+ dh = h_message / self.scale
86
+ x = self.norm1(x + self.dropout1(dh))
87
+ dh = self.dense(x)
88
+ x = self.norm2(x + self.dropout2(dh))
89
+ # update edge features
90
+ row, col = edge_index
91
+ x_i, x_j = x[row], x[col]
92
+ h_e = torch.cat([x_i, x_j, edge_attr], dim=-1)
93
+ h_e = self.out_e(h_e)
94
+ edge_attr = self.norm3(edge_attr + self.dropout3(h_e))
95
+ return x, edge_attr
96
+
97
+ def message(self, x_i: torch.Tensor, x_j: torch.Tensor,
98
+ edge_attr: torch.Tensor) -> torch.Tensor:
99
+ h = torch.cat([x_i, x_j, edge_attr], dim=-1) # [E, 2*d_v + d_e]
100
+ h = self.out_e(h) # [E, d_e]
101
+ return h
102
+
103
+
104
+ class Decoder(MessagePassing):
105
+ def __init__(
106
+ self,
107
+ in_channels: int,
108
+ hidden_channels: int,
109
+ dropout: float = 0.1,
110
+ scale: float = 30,
111
+ ) -> None:
112
+ super().__init__()
113
+ self.out_v = torch.nn.Sequential(
114
+ torch.nn.Linear(in_channels, hidden_channels),
115
+ torch.nn.GELU(),
116
+ torch.nn.Linear(hidden_channels, hidden_channels),
117
+ torch.nn.GELU(),
118
+ torch.nn.Linear(hidden_channels, hidden_channels),
119
+ )
120
+ self.dropout1 = torch.nn.Dropout(dropout)
121
+ self.dropout2 = torch.nn.Dropout(dropout)
122
+ self.norm1 = torch.nn.LayerNorm(hidden_channels)
123
+ self.norm2 = torch.nn.LayerNorm(hidden_channels)
124
+ self.scale = scale
125
+ self.dense = PositionWiseFeedForward(hidden_channels,
126
+ hidden_channels * 4)
127
+
128
+ def forward(
129
+ self,
130
+ x: torch.Tensor,
131
+ edge_index: torch.Tensor,
132
+ edge_attr: torch.Tensor,
133
+ x_label: torch.Tensor,
134
+ mask: torch.Tensor,
135
+ ) -> torch.Tensor:
136
+ # x: [N, d_v]
137
+ # edge_index: [2, E]
138
+ # edge_attr: [E, d_e]
139
+ h_message = self.propagate(x=x, x_label=x_label, edge_index=edge_index,
140
+ edge_attr=edge_attr, mask=mask)
141
+ dh = h_message / self.scale
142
+ x = self.norm1(x + self.dropout1(dh))
143
+ dh = self.dense(x)
144
+ x = self.norm2(x + self.dropout2(dh))
145
+ return x
146
+
147
+ def message(self, x_i: torch.Tensor, x_j: torch.Tensor,
148
+ x_label_j: torch.Tensor, edge_attr: torch.Tensor,
149
+ mask: torch.Tensor) -> torch.Tensor:
150
+ h_1 = torch.cat([x_j, edge_attr, x_label_j], dim=-1)
151
+ h_0 = torch.cat([x_j, edge_attr, torch.zeros_like(x_label_j)], dim=-1)
152
+ h = h_1 * mask + h_0 * (1 - mask)
153
+ h = torch.concat([x_i, h], dim=-1)
154
+ h = self.out_v(h)
155
+ return h
156
+
157
+
158
+ class ProteinMPNN(torch.nn.Module):
159
+ def __init__(
160
+ self,
161
+ hidden_dim: int = 128,
162
+ num_encoder_layers: int = 3,
163
+ num_decoder_layers: int = 3,
164
+ num_neighbors: int = 30,
165
+ num_rbf: int = 16,
166
+ dropout: float = 0.1,
167
+ augment_eps: float = 0.2,
168
+ num_positional_embedding: int = 16,
169
+ vocab_size: int = 21,
170
+ ) -> None:
171
+ super().__init__()
172
+ self.augment_eps = augment_eps
173
+ self.hidden_dim = hidden_dim
174
+ self.num_neighbors = num_neighbors
175
+ self.num_rbf = num_rbf
176
+ self.embedding = PositionalEncoding(num_positional_embedding)
177
+ self.edge_mlp = torch.nn.Sequential(
178
+ torch.nn.Linear(num_positional_embedding + 400, hidden_dim),
179
+ torch.nn.LayerNorm(hidden_dim),
180
+ torch.nn.Linear(hidden_dim, hidden_dim),
181
+ )
182
+ self.label_embedding = torch.nn.Embedding(vocab_size, hidden_dim)
183
+ self.encoder_layers = torch.nn.ModuleList([
184
+ Encoder(hidden_dim * 3, hidden_dim, dropout)
185
+ for _ in range(num_encoder_layers)
186
+ ])
187
+
188
+ self.decoder_layers = torch.nn.ModuleList([
189
+ Decoder(hidden_dim * 4, hidden_dim, dropout)
190
+ for _ in range(num_decoder_layers)
191
+ ])
192
+ self.output = torch.nn.Linear(hidden_dim, vocab_size)
193
+
194
+ self.reset_parameters()
195
+
196
+ def reset_parameters(self):
197
+ for p in self.parameters():
198
+ if p.dim() > 1:
199
+ torch.nn.init.xavier_uniform_(p)
200
+
201
+ def _featurize(
202
+ self,
203
+ x: torch.Tensor,
204
+ mask: torch.Tensor,
205
+ batch: torch.Tensor,
206
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
207
+ N, Ca, C, O = (x[:, i, :] for i in range(4)) # noqa: E741
208
+ b = Ca - N
209
+ c = C - Ca
210
+ a = torch.cross(b, c, dim=-1)
211
+ Cb = -0.58273431 * a + 0.56802827 * b - 0.54067466 * c + Ca
212
+
213
+ valid_mask = mask.bool()
214
+ valid_Ca = Ca[valid_mask]
215
+ valid_batch = batch[valid_mask]
216
+
217
+ edge_index = knn_graph(valid_Ca, k=self.num_neighbors,
218
+ batch=valid_batch, loop=True)
219
+
220
+ row, col = edge_index
221
+ original_indices = torch.arange(Ca.size(0),
222
+ device=x.device)[valid_mask]
223
+ edge_index_original = torch.stack(
224
+ [original_indices[row], original_indices[col]], dim=0)
225
+ row, col = edge_index_original
226
+
227
+ rbf_all = []
228
+ for A, B in list(product([N, Ca, C, O, Cb], repeat=2)):
229
+ distances = torch.sqrt(torch.sum((A[row] - B[col])**2, 1) + 1e-6)
230
+ rbf = self._rbf(distances)
231
+ rbf_all.append(rbf)
232
+
233
+ return edge_index_original, torch.cat(rbf_all, dim=-1)
234
+
235
+ def _rbf(self, D: torch.Tensor) -> torch.Tensor:
236
+ D_min, D_max, D_count = 2., 22., self.num_rbf
237
+ D_mu = torch.linspace(D_min, D_max, D_count, device=D.device)
238
+ D_mu = D_mu.view([1, -1])
239
+ D_sigma = (D_max - D_min) / D_count
240
+ D_expand = torch.unsqueeze(D, -1)
241
+ RBF = torch.exp(-((D_expand - D_mu) / D_sigma)**2)
242
+ return RBF
243
+
244
+ def forward(
245
+ self,
246
+ x: torch.Tensor,
247
+ chain_seq_label: torch.Tensor,
248
+ mask: torch.Tensor,
249
+ chain_mask_all: torch.Tensor,
250
+ residue_idx: torch.Tensor,
251
+ chain_encoding_all: torch.Tensor,
252
+ batch: torch.Tensor,
253
+ ) -> torch.Tensor:
254
+ device = x.device
255
+ if self.training and self.augment_eps > 0:
256
+ x = x + self.augment_eps * torch.randn_like(x)
257
+
258
+ edge_index, edge_attr = self._featurize(x, mask, batch)
259
+
260
+ row, col = edge_index
261
+ offset = residue_idx[row] - residue_idx[col]
262
+ # find self vs non-self interaction
263
+ e_chains = ((chain_encoding_all[row] -
264
+ chain_encoding_all[col]) == 0).long()
265
+ e_pos = self.embedding(offset, e_chains)
266
+ h_e = self.edge_mlp(torch.cat([edge_attr, e_pos], dim=-1))
267
+ h_v = torch.zeros(x.size(0), self.hidden_dim, device=x.device)
268
+
269
+ # encoder
270
+ for encoder in self.encoder_layers:
271
+ h_v, h_e = encoder(h_v, edge_index, h_e)
272
+
273
+ # mask
274
+ h_label = self.label_embedding(chain_seq_label)
275
+ batch_chain_mask_all, _ = to_dense_batch(chain_mask_all * mask,
276
+ batch) # [B, N]
277
+ # 0 - visible - encoder, 1 - masked - decoder
278
+ decoding_order = torch.argsort(
279
+ (batch_chain_mask_all + 1e-4) * (torch.abs(
280
+ torch.randn(batch_chain_mask_all.shape, device=device))))
281
+ mask_size = batch_chain_mask_all.size(1)
282
+ permutation_matrix_reverse = F.one_hot(decoding_order,
283
+ num_classes=mask_size).float()
284
+ order_mask_backward = torch.einsum(
285
+ 'ij, biq, bjp->bqp',
286
+ 1 - torch.triu(torch.ones(mask_size, mask_size, device=device)),
287
+ permutation_matrix_reverse,
288
+ permutation_matrix_reverse,
289
+ )
290
+ adj = to_dense_adj(edge_index, batch)
291
+ mask_attend = order_mask_backward[adj.bool()].unsqueeze(-1)
292
+
293
+ # decoder
294
+ for decoder in self.decoder_layers:
295
+ h_v = decoder(
296
+ h_v,
297
+ edge_index,
298
+ h_e,
299
+ h_label,
300
+ mask_attend,
301
+ )
302
+
303
+ logits = self.output(h_v)
304
+ return F.log_softmax(logits, dim=-1)
@@ -452,15 +452,22 @@ def to_cugraph(
452
452
  g = cugraph.Graph(directed=directed)
453
453
  df = cudf.from_dlpack(to_dlpack(edge_index.t()))
454
454
 
455
+ df = cudf.DataFrame({
456
+ 'source':
457
+ cudf.from_dlpack(to_dlpack(edge_index[0])),
458
+ 'destination':
459
+ cudf.from_dlpack(to_dlpack(edge_index[1])),
460
+ })
461
+
455
462
  if edge_weight is not None:
456
463
  assert edge_weight.dim() == 1
457
- df['2'] = cudf.from_dlpack(to_dlpack(edge_weight))
464
+ df['weight'] = cudf.from_dlpack(to_dlpack(edge_weight))
458
465
 
459
466
  g.from_cudf_edgelist(
460
467
  df,
461
- source=0,
462
- destination=1,
463
- edge_attr='2' if edge_weight is not None else None,
468
+ source='source',
469
+ destination='destination',
470
+ edge_attr='weight' if edge_weight is not None else None,
464
471
  renumber=relabel_nodes,
465
472
  )
466
473
 
@@ -476,13 +483,13 @@ def from_cugraph(g: Any) -> Tuple[Tensor, Optional[Tensor]]:
476
483
  """
477
484
  df = g.view_edge_list()
478
485
 
479
- src = from_dlpack(df[0].to_dlpack()).long()
480
- dst = from_dlpack(df[1].to_dlpack()).long()
486
+ src = from_dlpack(df[g.source_columns].to_dlpack()).long()
487
+ dst = from_dlpack(df[g.destination_columns].to_dlpack()).long()
481
488
  edge_index = torch.stack([src, dst], dim=0)
482
489
 
483
490
  edge_weight = None
484
- if '2' in df:
485
- edge_weight = from_dlpack(df['2'].to_dlpack())
491
+ if g.weight_column is not None:
492
+ edge_weight = from_dlpack(df[g.weight_column].to_dlpack())
486
493
 
487
494
  return edge_index, edge_weight
488
495
 
@@ -148,7 +148,7 @@ def from_smiles(
148
148
  """
149
149
  from rdkit import Chem, RDLogger
150
150
 
151
- RDLogger.DisableLog('rdApp.*') # type: ignore
151
+ RDLogger.DisableLog('rdApp.*')
152
152
 
153
153
  mol = Chem.MolFromSmiles(smiles)
154
154