pyg-nightly 2.7.0.dev20250701__py3-none-any.whl → 2.7.0.dev20250703__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pyg-nightly might be problematic. Click here for more details.

@@ -0,0 +1,269 @@
1
+ import json
2
+ import os
3
+ import os.path as osp
4
+ from glob import glob
5
+ from typing import Callable, Dict, List, Optional
6
+
7
+ import numpy as np
8
+ import torch
9
+ from tqdm import tqdm
10
+
11
+ from torch_geometric.data import (
12
+ Data,
13
+ InMemoryDataset,
14
+ download_url,
15
+ extract_zip,
16
+ )
17
+
18
+
19
+ class Teeth3DS(InMemoryDataset):
20
+ r"""The Teeth3DS+ dataset from the `"An Extended Benchmark for Intra-oral
21
+ 3D Scans Analysis" <https://crns-smartvision.github.io/teeth3ds/>`_ paper.
22
+
23
+ This dataset is the first comprehensive public benchmark designed to
24
+ advance the field of intra-oral 3D scan analysis developed as part of the
25
+ 3DTeethSeg 2022 and 3DTeethLand 2024 MICCAI challenges, aiming to drive
26
+ research in teeth identification, segmentation, labeling, 3D modeling,
27
+ and dental landmark identification.
28
+ The dataset includes at least 1,800 intra-oral scans (containing 23,999
29
+ annotated teeth) collected from 900 patients, covering both upper and lower
30
+ jaws separately.
31
+
32
+ Args:
33
+ root (str): Root directory where the dataset should be saved.
34
+ split (str): The split name (one of :obj:`"Teeth3DS"`,
35
+ :obj:`"3DTeethSeg22_challenge"` or :obj:`"3DTeethLand_challenge"`).
36
+ train (bool, optional): If :obj:`True`, loads the training dataset,
37
+ otherwise the test dataset. (default: :obj:`True`)
38
+ num_samples (int, optional): Number of points to sample from each mesh.
39
+ (default: :obj:`30000`)
40
+ transform (callable, optional): A function/transform that takes in an
41
+ :obj:`torch_geometric.data.Data` object and returns a transformed
42
+ version. The data object will be transformed before every access.
43
+ (default: :obj:`None`)
44
+ pre_transform (callable, optional): A function/transform that takes in
45
+ an :obj:`torch_geometric.data.Data` object and returns a
46
+ transformed version. The data object will be transformed before
47
+ being saved to disk. (default: :obj:`None`)
48
+ force_reload (bool, optional): Whether to re-process the dataset.
49
+ (default: :obj:`False`)
50
+ """
51
+ urls = {
52
+ 'data_part_1.zip':
53
+ 'https://osf.io/download/qhprs/',
54
+ 'data_part_2.zip':
55
+ 'https://osf.io/download/4pwnr/',
56
+ 'data_part_3.zip':
57
+ 'https://osf.io/download/frwdp/',
58
+ 'data_part_4.zip':
59
+ 'https://osf.io/download/2arn4/',
60
+ 'data_part_5.zip':
61
+ 'https://osf.io/download/xrz5f/',
62
+ 'data_part_6.zip':
63
+ 'https://osf.io/download/23hgq/',
64
+ 'data_part_7.zip':
65
+ 'https://osf.io/download/u83ad/',
66
+ 'train_test_split':
67
+ 'https://files.de-1.osf.io/v1/'
68
+ 'resources/xctdy/providers/osfstorage/?zip='
69
+ }
70
+
71
+ sample_url = {
72
+ 'teeth3ds_sample': 'https://osf.io/download/vr38s/',
73
+ }
74
+
75
+ landmarks_urls = {
76
+ '3DTeethLand_landmarks_train.zip': 'https://osf.io/download/k5hbj/',
77
+ '3DTeethLand_landmarks_test.zip': 'https://osf.io/download/sqw5e/',
78
+ }
79
+
80
+ def __init__(
81
+ self,
82
+ root: str,
83
+ split:
84
+ str = 'Teeth3DS', # [3DTeethSeg22_challenge, 3DTeethLand_challenge]
85
+ train: bool = True,
86
+ num_samples: int = 30000,
87
+ transform: Optional[Callable] = None,
88
+ pre_transform: Optional[Callable] = None,
89
+ force_reload: bool = False,
90
+ ) -> None:
91
+
92
+ self.mode = 'training' if train else 'testing'
93
+ self.split = split
94
+ self.num_samples = num_samples
95
+
96
+ super().__init__(root, transform, pre_transform,
97
+ force_reload=force_reload)
98
+
99
+ @property
100
+ def processed_dir(self) -> str:
101
+ return os.path.join(self.root, f'processed_{self.split}_{self.mode}')
102
+
103
+ @property
104
+ def raw_file_names(self) -> List[str]:
105
+ return ['license.txt']
106
+
107
+ @property
108
+ def processed_file_names(self) -> List[str]:
109
+ # Directory containing train/test split files:
110
+ split_subdir = 'teeth3ds_sample' if self.split == 'sample' else ''
111
+ split_dir = osp.join(
112
+ self.raw_dir,
113
+ split_subdir,
114
+ f'{self.split}_train_test_split',
115
+ )
116
+
117
+ split_files = glob(osp.join(split_dir, f'{self.mode}*.txt'))
118
+
119
+ # Collect all file names from the split files:
120
+ combined_list = []
121
+ for file_path in split_files:
122
+ with open(file_path) as file:
123
+ combined_list.extend(file.read().splitlines())
124
+
125
+ # Generate the list of processed file paths:
126
+ return [f'{file_name}.pt' for file_name in combined_list]
127
+
128
+ def download(self) -> None:
129
+ if self.split == 'sample':
130
+ for key, url in self.sample_url.items():
131
+ path = download_url(url, self.root, filename=key)
132
+ extract_zip(path, self.raw_dir)
133
+ os.unlink(path)
134
+ else:
135
+ for key, url in self.urls.items():
136
+ path = download_url(url, self.root, filename=key)
137
+ extract_zip(path, self.raw_dir)
138
+ os.unlink(path)
139
+ for key, url in self.landmarks_urls.items():
140
+ path = download_url(url, self.root, filename=key)
141
+ extract_zip(path, self.raw_dir) # Extract each downloaded part
142
+ os.unlink(path)
143
+
144
+ def process_file(self, file_path: str) -> Optional[Data]:
145
+ """Processes the input file path to load mesh data, annotations,
146
+ and prepare the input features for a graph-based deep learning model.
147
+ """
148
+ import trimesh
149
+ from fpsample import bucket_fps_kdline_sampling
150
+
151
+ mesh = trimesh.load_mesh(file_path)
152
+
153
+ if isinstance(mesh, list):
154
+ # Handle the case where a list of Geometry objects is returned
155
+ mesh = mesh[0]
156
+
157
+ vertices = mesh.vertices
158
+ vertex_normals = mesh.vertex_normals
159
+
160
+ # Perform sampling on mesh vertices:
161
+ if len(vertices) < self.num_samples:
162
+ sampled_indices = np.random.choice(
163
+ len(vertices),
164
+ self.num_samples,
165
+ replace=True,
166
+ )
167
+ else:
168
+ sampled_indices = bucket_fps_kdline_sampling(
169
+ vertices,
170
+ self.num_samples,
171
+ h=5,
172
+ start_idx=0,
173
+ )
174
+
175
+ if len(sampled_indices) != self.num_samples:
176
+ raise RuntimeError(f"Sampled points mismatch, expected "
177
+ f"{self.num_samples} points, but got "
178
+ f"{len(sampled_indices)} for '{file_path}'")
179
+
180
+ # Extract features and annotations for the sampled points:
181
+ pos = torch.tensor(vertices[sampled_indices], dtype=torch.float)
182
+ x = torch.tensor(vertex_normals[sampled_indices], dtype=torch.float)
183
+
184
+ # Load segmentation annotations:
185
+ seg_annotation_path = file_path.replace('.obj', '.json')
186
+ if osp.exists(seg_annotation_path):
187
+ with open(seg_annotation_path) as f:
188
+ seg_annotations = json.load(f)
189
+ y = torch.tensor(
190
+ np.asarray(seg_annotations['labels'])[sampled_indices],
191
+ dtype=torch.float)
192
+ instances = torch.tensor(
193
+ np.asarray(seg_annotations['instances'])[sampled_indices],
194
+ dtype=torch.float)
195
+ else:
196
+ y = torch.empty(0, 3)
197
+ instances = torch.empty(0, 3)
198
+
199
+ # Load landmarks annotations:
200
+ landmarks_annotation_path = file_path.replace('.obj', '__kpt.json')
201
+
202
+ # Parse keypoint annotations into structured tensors:
203
+ keypoints_dict: Dict[str, List] = {
204
+ key: []
205
+ for key in [
206
+ 'Mesial', 'Distal', 'Cusp', 'InnerPoint', 'OuterPoint',
207
+ 'FacialPoint'
208
+ ]
209
+ }
210
+ keypoint_tensors: Dict[str, torch.Tensor] = {
211
+ key: torch.empty(0, 3)
212
+ for key in [
213
+ 'Mesial', 'Distal', 'Cusp', 'InnerPoint', 'OuterPoint',
214
+ 'FacialPoint'
215
+ ]
216
+ }
217
+ if osp.exists(landmarks_annotation_path):
218
+ with open(landmarks_annotation_path) as f:
219
+ landmarks_annotations = json.load(f)
220
+
221
+ for keypoint in landmarks_annotations['objects']:
222
+ keypoints_dict[keypoint['class']].extend(keypoint['coord'])
223
+
224
+ keypoint_tensors = {
225
+ k: torch.tensor(np.asarray(v),
226
+ dtype=torch.float).reshape(-1, 3)
227
+ for k, v in keypoints_dict.items()
228
+ }
229
+
230
+ data = Data(
231
+ pos=pos,
232
+ x=x,
233
+ y=y,
234
+ instances=instances,
235
+ jaw=file_path.split('.obj')[0].split('_')[1],
236
+ mesial=keypoint_tensors['Mesial'],
237
+ distal=keypoint_tensors['Distal'],
238
+ cusp=keypoint_tensors['Cusp'],
239
+ inner_point=keypoint_tensors['InnerPoint'],
240
+ outer_point=keypoint_tensors['OuterPoint'],
241
+ facial_point=keypoint_tensors['FacialPoint'],
242
+ )
243
+
244
+ if self.pre_transform is not None:
245
+ data = self.pre_transform(data)
246
+
247
+ return data
248
+
249
+ def process(self) -> None:
250
+ for file in tqdm(self.processed_file_names):
251
+ name = file.split('.')[0]
252
+ path = osp.join(self.raw_dir, '**', '*', name + '.obj')
253
+ paths = glob(path)
254
+ if len(paths) == 1:
255
+ data = self.process_file(paths[0])
256
+ torch.save(data, osp.join(self.processed_dir, file))
257
+
258
+ def len(self) -> int:
259
+ return len(self.processed_file_names)
260
+
261
+ def get(self, idx: int) -> Data:
262
+ return torch.load(
263
+ osp.join(self.processed_dir, self.processed_file_names[idx]),
264
+ weights_only=False,
265
+ )
266
+
267
+ def __repr__(self) -> str:
268
+ return (f'{self.__class__.__name__}({len(self)}, '
269
+ f'mode={self.mode}, split={self.split})')
@@ -32,6 +32,8 @@ class PatchTransformerAggregation(Aggregation):
32
32
  aggr (str or list[str], optional): The aggregation module, *e.g.*,
