pyg-nightly 2.7.0.dev20250702__py3-none-any.whl → 2.7.0.dev20250704__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pyg-nightly might be problematic. Click here for more details.
- {pyg_nightly-2.7.0.dev20250702.dist-info → pyg_nightly-2.7.0.dev20250704.dist-info}/METADATA +1 -1
- {pyg_nightly-2.7.0.dev20250702.dist-info → pyg_nightly-2.7.0.dev20250704.dist-info}/RECORD +20 -15
- torch_geometric/__init__.py +1 -1
- torch_geometric/datasets/__init__.py +4 -0
- torch_geometric/datasets/git_mol_dataset.py +1 -1
- torch_geometric/datasets/molecule_gpt_dataset.py +1 -1
- torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
- torch_geometric/datasets/qm9.py +1 -1
- torch_geometric/datasets/teeth3ds.py +269 -0
- torch_geometric/nn/attention/__init__.py +2 -0
- torch_geometric/nn/attention/polynormer.py +107 -0
- torch_geometric/nn/conv/meshcnn_conv.py +9 -15
- torch_geometric/nn/models/__init__.py +4 -0
- torch_geometric/nn/models/glem.py +7 -3
- torch_geometric/nn/models/polynormer.py +206 -0
- torch_geometric/nn/models/protein_mpnn.py +304 -0
- torch_geometric/utils/convert.py +15 -8
- torch_geometric/utils/smiles.py +1 -1
- {pyg_nightly-2.7.0.dev20250702.dist-info → pyg_nightly-2.7.0.dev20250704.dist-info}/WHEEL +0 -0
- {pyg_nightly-2.7.0.dev20250702.dist-info → pyg_nightly-2.7.0.dev20250704.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,269 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import os
|
|
3
|
+
import os.path as osp
|
|
4
|
+
from glob import glob
|
|
5
|
+
from typing import Callable, Dict, List, Optional
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import torch
|
|
9
|
+
from tqdm import tqdm
|
|
10
|
+
|
|
11
|
+
from torch_geometric.data import (
|
|
12
|
+
Data,
|
|
13
|
+
InMemoryDataset,
|
|
14
|
+
download_url,
|
|
15
|
+
extract_zip,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class Teeth3DS(InMemoryDataset):
|
|
20
|
+
r"""The Teeth3DS+ dataset from the `"An Extended Benchmark for Intra-oral
|
|
21
|
+
3D Scans Analysis" <https://crns-smartvision.github.io/teeth3ds/>`_ paper.
|
|
22
|
+
|
|
23
|
+
This dataset is the first comprehensive public benchmark designed to
|
|
24
|
+
advance the field of intra-oral 3D scan analysis developed as part of the
|
|
25
|
+
3DTeethSeg 2022 and 3DTeethLand 2024 MICCAI challenges, aiming to drive
|
|
26
|
+
research in teeth identification, segmentation, labeling, 3D modeling,
|
|
27
|
+
and dental landmark identification.
|
|
28
|
+
The dataset includes at least 1,800 intra-oral scans (containing 23,999
|
|
29
|
+
annotated teeth) collected from 900 patients, covering both upper and lower
|
|
30
|
+
jaws separately.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
root (str): Root directory where the dataset should be saved.
|
|
34
|
+
split (str): The split name (one of :obj:`"Teeth3DS"`,
|
|
35
|
+
:obj:`"3DTeethSeg22_challenge"` or :obj:`"3DTeethLand_challenge"`).
|
|
36
|
+
train (bool, optional): If :obj:`True`, loads the training dataset,
|
|
37
|
+
otherwise the test dataset. (default: :obj:`True`)
|
|
38
|
+
num_samples (int, optional): Number of points to sample from each mesh.
|
|
39
|
+
(default: :obj:`30000`)
|
|
40
|
+
transform (callable, optional): A function/transform that takes in an
|
|
41
|
+
:obj:`torch_geometric.data.Data` object and returns a transformed
|
|
42
|
+
version. The data object will be transformed before every access.
|
|
43
|
+
(default: :obj:`None`)
|
|
44
|
+
pre_transform (callable, optional): A function/transform that takes in
|
|
45
|
+
an :obj:`torch_geometric.data.Data` object and returns a
|
|
46
|
+
transformed version. The data object will be transformed before
|
|
47
|
+
being saved to disk. (default: :obj:`None`)
|
|
48
|
+
force_reload (bool, optional): Whether to re-process the dataset.
|
|
49
|
+
(default: :obj:`False`)
|
|
50
|
+
"""
|
|
51
|
+
urls = {
|
|
52
|
+
'data_part_1.zip':
|
|
53
|
+
'https://osf.io/download/qhprs/',
|
|
54
|
+
'data_part_2.zip':
|
|
55
|
+
'https://osf.io/download/4pwnr/',
|
|
56
|
+
'data_part_3.zip':
|
|
57
|
+
'https://osf.io/download/frwdp/',
|
|
58
|
+
'data_part_4.zip':
|
|
59
|
+
'https://osf.io/download/2arn4/',
|
|
60
|
+
'data_part_5.zip':
|
|
61
|
+
'https://osf.io/download/xrz5f/',
|
|
62
|
+
'data_part_6.zip':
|
|
63
|
+
'https://osf.io/download/23hgq/',
|
|
64
|
+
'data_part_7.zip':
|
|
65
|
+
'https://osf.io/download/u83ad/',
|
|
66
|
+
'train_test_split':
|
|
67
|
+
'https://files.de-1.osf.io/v1/'
|
|
68
|
+
'resources/xctdy/providers/osfstorage/?zip='
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
sample_url = {
|
|
72
|
+
'teeth3ds_sample': 'https://osf.io/download/vr38s/',
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
landmarks_urls = {
|
|
76
|
+
'3DTeethLand_landmarks_train.zip': 'https://osf.io/download/k5hbj/',
|
|
77
|
+
'3DTeethLand_landmarks_test.zip': 'https://osf.io/download/sqw5e/',
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
def __init__(
|
|
81
|
+
self,
|
|
82
|
+
root: str,
|
|
83
|
+
split:
|
|
84
|
+
str = 'Teeth3DS', # [3DTeethSeg22_challenge, 3DTeethLand_challenge]
|
|
85
|
+
train: bool = True,
|
|
86
|
+
num_samples: int = 30000,
|
|
87
|
+
transform: Optional[Callable] = None,
|
|
88
|
+
pre_transform: Optional[Callable] = None,
|
|
89
|
+
force_reload: bool = False,
|
|
90
|
+
) -> None:
|
|
91
|
+
|
|
92
|
+
self.mode = 'training' if train else 'testing'
|
|
93
|
+
self.split = split
|
|
94
|
+
self.num_samples = num_samples
|
|
95
|
+
|
|
96
|
+
super().__init__(root, transform, pre_transform,
|
|
97
|
+
force_reload=force_reload)
|
|
98
|
+
|
|
99
|
+
@property
|
|
100
|
+
def processed_dir(self) -> str:
|
|
101
|
+
return os.path.join(self.root, f'processed_{self.split}_{self.mode}')
|
|
102
|
+
|
|
103
|
+
@property
|
|
104
|
+
def raw_file_names(self) -> List[str]:
|
|
105
|
+
return ['license.txt']
|
|
106
|
+
|
|
107
|
+
@property
|
|
108
|
+
def processed_file_names(self) -> List[str]:
|
|
109
|
+
# Directory containing train/test split files:
|
|
110
|
+
split_subdir = 'teeth3ds_sample' if self.split == 'sample' else ''
|
|
111
|
+
split_dir = osp.join(
|
|
112
|
+
self.raw_dir,
|
|
113
|
+
split_subdir,
|
|
114
|
+
f'{self.split}_train_test_split',
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
split_files = glob(osp.join(split_dir, f'{self.mode}*.txt'))
|
|
118
|
+
|
|
119
|
+
# Collect all file names from the split files:
|
|
120
|
+
combined_list = []
|
|
121
|
+
for file_path in split_files:
|
|
122
|
+
with open(file_path) as file:
|
|
123
|
+
combined_list.extend(file.read().splitlines())
|
|
124
|
+
|
|
125
|
+
# Generate the list of processed file paths:
|
|
126
|
+
return [f'{file_name}.pt' for file_name in combined_list]
|
|
127
|
+
|
|
128
|
+
def download(self) -> None:
|
|
129
|
+
if self.split == 'sample':
|
|
130
|
+
for key, url in self.sample_url.items():
|
|
131
|
+
path = download_url(url, self.root, filename=key)
|
|
132
|
+
extract_zip(path, self.raw_dir)
|
|
133
|
+
os.unlink(path)
|
|
134
|
+
else:
|
|
135
|
+
for key, url in self.urls.items():
|
|
136
|
+
path = download_url(url, self.root, filename=key)
|
|
137
|
+
extract_zip(path, self.raw_dir)
|
|
138
|
+
os.unlink(path)
|
|
139
|
+
for key, url in self.landmarks_urls.items():
|
|
140
|
+
path = download_url(url, self.root, filename=key)
|
|
141
|
+
extract_zip(path, self.raw_dir) # Extract each downloaded part
|
|
142
|
+
os.unlink(path)
|
|
143
|
+
|
|
144
|
+
def process_file(self, file_path: str) -> Optional[Data]:
|
|
145
|
+
"""Processes the input file path to load mesh data, annotations,
|
|
146
|
+
and prepare the input features for a graph-based deep learning model.
|
|
147
|
+
"""
|
|
148
|
+
import trimesh
|
|
149
|
+
from fpsample import bucket_fps_kdline_sampling
|
|
150
|
+
|
|
151
|
+
mesh = trimesh.load_mesh(file_path)
|
|
152
|
+
|
|
153
|
+
if isinstance(mesh, list):
|
|
154
|
+
# Handle the case where a list of Geometry objects is returned
|
|
155
|
+
mesh = mesh[0]
|
|
156
|
+
|
|
157
|
+
vertices = mesh.vertices
|
|
158
|
+
vertex_normals = mesh.vertex_normals
|
|
159
|
+
|
|
160
|
+
# Perform sampling on mesh vertices:
|
|
161
|
+
if len(vertices) < self.num_samples:
|
|
162
|
+
sampled_indices = np.random.choice(
|
|
163
|
+
len(vertices),
|
|
164
|
+
self.num_samples,
|
|
165
|
+
replace=True,
|
|
166
|
+
)
|
|
167
|
+
else:
|
|
168
|
+
sampled_indices = bucket_fps_kdline_sampling(
|
|
169
|
+
vertices,
|
|
170
|
+
self.num_samples,
|
|
171
|
+
h=5,
|
|
172
|
+
start_idx=0,
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
if len(sampled_indices) != self.num_samples:
|
|
176
|
+
raise RuntimeError(f"Sampled points mismatch, expected "
|
|
177
|
+
f"{self.num_samples} points, but got "
|
|
178
|
+
f"{len(sampled_indices)} for '{file_path}'")
|
|
179
|
+
|
|
180
|
+
# Extract features and annotations for the sampled points:
|
|
181
|
+
pos = torch.tensor(vertices[sampled_indices], dtype=torch.float)
|
|
182
|
+
x = torch.tensor(vertex_normals[sampled_indices], dtype=torch.float)
|
|
183
|
+
|
|
184
|
+
# Load segmentation annotations:
|
|
185
|
+
seg_annotation_path = file_path.replace('.obj', '.json')
|
|
186
|
+
if osp.exists(seg_annotation_path):
|
|
187
|
+
with open(seg_annotation_path) as f:
|
|
188
|
+
seg_annotations = json.load(f)
|
|
189
|
+
y = torch.tensor(
|
|
190
|
+
np.asarray(seg_annotations['labels'])[sampled_indices],
|
|
191
|
+
dtype=torch.float)
|
|
192
|
+
instances = torch.tensor(
|
|
193
|
+
np.asarray(seg_annotations['instances'])[sampled_indices],
|
|
194
|
+
dtype=torch.float)
|
|
195
|
+
else:
|
|
196
|
+
y = torch.empty(0, 3)
|
|
197
|
+
instances = torch.empty(0, 3)
|
|
198
|
+
|
|
199
|
+
# Load landmarks annotations:
|
|
200
|
+
landmarks_annotation_path = file_path.replace('.obj', '__kpt.json')
|
|
201
|
+
|
|
202
|
+
# Parse keypoint annotations into structured tensors:
|
|
203
|
+
keypoints_dict: Dict[str, List] = {
|
|
204
|
+
key: []
|
|
205
|
+
for key in [
|
|
206
|
+
'Mesial', 'Distal', 'Cusp', 'InnerPoint', 'OuterPoint',
|
|
207
|
+
'FacialPoint'
|
|
208
|
+
]
|
|
209
|
+
}
|
|
210
|
+
keypoint_tensors: Dict[str, torch.Tensor] = {
|
|
211
|
+
key: torch.empty(0, 3)
|
|
212
|
+
for key in [
|
|
213
|
+
'Mesial', 'Distal', 'Cusp', 'InnerPoint', 'OuterPoint',
|
|
214
|
+
'FacialPoint'
|
|
215
|
+
]
|
|
216
|
+
}
|
|
217
|
+
if osp.exists(landmarks_annotation_path):
|
|
218
|
+
with open(landmarks_annotation_path) as f:
|
|
219
|
+
landmarks_annotations = json.load(f)
|
|
220
|
+
|
|
221
|
+
for keypoint in landmarks_annotations['objects']:
|
|
222
|
+
keypoints_dict[keypoint['class']].extend(keypoint['coord'])
|
|
223
|
+
|
|
224
|
+
keypoint_tensors = {
|
|
225
|
+
k: torch.tensor(np.asarray(v),
|
|
226
|
+
dtype=torch.float).reshape(-1, 3)
|
|
227
|
+
for k, v in keypoints_dict.items()
|
|
228
|
+
}
|
|
229
|
+
|
|
230
|
+
data = Data(
|
|
231
|
+
pos=pos,
|
|
232
|
+
x=x,
|
|
233
|
+
y=y,
|
|
234
|
+
instances=instances,
|
|
235
|
+
jaw=file_path.split('.obj')[0].split('_')[1],
|
|
236
|
+
mesial=keypoint_tensors['Mesial'],
|
|
237
|
+
distal=keypoint_tensors['Distal'],
|
|
238
|
+
cusp=keypoint_tensors['Cusp'],
|
|
239
|
+
inner_point=keypoint_tensors['InnerPoint'],
|
|
240
|
+
outer_point=keypoint_tensors['OuterPoint'],
|
|
241
|
+
facial_point=keypoint_tensors['FacialPoint'],
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
if self.pre_transform is not None:
|
|
245
|
+
data = self.pre_transform(data)
|
|
246
|
+
|
|
247
|
+
return data
|
|
248
|
+
|
|
249
|
+
def process(self) -> None:
|
|
250
|
+
for file in tqdm(self.processed_file_names):
|
|
251
|
+
name = file.split('.')[0]
|
|
252
|
+
path = osp.join(self.raw_dir, '**', '*', name + '.obj')
|
|
253
|
+
paths = glob(path)
|
|
254
|
+
if len(paths) == 1:
|
|
255
|
+
data = self.process_file(paths[0])
|
|
256
|
+
torch.save(data, osp.join(self.processed_dir, file))
|
|
257
|
+
|
|
258
|
+
def len(self) -> int:
|
|
259
|
+
return len(self.processed_file_names)
|
|
260
|
+
|
|
261
|
+
def get(self, idx: int) -> Data:
|
|
262
|
+
return torch.load(
|
|
263
|
+
osp.join(self.processed_dir, self.processed_file_names[idx]),
|
|
264
|
+
weights_only=False,
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
def __repr__(self) -> str:
|
|
268
|
+
return (f'{self.__class__.__name__}({len(self)}, '
|
|
269
|
+
f'mode={self.mode}, split={self.split})')
|
|
@@ -1,9 +1,11 @@
|
|
|
1
1
|
from .performer import PerformerAttention
|
|
2
2
|
from .qformer import QFormer
|
|
3
3
|
from .sgformer import SGFormerAttention
|
|
4
|
+
from .polynormer import PolynormerAttention
|
|
4
5
|
|
|
5
6
|
__all__ = [
|
|
6
7
|
'PerformerAttention',
|
|
7
8
|
'QFormer',
|
|
8
9
|
'SGFormerAttention',
|
|
10
|
+
'PolynormerAttention',
|
|
9
11
|
]
|
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn.functional as F
|
|
5
|
+
from torch import Tensor
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class PolynormerAttention(torch.nn.Module):
|
|
9
|
+
r"""The polynomial-expressive attention mechanism from the
|
|
10
|
+
`"Polynormer: Polynomial-Expressive Graph Transformer in Linear Time"
|
|
11
|
+
<https://arxiv.org/abs/2403.01232>`_ paper.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
channels (int): Size of each input sample.
|
|
15
|
+
heads (int, optional): Number of parallel attention heads.
|
|
16
|
+
head_channels (int, optional): Size of each attention head.
|
|
17
|
+
(default: :obj:`64.`)
|
|
18
|
+
beta (float, optional): Polynormer beta initialization.
|
|
19
|
+
(default: :obj:`0.9`)
|
|
20
|
+
qkv_bias (bool, optional): If specified, add bias to query, key
|
|
21
|
+
and value in the self attention. (default: :obj:`False`)
|
|
22
|
+
qk_shared (bool optional): Whether weight of query and key are shared.
|
|
23
|
+
(default: :obj:`True`)
|
|
24
|
+
dropout (float, optional): Dropout probability of the final
|
|
25
|
+
attention output. (default: :obj:`0.0`)
|
|
26
|
+
"""
|
|
27
|
+
def __init__(
|
|
28
|
+
self,
|
|
29
|
+
channels: int,
|
|
30
|
+
heads: int,
|
|
31
|
+
head_channels: int = 64,
|
|
32
|
+
beta: float = 0.9,
|
|
33
|
+
qkv_bias: bool = False,
|
|
34
|
+
qk_shared: bool = True,
|
|
35
|
+
dropout: float = 0.0,
|
|
36
|
+
) -> None:
|
|
37
|
+
super().__init__()
|
|
38
|
+
|
|
39
|
+
self.head_channels = head_channels
|
|
40
|
+
self.heads = heads
|
|
41
|
+
self.beta = beta
|
|
42
|
+
self.qk_shared = qk_shared
|
|
43
|
+
|
|
44
|
+
inner_channels = heads * head_channels
|
|
45
|
+
self.h_lins = torch.nn.Linear(channels, inner_channels)
|
|
46
|
+
if not self.qk_shared:
|
|
47
|
+
self.q = torch.nn.Linear(channels, inner_channels, bias=qkv_bias)
|
|
48
|
+
self.k = torch.nn.Linear(channels, inner_channels, bias=qkv_bias)
|
|
49
|
+
self.v = torch.nn.Linear(channels, inner_channels, bias=qkv_bias)
|
|
50
|
+
self.lns = torch.nn.LayerNorm(inner_channels)
|
|
51
|
+
self.lin_out = torch.nn.Linear(inner_channels, inner_channels)
|
|
52
|
+
self.dropout = torch.nn.Dropout(dropout)
|
|
53
|
+
|
|
54
|
+
def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor:
|
|
55
|
+
r"""Forward pass.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
x (torch.Tensor): Node feature tensor
|
|
59
|
+
:math:`\mathbf{X} \in \mathbb{R}^{B \times N \times F}`, with
|
|
60
|
+
batch-size :math:`B`, (maximum) number of nodes :math:`N` for
|
|
61
|
+
each graph, and feature dimension :math:`F`.
|
|
62
|
+
mask (torch.Tensor, optional): Mask matrix
|
|
63
|
+
:math:`\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}` indicating
|
|
64
|
+
the valid nodes for each graph. (default: :obj:`None`)
|
|
65
|
+
"""
|
|
66
|
+
B, N, *_ = x.shape
|
|
67
|
+
h = self.h_lins(x)
|
|
68
|
+
k = self.k(x).sigmoid().view(B, N, self.head_channels, self.heads)
|
|
69
|
+
if self.qk_shared:
|
|
70
|
+
q = k
|
|
71
|
+
else:
|
|
72
|
+
q = F.sigmoid(self.q(x)).view(B, N, self.head_channels, self.heads)
|
|
73
|
+
v = self.v(x).view(B, N, self.head_channels, self.heads)
|
|
74
|
+
|
|
75
|
+
if mask is not None:
|
|
76
|
+
mask = mask[:, :, None, None]
|
|
77
|
+
v.masked_fill_(~mask, 0.)
|
|
78
|
+
|
|
79
|
+
# numerator
|
|
80
|
+
kv = torch.einsum('bndh, bnmh -> bdmh', k, v)
|
|
81
|
+
num = torch.einsum('bndh, bdmh -> bnmh', q, kv)
|
|
82
|
+
|
|
83
|
+
# denominator
|
|
84
|
+
k_sum = torch.einsum('bndh -> bdh', k)
|
|
85
|
+
den = torch.einsum('bndh, bdh -> bnh', q, k_sum).unsqueeze(2)
|
|
86
|
+
|
|
87
|
+
# linear global attention based on kernel trick
|
|
88
|
+
x = (num / (den + 1e-6)).reshape(B, N, -1)
|
|
89
|
+
x = self.lns(x) * (h + self.beta)
|
|
90
|
+
x = F.relu(self.lin_out(x))
|
|
91
|
+
x = self.dropout(x)
|
|
92
|
+
|
|
93
|
+
return x
|
|
94
|
+
|
|
95
|
+
def reset_parameters(self) -> None:
|
|
96
|
+
self.h_lins.reset_parameters()
|
|
97
|
+
if not self.qk_shared:
|
|
98
|
+
self.q.reset_parameters()
|
|
99
|
+
self.k.reset_parameters()
|
|
100
|
+
self.v.reset_parameters()
|
|
101
|
+
self.lns.reset_parameters()
|
|
102
|
+
self.lin_out.reset_parameters()
|
|
103
|
+
|
|
104
|
+
def __repr__(self) -> str:
|
|
105
|
+
return (f'{self.__class__.__name__}('
|
|
106
|
+
f'heads={self.heads}, '
|
|
107
|
+
f'head_channels={self.head_channels})')
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
# The below is to suppress the warning on torch.nn.conv.MeshCNNConv::update
|
|
2
2
|
# pyright: reportIncompatibleMethodOverride=false
|
|
3
|
+
import warnings
|
|
3
4
|
from typing import Optional
|
|
4
|
-
from warnings import warn
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
7
|
from torch.nn import Linear, Module, ModuleList
|
|
@@ -456,13 +456,10 @@ class MeshCNNConv(MessagePassing):
|
|
|
456
456
|
{type(network)}"
|
|
457
457
|
if not hasattr(network, "in_channels") and \
|
|
458
458
|
not hasattr(network, "in_features"):
|
|
459
|
-
warn(
|
|
460
|
-
f"kernel[{i}] does not have attribute
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
{self.in_channels}-dimensional tensor. \
|
|
464
|
-
Still, assuming user configured \
|
|
465
|
-
correctly. Continuing..", stacklevel=2)
|
|
459
|
+
warnings.warn(
|
|
460
|
+
f"kernel[{i}] does not have attribute 'in_channels' nor "
|
|
461
|
+
f"'out_features'. The network must take as input a "
|
|
462
|
+
f"{self.in_channels}-dimensional tensor.", stacklevel=2)
|
|
466
463
|
else:
|
|
467
464
|
input_dimension = getattr(network, "in_channels",
|
|
468
465
|
network.in_features)
|
|
@@ -475,13 +472,10 @@ class MeshCNNConv(MessagePassing):
|
|
|
475
472
|
|
|
476
473
|
if not hasattr(network, "out_channels") and \
|
|
477
474
|
not hasattr(network, "out_features"):
|
|
478
|
-
warn(
|
|
479
|
-
f"kernel[{i}] does not have attribute
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
{self.in_channels}-dimensional tensor. \
|
|
483
|
-
Still, assuming user configured \
|
|
484
|
-
correctly. Continuing..", stacklevel=2)
|
|
475
|
+
warnings.warn(
|
|
476
|
+
f"kernel[{i}] does not have attribute 'in_channels' nor "
|
|
477
|
+
f"'out_features'. The network must take as input a "
|
|
478
|
+
f"{self.in_channels}-dimensional tensor.", stacklevel=2)
|
|
485
479
|
else:
|
|
486
480
|
output_dimension = getattr(network, "out_channels",
|
|
487
481
|
network.out_features)
|
|
@@ -32,8 +32,10 @@ from .visnet import ViSNet
|
|
|
32
32
|
from .g_retriever import GRetriever
|
|
33
33
|
from .git_mol import GITMol
|
|
34
34
|
from .molecule_gpt import MoleculeGPT
|
|
35
|
+
from .protein_mpnn import ProteinMPNN
|
|
35
36
|
from .glem import GLEM
|
|
36
37
|
from .sgformer import SGFormer
|
|
38
|
+
from .polynormer import Polynormer
|
|
37
39
|
# Deprecated:
|
|
38
40
|
from torch_geometric.explain.algorithm.captum import (to_captum_input,
|
|
39
41
|
captum_output_to_dicts)
|
|
@@ -86,7 +88,9 @@ __all__ = classes = [
|
|
|
86
88
|
'GRetriever',
|
|
87
89
|
'GITMol',
|
|
88
90
|
'MoleculeGPT',
|
|
91
|
+
'ProteinMPNN',
|
|
89
92
|
'GLEM',
|
|
90
93
|
'SGFormer',
|
|
94
|
+
'Polynormer',
|
|
91
95
|
'ARLinkPredictor',
|
|
92
96
|
]
|
|
@@ -8,6 +8,13 @@ from torch_geometric.loader import DataLoader, NeighborLoader
|
|
|
8
8
|
from torch_geometric.nn.models import GraphSAGE, basic_gnn
|
|
9
9
|
|
|
10
10
|
|
|
11
|
+
def deal_nan(x):
|
|
12
|
+
if isinstance(x, torch.Tensor):
|
|
13
|
+
x = x.clone()
|
|
14
|
+
x[torch.isnan(x)] = 0.0
|
|
15
|
+
return x
|
|
16
|
+
|
|
17
|
+
|
|
11
18
|
class GLEM(torch.nn.Module):
|
|
12
19
|
r"""This GNN+LM co-training model is based on GLEM from the `"Learning on
|
|
13
20
|
Large-scale Text-attributed Graphs via Variational Inference"
|
|
@@ -379,9 +386,6 @@ class GLEM(torch.nn.Module):
|
|
|
379
386
|
is_augmented: use EM or just train GNN and LM with gold data
|
|
380
387
|
|
|
381
388
|
"""
|
|
382
|
-
def deal_nan(x):
|
|
383
|
-
return 0 if torch.isnan(x) else x
|
|
384
|
-
|
|
385
389
|
if is_augmented and (sum(~is_gold) > 0):
|
|
386
390
|
mle_loss = deal_nan(loss_func(logits[is_gold], labels[is_gold]))
|
|
387
391
|
# all other labels beside from ground truth(gold labels)
|
|
@@ -0,0 +1,206 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn.functional as F
|
|
5
|
+
from torch import Tensor
|
|
6
|
+
|
|
7
|
+
from torch_geometric.nn import GATConv, GCNConv
|
|
8
|
+
from torch_geometric.nn.attention import PolynormerAttention
|
|
9
|
+
from torch_geometric.utils import to_dense_batch
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class Polynormer(torch.nn.Module):
|
|
13
|
+
r"""The polynormer module from the
|
|
14
|
+
`"Polynormer: polynomial-expressive graph
|
|
15
|
+
transformer in linear time"
|
|
16
|
+
<https://arxiv.org/abs/2403.01232>`_ paper.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
in_channels (int): Input channels.
|
|
20
|
+
hidden_channels (int): Hidden channels.
|
|
21
|
+
out_channels (int): Output channels.
|
|
22
|
+
local_layers (int): The number of local attention layers.
|
|
23
|
+
(default: :obj:`7`)
|
|
24
|
+
global_layers (int): The number of global attention layers.
|
|
25
|
+
(default: :obj:`2`)
|
|
26
|
+
in_dropout (float): Input dropout rate.
|
|
27
|
+
(default: :obj:`0.15`)
|
|
28
|
+
dropout (float): Dropout rate.
|
|
29
|
+
(default: :obj:`0.5`)
|
|
30
|
+
global_dropout (float): Global dropout rate.
|
|
31
|
+
(default: :obj:`0.5`)
|
|
32
|
+
heads (int): The number of heads.
|
|
33
|
+
(default: :obj:`1`)
|
|
34
|
+
beta (float): Aggregate type.
|
|
35
|
+
(default: :obj:`0.9`)
|
|
36
|
+
qk_shared (bool optional): Whether weight of query and key are shared.
|
|
37
|
+
(default: :obj:`True`)
|
|
38
|
+
pre_ln (bool): Pre layer normalization.
|
|
39
|
+
(default: :obj:`False`)
|
|
40
|
+
post_bn (bool): Post batch normlization.
|
|
41
|
+
(default: :obj:`True`)
|
|
42
|
+
local_attn (bool): Whether use local attention.
|
|
43
|
+
(default: :obj:`False`)
|
|
44
|
+
"""
|
|
45
|
+
def __init__(
|
|
46
|
+
self,
|
|
47
|
+
in_channels: int,
|
|
48
|
+
hidden_channels: int,
|
|
49
|
+
out_channels: int,
|
|
50
|
+
local_layers: int = 7,
|
|
51
|
+
global_layers: int = 2,
|
|
52
|
+
in_dropout: float = 0.15,
|
|
53
|
+
dropout: float = 0.5,
|
|
54
|
+
global_dropout: float = 0.5,
|
|
55
|
+
heads: int = 1,
|
|
56
|
+
beta: float = 0.9,
|
|
57
|
+
qk_shared: bool = False,
|
|
58
|
+
pre_ln: bool = False,
|
|
59
|
+
post_bn: bool = True,
|
|
60
|
+
local_attn: bool = False,
|
|
61
|
+
) -> None:
|
|
62
|
+
super().__init__()
|
|
63
|
+
self._global = False
|
|
64
|
+
self.in_drop = in_dropout
|
|
65
|
+
self.dropout = dropout
|
|
66
|
+
self.pre_ln = pre_ln
|
|
67
|
+
self.post_bn = post_bn
|
|
68
|
+
|
|
69
|
+
self.beta = beta
|
|
70
|
+
|
|
71
|
+
self.h_lins = torch.nn.ModuleList()
|
|
72
|
+
self.local_convs = torch.nn.ModuleList()
|
|
73
|
+
self.lins = torch.nn.ModuleList()
|
|
74
|
+
self.lns = torch.nn.ModuleList()
|
|
75
|
+
if self.pre_ln:
|
|
76
|
+
self.pre_lns = torch.nn.ModuleList()
|
|
77
|
+
if self.post_bn:
|
|
78
|
+
self.post_bns = torch.nn.ModuleList()
|
|
79
|
+
|
|
80
|
+
# first layer
|
|
81
|
+
inner_channels = heads * hidden_channels
|
|
82
|
+
self.h_lins.append(torch.nn.Linear(in_channels, inner_channels))
|
|
83
|
+
if local_attn:
|
|
84
|
+
self.local_convs.append(
|
|
85
|
+
GATConv(in_channels, hidden_channels, heads=heads, concat=True,
|
|
86
|
+
add_self_loops=False, bias=False))
|
|
87
|
+
else:
|
|
88
|
+
self.local_convs.append(
|
|
89
|
+
GCNConv(in_channels, inner_channels, cached=False,
|
|
90
|
+
normalize=True))
|
|
91
|
+
|
|
92
|
+
self.lins.append(torch.nn.Linear(in_channels, inner_channels))
|
|
93
|
+
self.lns.append(torch.nn.LayerNorm(inner_channels))
|
|
94
|
+
if self.pre_ln:
|
|
95
|
+
self.pre_lns.append(torch.nn.LayerNorm(in_channels))
|
|
96
|
+
if self.post_bn:
|
|
97
|
+
self.post_bns.append(torch.nn.BatchNorm1d(inner_channels))
|
|
98
|
+
|
|
99
|
+
# following layers
|
|
100
|
+
for _ in range(local_layers - 1):
|
|
101
|
+
self.h_lins.append(torch.nn.Linear(inner_channels, inner_channels))
|
|
102
|
+
if local_attn:
|
|
103
|
+
self.local_convs.append(
|
|
104
|
+
GATConv(inner_channels, hidden_channels, heads=heads,
|
|
105
|
+
concat=True, add_self_loops=False, bias=False))
|
|
106
|
+
else:
|
|
107
|
+
self.local_convs.append(
|
|
108
|
+
GCNConv(inner_channels, inner_channels, cached=False,
|
|
109
|
+
normalize=True))
|
|
110
|
+
|
|
111
|
+
self.lins.append(torch.nn.Linear(inner_channels, inner_channels))
|
|
112
|
+
self.lns.append(torch.nn.LayerNorm(inner_channels))
|
|
113
|
+
if self.pre_ln:
|
|
114
|
+
self.pre_lns.append(torch.nn.LayerNorm(heads *
|
|
115
|
+
hidden_channels))
|
|
116
|
+
if self.post_bn:
|
|
117
|
+
self.post_bns.append(torch.nn.BatchNorm1d(inner_channels))
|
|
118
|
+
|
|
119
|
+
self.lin_in = torch.nn.Linear(in_channels, inner_channels)
|
|
120
|
+
self.ln = torch.nn.LayerNorm(inner_channels)
|
|
121
|
+
|
|
122
|
+
self.global_attn = torch.nn.ModuleList()
|
|
123
|
+
for _ in range(global_layers):
|
|
124
|
+
self.global_attn.append(
|
|
125
|
+
PolynormerAttention(
|
|
126
|
+
channels=hidden_channels,
|
|
127
|
+
heads=heads,
|
|
128
|
+
head_channels=hidden_channels,
|
|
129
|
+
beta=beta,
|
|
130
|
+
dropout=global_dropout,
|
|
131
|
+
qk_shared=qk_shared,
|
|
132
|
+
))
|
|
133
|
+
self.pred_local = torch.nn.Linear(inner_channels, out_channels)
|
|
134
|
+
self.pred_global = torch.nn.Linear(inner_channels, out_channels)
|
|
135
|
+
self.reset_parameters()
|
|
136
|
+
|
|
137
|
+
def reset_parameters(self) -> None:
|
|
138
|
+
for local_conv in self.local_convs:
|
|
139
|
+
local_conv.reset_parameters()
|
|
140
|
+
for attn in self.global_attn:
|
|
141
|
+
attn.reset_parameters()
|
|
142
|
+
for lin in self.lins:
|
|
143
|
+
lin.reset_parameters()
|
|
144
|
+
for h_lin in self.h_lins:
|
|
145
|
+
h_lin.reset_parameters()
|
|
146
|
+
for ln in self.lns:
|
|
147
|
+
ln.reset_parameters()
|
|
148
|
+
if self.pre_ln:
|
|
149
|
+
for p_ln in self.pre_lns:
|
|
150
|
+
p_ln.reset_parameters()
|
|
151
|
+
if self.post_bn:
|
|
152
|
+
for p_bn in self.post_bns:
|
|
153
|
+
p_bn.reset_parameters()
|
|
154
|
+
self.lin_in.reset_parameters()
|
|
155
|
+
self.ln.reset_parameters()
|
|
156
|
+
self.pred_local.reset_parameters()
|
|
157
|
+
self.pred_global.reset_parameters()
|
|
158
|
+
|
|
159
|
+
def forward(
|
|
160
|
+
self,
|
|
161
|
+
x: Tensor,
|
|
162
|
+
edge_index: Tensor,
|
|
163
|
+
batch: Optional[Tensor],
|
|
164
|
+
) -> Tensor:
|
|
165
|
+
r"""Forward pass.
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
x (torch.Tensor): The input node features.
|
|
169
|
+
edge_index (torch.Tensor or SparseTensor): The edge indices.
|
|
170
|
+
batch (torch.Tensor, optional): The batch vector
|
|
171
|
+
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns
|
|
172
|
+
each element to a specific example.
|
|
173
|
+
"""
|
|
174
|
+
x = F.dropout(x, p=self.in_drop, training=self.training)
|
|
175
|
+
|
|
176
|
+
# equivariant local attention
|
|
177
|
+
x_local = 0
|
|
178
|
+
for i, local_conv in enumerate(self.local_convs):
|
|
179
|
+
if self.pre_ln:
|
|
180
|
+
x = self.pre_lns[i](x)
|
|
181
|
+
h = self.h_lins[i](x)
|
|
182
|
+
h = F.relu(h)
|
|
183
|
+
x = local_conv(x, edge_index) + self.lins[i](x)
|
|
184
|
+
if self.post_bn:
|
|
185
|
+
x = self.post_bns[i](x)
|
|
186
|
+
x = F.relu(x)
|
|
187
|
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
|
188
|
+
x = (1 - self.beta) * self.lns[i](h * x) + self.beta * x
|
|
189
|
+
x_local = x_local + x
|
|
190
|
+
|
|
191
|
+
# equivariant global attention
|
|
192
|
+
if self._global:
|
|
193
|
+
batch, indices = batch.sort()
|
|
194
|
+
rev_perm = torch.empty_like(indices)
|
|
195
|
+
rev_perm[indices] = torch.arange(len(indices),
|
|
196
|
+
device=indices.device)
|
|
197
|
+
x_local = self.ln(x_local[indices])
|
|
198
|
+
x_global, mask = to_dense_batch(x_local, batch)
|
|
199
|
+
for attn in self.global_attn:
|
|
200
|
+
x_global = attn(x_global, mask)
|
|
201
|
+
x = x_global[mask][rev_perm]
|
|
202
|
+
x = self.pred_global(x)
|
|
203
|
+
else:
|
|
204
|
+
x = self.pred_local(x_local)
|
|
205
|
+
|
|
206
|
+
return F.log_softmax(x, dim=-1)
|