33
33
  :obj:`"sum"`, :obj:`"mean"`, :obj:`"min"`, :obj:`"max"`,
34
34
  :obj:`"var"`, :obj:`"std"`. (default: :obj:`"mean"`)
35
+ device (torch.device, optional): The device of the module.
36
+ (default: :obj:`None`)
35
37
  """
36
38
  def __init__(
37
39
  self,
@@ -43,6 +45,7 @@ class PatchTransformerAggregation(Aggregation):
43
45
  heads: int = 1,
44
46
  dropout: float = 0.0,
45
47
  aggr: Union[str, List[str]] = 'mean',
48
+ device: Optional[torch.device] = None,
46
49
  ) -> None:
47
50
  super().__init__()
48
51
 
@@ -55,12 +58,13 @@ class PatchTransformerAggregation(Aggregation):
55
58
  for aggr in self.aggrs:
56
59
  assert aggr in ['sum', 'mean', 'min', 'max', 'var', 'std']
57
60
 
58
- self.lin = torch.nn.Linear(in_channels, hidden_channels)
61
+ self.lin = torch.nn.Linear(in_channels, hidden_channels, device=device)
59
62
  self.pad_projector = torch.nn.Linear(
60
63
  patch_size * hidden_channels,
61
64
  hidden_channels,
65
+ device=device,
62
66
  )
63
- self.pe = PositionalEncoding(hidden_channels)
67
+ self.pe = PositionalEncoding(hidden_channels, device=device)
64
68
 
65
69
  self.blocks = torch.nn.ModuleList([
66
70
  MultiheadAttentionBlock(
@@ -68,12 +72,14 @@ class PatchTransformerAggregation(Aggregation):
68
72
  heads=heads,
69
73
  layer_norm=True,
70
74
  dropout=dropout,
75
+ device=device,
71
76
  ) for _ in range(num_transformer_blocks)
72
77
  ])
73
78
 
74
79
  self.fc = torch.nn.Linear(
75
80
  hidden_channels * len(self.aggrs),
76
81
  out_channels,
82
+ device=device,
77
83
  )
78
84
 
79
85
  def reset_parameters(self) -> None:
@@ -26,9 +26,11 @@ class MultiheadAttentionBlock(torch.nn.Module):
26
26
  normalization. (default: :obj:`True`)
27
27
  dropout (float, optional): Dropout probability of attention weights.
28
28
  (default: :obj:`0`)
29
+ device (torch.device, optional): The device of the module.
30
+ (default: :obj:`None`)
29
31
  """
30
32
  def __init__(self, channels: int, heads: int = 1, layer_norm: bool = True,
31
- dropout: float = 0.0):
33
+ dropout: float = 0.0, device: Optional[torch.device] = None):
32
34
  super().__init__()
33
35
 
34
36
  self.channels = channels
@@ -40,10 +42,13 @@ class MultiheadAttentionBlock(torch.nn.Module):
40
42
  heads,
41
43
  batch_first=True,
42
44
  dropout=dropout,
45
+ device=device,
43
46
  )
44
- self.lin = Linear(channels, channels)
45
- self.layer_norm1 = LayerNorm(channels) if layer_norm else None
46
- self.layer_norm2 = LayerNorm(channels) if layer_norm else None
47
+ self.lin = Linear(channels, channels, device=device)
48
+ self.layer_norm1 = LayerNorm(channels,
49
+ device=device) if layer_norm else None
50
+ self.layer_norm2 = LayerNorm(channels,
51
+ device=device) if layer_norm else None
47
52
 
48
53
  def reset_parameters(self):
49
54
  self.attn._reset_parameters()
@@ -1,7 +1,7 @@
1
1
  # The below is to suppress the warning on torch.nn.conv.MeshCNNConv::update
2
2
  # pyright: reportIncompatibleMethodOverride=false
3
+ import warnings
3
4
  from typing import Optional
4
- from warnings import warn
5
5
 
6
6
  import torch
7
7
  from torch.nn import Linear, Module, ModuleList
@@ -456,13 +456,10 @@ class MeshCNNConv(MessagePassing):
456
456
  {type(network)}"
457
457
  if not hasattr(network, "in_channels") and \
458
458
  not hasattr(network, "in_features"):
459
- warn(
460
- f"kernel[{i}] does not have attribute \
461
- 'in_channels' nor 'out_features'. The \
462
- network must take as input a \
463
- {self.in_channels}-dimensional tensor. \
464
- Still, assuming user configured \
465
- correctly. Continuing..", stacklevel=2)
459
+ warnings.warn(
460
+ f"kernel[{i}] does not have attribute 'in_channels' nor "
461
+ f"'out_features'. The network must take as input a "
462
+ f"{self.in_channels}-dimensional tensor.", stacklevel=2)
466
463
  else:
467
464
  input_dimension = getattr(network, "in_channels",
468
465
  network.in_features)
@@ -475,13 +472,10 @@ class MeshCNNConv(MessagePassing):
475
472
 
476
473
  if not hasattr(network, "out_channels") and \
477
474
  not hasattr(network, "out_features"):
478
- warn(
479
- f"kernel[{i}] does not have attribute \
480
- 'in_channels' nor 'out_features'. The \
481
- network must take as input a \
482
- {self.in_channels}-dimensional tensor. \
483
- Still, assuming user configured \
484
- correctly. Continuing..", stacklevel=2)
475
+ warnings.warn(
476
+ f"kernel[{i}] does not have attribute 'in_channels' nor "
477
+ f"'out_features'. The network must take as input a "
478
+ f"{self.in_channels}-dimensional tensor.", stacklevel=2)
485
479
  else:
486
480
  output_dimension = getattr(network, "out_channels",
487
481
  network.out_features)
@@ -1,4 +1,5 @@
1
1
  import math
2
+ from typing import Optional
2
3
 
3
4
  import torch
4
5
  from torch import Tensor
@@ -23,12 +24,15 @@ class PositionalEncoding(torch.nn.Module):
23
24
  granularity (float, optional): The granularity of the positions. If
24
25
  set to smaller value, the encoder will capture more fine-grained
25
26
  changes in positions. (default: :obj:`1.0`)
27
+ device (torch.device, optional): The device of the module.
28
+ (default: :obj:`None`)
26
29
  """
27
30
  def __init__(
28
31
  self,
29
32
  out_channels: int,
30
33
  base_freq: float = 1e-4,
31
34
  granularity: float = 1.0,
35
+ device: Optional[torch.device] = None,
32
36
  ):
33
37
  super().__init__()
34
38
 
@@ -40,7 +44,8 @@ class PositionalEncoding(torch.nn.Module):
40
44
  self.base_freq = base_freq
41
45
  self.granularity = granularity
42
46
 
43
- frequency = torch.logspace(0, 1, out_channels // 2, base_freq)
47
+ frequency = torch.logspace(0, 1, out_channels // 2, base_freq,
48
+ device=device)
44
49
  self.register_buffer('frequency', frequency)
45
50
 
46
51
  self.reset_parameters()
@@ -75,13 +80,17 @@ class TemporalEncoding(torch.nn.Module):
75
80
 
76
81
  Args:
77
82
  out_channels (int): Size :math:`d` of each output sample.
83
+ device (torch.device, optional): The device of the module.
84
+ (default: :obj:`None`)
78
85
  """
79
- def __init__(self, out_channels: int):
86
+ def __init__(self, out_channels: int,
87
+ device: Optional[torch.device] = None):
80
88
  super().__init__()
81
89
  self.out_channels = out_channels
82
90
 
83
91
  sqrt = math.sqrt(out_channels)
84
- weight = 1.0 / sqrt**torch.linspace(0, sqrt, out_channels).view(1, -1)
92
+ weight = 1.0 / sqrt**torch.linspace(0, sqrt, out_channels,
93
+ device=device).view(1, -1)
85
94
  self.register_buffer('weight', weight)
86
95
 
87
96
  self.reset_parameters()
@@ -32,6 +32,7 @@ from .visnet import ViSNet
32
32
  from .g_retriever import GRetriever
33
33
  from .git_mol import GITMol
34
34
  from .molecule_gpt import MoleculeGPT
35
+ from .protein_mpnn import ProteinMPNN
35
36
  from .glem import GLEM
36
37
  from .sgformer import SGFormer
37
38
  # Deprecated:
@@ -86,6 +87,7 @@ __all__ = classes = [
86
87
  'GRetriever',
87
88
  'GITMol',
88
89
  'MoleculeGPT',
90
+ 'ProteinMPNN',
89
91
  'GLEM',
90
92
  'SGFormer',
91
93
  'ARLinkPredictor',
@@ -8,6 +8,13 @@ from torch_geometric.loader import DataLoader, NeighborLoader
8
8
  from torch_geometric.nn.models import GraphSAGE, basic_gnn
9
9
 
10
10
 
11
+ def deal_nan(x):
12
+ if isinstance(x, torch.Tensor):
13
+ x = x.clone()
14
+ x[torch.isnan(x)] = 0.0
15
+ return x
16
+
17
+
11
18
  class GLEM(torch.nn.Module):
12
19
  r"""This GNN+LM co-training model is based on GLEM from the `"Learning on
13
20
  Large-scale Text-attributed Graphs via Variational Inference"
@@ -379,9 +386,6 @@ class GLEM(torch.nn.Module):
379
386
  is_augmented: use EM or just train GNN and LM with gold data
380
387
 
381
388
  """
382
- def deal_nan(x):
383
- return 0 if torch.isnan(x) else x
384
-
385
389
  if is_augmented and (sum(~is_gold) > 0):
386
390
  mle_loss = deal_nan(loss_func(logits[is_gold], labels[is_gold]))
387
391
  # all other labels beside from ground truth(gold labels)