pyg-nightly 2.7.0.dev20241009__py3-none-any.whl → 2.8.0.dev20251228__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (229) hide show
  1. {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251228.dist-info}/METADATA +77 -53
  2. {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251228.dist-info}/RECORD +227 -190
  3. {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251228.dist-info}/WHEEL +1 -1
  4. pyg_nightly-2.8.0.dev20251228.dist-info/licenses/LICENSE +19 -0
  5. torch_geometric/__init__.py +14 -2
  6. torch_geometric/_compile.py +9 -3
  7. torch_geometric/_onnx.py +214 -0
  8. torch_geometric/config_mixin.py +5 -3
  9. torch_geometric/config_store.py +1 -1
  10. torch_geometric/contrib/__init__.py +1 -1
  11. torch_geometric/contrib/explain/pgm_explainer.py +1 -1
  12. torch_geometric/data/batch.py +2 -2
  13. torch_geometric/data/collate.py +1 -3
  14. torch_geometric/data/data.py +109 -5
  15. torch_geometric/data/database.py +4 -0
  16. torch_geometric/data/dataset.py +14 -11
  17. torch_geometric/data/extract.py +1 -1
  18. torch_geometric/data/feature_store.py +17 -22
  19. torch_geometric/data/graph_store.py +3 -2
  20. torch_geometric/data/hetero_data.py +139 -7
  21. torch_geometric/data/hypergraph_data.py +2 -2
  22. torch_geometric/data/in_memory_dataset.py +2 -2
  23. torch_geometric/data/lightning/datamodule.py +42 -28
  24. torch_geometric/data/storage.py +9 -1
  25. torch_geometric/datasets/__init__.py +18 -1
  26. torch_geometric/datasets/actor.py +7 -9
  27. torch_geometric/datasets/airfrans.py +15 -17
  28. torch_geometric/datasets/airports.py +8 -10
  29. torch_geometric/datasets/amazon.py +8 -11
  30. torch_geometric/datasets/amazon_book.py +8 -9
  31. torch_geometric/datasets/amazon_products.py +7 -9
  32. torch_geometric/datasets/aminer.py +8 -9
  33. torch_geometric/datasets/aqsol.py +10 -13
  34. torch_geometric/datasets/attributed_graph_dataset.py +8 -10
  35. torch_geometric/datasets/ba_multi_shapes.py +10 -12
  36. torch_geometric/datasets/ba_shapes.py +5 -6
  37. torch_geometric/datasets/city.py +157 -0
  38. torch_geometric/datasets/dbp15k.py +1 -1
  39. torch_geometric/datasets/git_mol_dataset.py +263 -0
  40. torch_geometric/datasets/hgb_dataset.py +2 -2
  41. torch_geometric/datasets/hm.py +1 -1
  42. torch_geometric/datasets/instruct_mol_dataset.py +134 -0
  43. torch_geometric/datasets/md17.py +3 -3
  44. torch_geometric/datasets/medshapenet.py +145 -0
  45. torch_geometric/datasets/modelnet.py +1 -1
  46. torch_geometric/datasets/molecule_gpt_dataset.py +492 -0
  47. torch_geometric/datasets/molecule_net.py +3 -2
  48. torch_geometric/datasets/ppi.py +2 -1
  49. torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
  50. torch_geometric/datasets/qm7.py +1 -1
  51. torch_geometric/datasets/qm9.py +1 -1
  52. torch_geometric/datasets/snap_dataset.py +8 -4
  53. torch_geometric/datasets/tag_dataset.py +462 -0
  54. torch_geometric/datasets/teeth3ds.py +269 -0
  55. torch_geometric/datasets/web_qsp_dataset.py +310 -209
  56. torch_geometric/datasets/wikics.py +2 -1
  57. torch_geometric/deprecation.py +1 -1
  58. torch_geometric/distributed/__init__.py +13 -0
  59. torch_geometric/distributed/dist_loader.py +2 -2
  60. torch_geometric/distributed/partition.py +2 -2
  61. torch_geometric/distributed/rpc.py +3 -3
  62. torch_geometric/edge_index.py +18 -14
  63. torch_geometric/explain/algorithm/attention_explainer.py +219 -29
  64. torch_geometric/explain/algorithm/base.py +2 -2
  65. torch_geometric/explain/algorithm/captum.py +1 -1
  66. torch_geometric/explain/algorithm/captum_explainer.py +2 -1
  67. torch_geometric/explain/algorithm/gnn_explainer.py +406 -69
  68. torch_geometric/explain/algorithm/graphmask_explainer.py +8 -8
  69. torch_geometric/explain/algorithm/pg_explainer.py +305 -47
  70. torch_geometric/explain/explainer.py +2 -2
  71. torch_geometric/explain/explanation.py +87 -3
  72. torch_geometric/explain/metric/faithfulness.py +1 -1
  73. torch_geometric/graphgym/config.py +3 -2
  74. torch_geometric/graphgym/imports.py +15 -4
  75. torch_geometric/graphgym/logger.py +1 -1
  76. torch_geometric/graphgym/loss.py +1 -1
  77. torch_geometric/graphgym/models/encoder.py +2 -2
  78. torch_geometric/graphgym/models/layer.py +1 -1
  79. torch_geometric/graphgym/utils/comp_budget.py +4 -3
  80. torch_geometric/hash_tensor.py +798 -0
  81. torch_geometric/index.py +14 -5
  82. torch_geometric/inspector.py +4 -0
  83. torch_geometric/io/fs.py +5 -4
  84. torch_geometric/llm/__init__.py +9 -0
  85. torch_geometric/llm/large_graph_indexer.py +741 -0
  86. torch_geometric/llm/models/__init__.py +23 -0
  87. torch_geometric/{nn → llm}/models/g_retriever.py +77 -45
  88. torch_geometric/llm/models/git_mol.py +336 -0
  89. torch_geometric/llm/models/glem.py +397 -0
  90. torch_geometric/{nn/nlp → llm/models}/llm.py +180 -32
  91. torch_geometric/llm/models/llm_judge.py +158 -0
  92. torch_geometric/llm/models/molecule_gpt.py +222 -0
  93. torch_geometric/llm/models/protein_mpnn.py +333 -0
  94. torch_geometric/llm/models/sentence_transformer.py +188 -0
  95. torch_geometric/llm/models/txt2kg.py +353 -0
  96. torch_geometric/llm/models/vision_transformer.py +38 -0
  97. torch_geometric/llm/rag_loader.py +154 -0
  98. torch_geometric/llm/utils/__init__.py +10 -0
  99. torch_geometric/llm/utils/backend_utils.py +443 -0
  100. torch_geometric/llm/utils/feature_store.py +169 -0
  101. torch_geometric/llm/utils/graph_store.py +199 -0
  102. torch_geometric/llm/utils/vectorrag.py +125 -0
  103. torch_geometric/loader/cluster.py +4 -4
  104. torch_geometric/loader/ibmb_loader.py +4 -4
  105. torch_geometric/loader/link_loader.py +1 -1
  106. torch_geometric/loader/link_neighbor_loader.py +2 -1
  107. torch_geometric/loader/mixin.py +6 -5
  108. torch_geometric/loader/neighbor_loader.py +1 -1
  109. torch_geometric/loader/neighbor_sampler.py +2 -2
  110. torch_geometric/loader/prefetch.py +3 -2
  111. torch_geometric/loader/temporal_dataloader.py +2 -2
  112. torch_geometric/loader/utils.py +10 -10
  113. torch_geometric/metrics/__init__.py +14 -0
  114. torch_geometric/metrics/link_pred.py +745 -92
  115. torch_geometric/nn/__init__.py +1 -0
  116. torch_geometric/nn/aggr/base.py +1 -1
  117. torch_geometric/nn/aggr/equilibrium.py +1 -1
  118. torch_geometric/nn/aggr/fused.py +1 -1
  119. torch_geometric/nn/aggr/patch_transformer.py +8 -2
  120. torch_geometric/nn/aggr/set_transformer.py +1 -1
  121. torch_geometric/nn/aggr/utils.py +9 -4
  122. torch_geometric/nn/attention/__init__.py +9 -1
  123. torch_geometric/nn/attention/polynormer.py +107 -0
  124. torch_geometric/nn/attention/qformer.py +71 -0
  125. torch_geometric/nn/attention/sgformer.py +99 -0
  126. torch_geometric/nn/conv/__init__.py +2 -0
  127. torch_geometric/nn/conv/appnp.py +1 -1
  128. torch_geometric/nn/conv/cugraph/gat_conv.py +8 -2
  129. torch_geometric/nn/conv/cugraph/rgcn_conv.py +3 -0
  130. torch_geometric/nn/conv/cugraph/sage_conv.py +3 -0
  131. torch_geometric/nn/conv/dna_conv.py +1 -1
  132. torch_geometric/nn/conv/eg_conv.py +7 -7
  133. torch_geometric/nn/conv/gen_conv.py +1 -1
  134. torch_geometric/nn/conv/gravnet_conv.py +2 -1
  135. torch_geometric/nn/conv/hetero_conv.py +2 -1
  136. torch_geometric/nn/conv/meshcnn_conv.py +487 -0
  137. torch_geometric/nn/conv/message_passing.py +5 -4
  138. torch_geometric/nn/conv/rgcn_conv.py +2 -1
  139. torch_geometric/nn/conv/sg_conv.py +1 -1
  140. torch_geometric/nn/conv/spline_conv.py +2 -1
  141. torch_geometric/nn/conv/ssg_conv.py +1 -1
  142. torch_geometric/nn/conv/transformer_conv.py +5 -3
  143. torch_geometric/nn/data_parallel.py +5 -4
  144. torch_geometric/nn/dense/linear.py +0 -20
  145. torch_geometric/nn/encoding.py +17 -3
  146. torch_geometric/nn/fx.py +14 -12
  147. torch_geometric/nn/model_hub.py +2 -15
  148. torch_geometric/nn/models/__init__.py +11 -2
  149. torch_geometric/nn/models/attentive_fp.py +1 -1
  150. torch_geometric/nn/models/attract_repel.py +148 -0
  151. torch_geometric/nn/models/basic_gnn.py +2 -1
  152. torch_geometric/nn/models/captum.py +1 -1
  153. torch_geometric/nn/models/deep_graph_infomax.py +1 -1
  154. torch_geometric/nn/models/dimenet.py +2 -2
  155. torch_geometric/nn/models/dimenet_utils.py +4 -2
  156. torch_geometric/nn/models/gpse.py +1083 -0
  157. torch_geometric/nn/models/graph_unet.py +13 -4
  158. torch_geometric/nn/models/lpformer.py +783 -0
  159. torch_geometric/nn/models/metapath2vec.py +1 -1
  160. torch_geometric/nn/models/mlp.py +4 -2
  161. torch_geometric/nn/models/node2vec.py +1 -1
  162. torch_geometric/nn/models/polynormer.py +206 -0
  163. torch_geometric/nn/models/rev_gnn.py +3 -3
  164. torch_geometric/nn/models/sgformer.py +219 -0
  165. torch_geometric/nn/models/signed_gcn.py +1 -1
  166. torch_geometric/nn/models/visnet.py +2 -2
  167. torch_geometric/nn/norm/batch_norm.py +17 -7
  168. torch_geometric/nn/norm/diff_group_norm.py +7 -2
  169. torch_geometric/nn/norm/graph_norm.py +9 -4
  170. torch_geometric/nn/norm/instance_norm.py +5 -1
  171. torch_geometric/nn/norm/layer_norm.py +15 -7
  172. torch_geometric/nn/norm/msg_norm.py +8 -2
  173. torch_geometric/nn/pool/__init__.py +8 -4
  174. torch_geometric/nn/pool/cluster_pool.py +3 -4
  175. torch_geometric/nn/pool/connect/base.py +1 -3
  176. torch_geometric/nn/pool/knn.py +13 -10
  177. torch_geometric/nn/pool/select/base.py +1 -4
  178. torch_geometric/nn/to_hetero_module.py +4 -3
  179. torch_geometric/nn/to_hetero_transformer.py +3 -3
  180. torch_geometric/nn/to_hetero_with_bases_transformer.py +4 -4
  181. torch_geometric/profile/__init__.py +2 -0
  182. torch_geometric/profile/nvtx.py +66 -0
  183. torch_geometric/profile/utils.py +20 -5
  184. torch_geometric/sampler/__init__.py +2 -1
  185. torch_geometric/sampler/base.py +336 -7
  186. torch_geometric/sampler/hgt_sampler.py +11 -1
  187. torch_geometric/sampler/neighbor_sampler.py +296 -23
  188. torch_geometric/sampler/utils.py +93 -5
  189. torch_geometric/testing/__init__.py +4 -0
  190. torch_geometric/testing/decorators.py +35 -5
  191. torch_geometric/testing/distributed.py +1 -1
  192. torch_geometric/transforms/__init__.py +2 -0
  193. torch_geometric/transforms/add_gpse.py +49 -0
  194. torch_geometric/transforms/add_metapaths.py +8 -6
  195. torch_geometric/transforms/add_positional_encoding.py +2 -2
  196. torch_geometric/transforms/base_transform.py +2 -1
  197. torch_geometric/transforms/delaunay.py +65 -15
  198. torch_geometric/transforms/face_to_edge.py +32 -3
  199. torch_geometric/transforms/gdc.py +7 -8
  200. torch_geometric/transforms/largest_connected_components.py +1 -1
  201. torch_geometric/transforms/mask.py +5 -1
  202. torch_geometric/transforms/normalize_features.py +3 -3
  203. torch_geometric/transforms/random_link_split.py +1 -1
  204. torch_geometric/transforms/remove_duplicated_edges.py +4 -2
  205. torch_geometric/transforms/rooted_subgraph.py +1 -1
  206. torch_geometric/typing.py +70 -17
  207. torch_geometric/utils/__init__.py +4 -1
  208. torch_geometric/utils/_lexsort.py +0 -9
  209. torch_geometric/utils/_negative_sampling.py +27 -12
  210. torch_geometric/utils/_scatter.py +132 -195
  211. torch_geometric/utils/_sort_edge_index.py +0 -2
  212. torch_geometric/utils/_spmm.py +16 -14
  213. torch_geometric/utils/_subgraph.py +4 -0
  214. torch_geometric/utils/_to_dense_batch.py +2 -2
  215. torch_geometric/utils/_trim_to_layer.py +2 -2
  216. torch_geometric/utils/convert.py +17 -10
  217. torch_geometric/utils/cross_entropy.py +34 -13
  218. torch_geometric/utils/embedding.py +91 -2
  219. torch_geometric/utils/geodesic.py +4 -3
  220. torch_geometric/utils/influence.py +279 -0
  221. torch_geometric/utils/map.py +13 -9
  222. torch_geometric/utils/nested.py +1 -1
  223. torch_geometric/utils/smiles.py +3 -3
  224. torch_geometric/utils/sparse.py +7 -14
  225. torch_geometric/visualization/__init__.py +2 -1
  226. torch_geometric/visualization/graph.py +250 -5
  227. torch_geometric/warnings.py +11 -2
  228. torch_geometric/nn/nlp/__init__.py +0 -7
  229. torch_geometric/nn/nlp/sentence_transformer.py +0 -101
@@ -123,8 +123,8 @@ class HGBDataset(InMemoryDataset):
123
123
  start = info.index('LINK\tSTART\tEND\tMEANING') + 1
124
124
  end = info[start:].index('')
125
125
  for key, row in enumerate(info[start:start + end]):
126
- row = row.split('\t')[1:]
127
- src, dst, rel = (v for v in row if v != '')
126
+ edge = row.split('\t')[1:]
127
+ src, dst, rel = (v for v in edge if v != '')
128
128
  src, dst = n_types[int(src)], n_types[int(dst)]
129
129
  rel = rel.split('-')[1]
130
130
  e_types[key] = (src, rel, dst)
@@ -81,7 +81,7 @@ class HM(InMemoryDataset):
81
81
  xs.append(torch.from_numpy(x).to(torch.float))
82
82
 
83
83
  x = torch.from_numpy(df['age'].values).to(torch.float).view(-1, 1)
84
- x = x.nan_to_num(nan=x.nanmean())
84
+ x = x.nan_to_num(nan=x.nanmean()) # type: ignore
85
85
  xs.append(x / x.max())
86
86
 
87
87
  data['customer'].x = torch.cat(xs, dim=-1)
@@ -0,0 +1,134 @@
1
+ import json
2
+ import sys
3
+ from typing import Callable, List, Optional
4
+
5
+ import torch
6
+ from tqdm import tqdm
7
+
8
+ from torch_geometric.data import Data, InMemoryDataset
9
+ from torch_geometric.io import fs
10
+ from torch_geometric.utils import one_hot
11
+
12
+
13
+ class InstructMolDataset(InMemoryDataset):
14
+ r"""The dataset from the `"InstructMol: Multi-Modal Integration for
15
+ Building a Versatile and Reliable Molecular Assistant in Drug Discovery"
16
+ <https://arxiv.org/pdf/2311.16208>`_ paper.
17
+
18
+ Args:
19
+ root (str): Root directory where the dataset should be saved.
20
+ transform (callable, optional): A function/transform that takes in an
21
+ :obj:`torch_geometric.data.Data` object and returns a transformed
22
+ version. The data object will be transformed before every access.
23
+ (default: :obj:`None`)
24
+ pre_transform (callable, optional): A function/transform that takes in
25
+ an :obj:`torch_geometric.data.Data` object and returns a
26
+ transformed version. The data object will be transformed before
27
+ being saved to disk. (default: :obj:`None`)
28
+ pre_filter (callable, optional): A function that takes in an
29
+ :obj:`torch_geometric.data.Data` object and returns a boolean
30
+ value, indicating whether the data object should be included in the
31
+ final dataset. (default: :obj:`None`)
32
+ force_reload (bool, optional): Whether to re-process the dataset.
33
+ (default: :obj:`False`)
34
+ """
35
+ raw_url = 'https://huggingface.co/datasets/OpenMol/PubChemSFT/resolve/main'
36
+
37
+ def __init__(
38
+ self,
39
+ root: str,
40
+ transform: Optional[Callable] = None,
41
+ pre_transform: Optional[Callable] = None,
42
+ pre_filter: Optional[Callable] = None,
43
+ force_reload: bool = False,
44
+ ):
45
+ super().__init__(root, transform, pre_transform, pre_filter,
46
+ force_reload=force_reload)
47
+ self.load(self.processed_paths[0])
48
+
49
+ @property
50
+ def raw_file_names(self) -> List[str]:
51
+ return ['all_clean.json']
52
+
53
+ @property
54
+ def processed_file_names(self) -> List[str]:
55
+ return ['data.pt']
56
+
57
+ def download(self) -> None:
58
+ print('downloading dataset...')
59
+ fs.cp(f'{self.raw_url}/all_clean.json', self.raw_dir)
60
+
61
+ def process(self) -> None:
62
+ try:
63
+ from rdkit import Chem
64
+ from rdkit.Chem.rdchem import BondType as BT
65
+ WITH_RDKIT = True
66
+
67
+ except ImportError:
68
+ WITH_RDKIT = False
69
+
70
+ if not WITH_RDKIT:
71
+ print(("Using a pre-processed version of the dataset. Please "
72
+ "install 'rdkit' to alternatively process the raw data."),
73
+ file=sys.stderr)
74
+
75
+ data_list = fs.torch_load(self.raw_paths[0])
76
+ data_list = [Data(**data_dict) for data_dict in data_list]
77
+
78
+ if self.pre_filter is not None:
79
+ data_list = [d for d in data_list if self.pre_filter(d)]
80
+
81
+ if self.pre_transform is not None:
82
+ data_list = [self.pre_transform(d) for d in data_list]
83
+
84
+ self.save(data_list, self.processed_paths[0])
85
+ return
86
+
87
+ # types of atom and bond
88
+ types = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4, 'Unknow': 5}
89
+ bonds = {BT.SINGLE: 0, BT.DOUBLE: 1, BT.TRIPLE: 2, BT.AROMATIC: 3}
90
+
91
+ # load data
92
+ mols = json.load(open(f'{self.raw_dir}/all_clean.json'))
93
+
94
+ data_list = []
95
+ for smiles, qa_pairs in tqdm(mols.items(), total=len(mols)):
96
+ mol = Chem.MolFromSmiles(smiles)
97
+ if mol is None:
98
+ continue
99
+
100
+ x: torch.Tensor = torch.tensor([
101
+ types[atom.GetSymbol()] if atom.GetSymbol() in types else 5
102
+ for atom in mol.GetAtoms()
103
+ ])
104
+ x = one_hot(x, num_classes=len(types), dtype=torch.float)
105
+
106
+ rows, cols, edge_types = [], [], []
107
+ for bond in mol.GetBonds():
108
+ i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
109
+ edge_types += [bonds[bond.GetBondType()]] * 2
110
+ rows += [i, j]
111
+ cols += [j, i]
112
+
113
+ edge_index = torch.tensor([rows, cols], dtype=torch.long)
114
+ edge_type = torch.tensor(edge_types, dtype=torch.long)
115
+ edge_attr = one_hot(edge_type, num_classes=len(bonds))
116
+
117
+ for question, answer in qa_pairs:
118
+ data = Data(
119
+ x=x,
120
+ edge_index=edge_index,
121
+ edge_attr=edge_attr,
122
+ smiles=smiles,
123
+ instruction=question,
124
+ y=answer,
125
+ )
126
+
127
+ if self.pre_filter is not None and not self.pre_filter(data):
128
+ continue
129
+ if self.pre_transform is not None:
130
+ data = self.pre_transform(data)
131
+
132
+ data_list.append(data)
133
+
134
+ self.save(data_list, self.processed_paths[0])
@@ -57,7 +57,7 @@ class MD17(InMemoryDataset):
57
57
  +--------------------+--------------------+-------------------------------+-----------+
58
58
  | Uracil | DFT | :obj:`uracil` | 133,770 |
59
59
  +--------------------+--------------------+-------------------------------+-----------+
60
- | Naphthalene | DFT | :obj:`napthalene` | 326,250 |
60
+ | Naphthalene | DFT | :obj:`naphthalene` | 326,250 |
61
61
  +--------------------+--------------------+-------------------------------+-----------+
62
62
  | Aspirin | DFT | :obj:`aspirin` | 211,762 |
63
63
  +--------------------+--------------------+-------------------------------+-----------+
@@ -77,7 +77,7 @@ class MD17(InMemoryDataset):
77
77
  +--------------------+--------------------+-------------------------------+-----------+
78
78
  | Uracil (R) | DFT (PBE/def2-SVP) | :obj:`revised uracil` | 100,000 |
79
79
  +--------------------+--------------------+-------------------------------+-----------+
80
- | Naphthalene (R) | DFT (PBE/def2-SVP) | :obj:`revised napthalene` | 100,000 |
80
+ | Naphthalene (R) | DFT (PBE/def2-SVP) | :obj:`revised naphthalene` | 100,000 |
81
81
  +--------------------+--------------------+-------------------------------+-----------+
82
82
  | Aspirin (R) | DFT (PBE/def2-SVP) | :obj:`revised aspirin` | 100,000 |
83
83
  +--------------------+--------------------+-------------------------------+-----------+
@@ -309,7 +309,7 @@ class MD17(InMemoryDataset):
309
309
  file_names = {
310
310
  'benzene': 'md17_benzene2017.npz',
311
311
  'uracil': 'md17_uracil.npz',
312
- 'naphtalene': 'md17_naphthalene.npz',
312
+ 'naphthalene': 'md17_naphthalene.npz',
313
313
  'aspirin': 'md17_aspirin.npz',
314
314
  'salicylic acid': 'md17_salicylic.npz',
315
315
  'malonaldehyde': 'md17_malonaldehyde.npz',
@@ -0,0 +1,145 @@
1
+ import os
2
+ import os.path as osp
3
+ from typing import Callable, List, Optional
4
+
5
+ import numpy as np
6
+ import torch
7
+
8
+ from torch_geometric.data import Data, InMemoryDataset
9
+
10
+
11
+ class MedShapeNet(InMemoryDataset):
12
+ r"""The MedShapeNet datasets from the `"MedShapeNet -- A Large-Scale
13
+ Dataset of 3D Medical Shapes for Computer Vision"
14
+ <https://arxiv.org/abs/2308.16139>`_ paper,
15
+ containing 8 different type of structures (classes).
16
+
17
+ .. note::
18
+
19
+ Data objects hold mesh faces instead of edge indices.
20
+ To convert the mesh to a graph, use the
21
+ :obj:`torch_geometric.transforms.FaceToEdge` as :obj:`pre_transform`.
22
+ To convert the mesh to a point cloud, use the
23
+ :obj:`torch_geometric.transforms.SamplePoints` as :obj:`transform` to
24
+ sample a fixed number of points on the mesh faces according to their
25
+ face area.
26
+
27
+ Args:
28
+ root (str): Root directory where the dataset should be saved.
29
+ size (int): Number of invividual 3D structures to download per
30
+ type (classes).
31
+ transform (callable, optional): A function/transform that takes in an
32
+ :obj:`torch_geometric.data.Data` object and returns a transformed
33
+ version. The data object will be transformed before every access.
34
+ (default: :obj:`None`)
35
+ pre_transform (callable, optional): A function/transform that takes in
36
+ an :obj:`torch_geometric.data.Data` object and returns a
37
+ transformed version. The data object will be transformed before
38
+ being saved to disk. (default: :obj:`None`)
39
+ pre_filter (callable, optional): A function that takes in an
40
+ :obj:`torch_geometric.data.Data` object and returns a boolean
41
+ value, indicating whether the data object should be included in the
42
+ final dataset. (default: :obj:`None`)
43
+ force_reload (bool, optional): Whether to re-process the dataset.
44
+ (default: :obj:`False`)
45
+ """
46
+ def __init__(
47
+ self,
48
+ root: str,
49
+ size: int = 100,
50
+ transform: Optional[Callable] = None,
51
+ pre_transform: Optional[Callable] = None,
52
+ pre_filter: Optional[Callable] = None,
53
+ force_reload: bool = False,
54
+ ) -> None:
55
+ self.size = size
56
+ super().__init__(root, transform, pre_transform, pre_filter,
57
+ force_reload=force_reload)
58
+
59
+ path = self.processed_paths[0]
60
+ self.load(path)
61
+
62
+ @property
63
+ def raw_file_names(self) -> List[str]:
64
+ return [
65
+ '3DTeethSeg', 'CoronaryArteries', 'FLARE', 'KITS', 'PULMONARY',
66
+ 'SurgicalInstruments', 'ThoracicAorta_Saitta', 'ToothFairy'
67
+ ]
68
+
69
+ @property
70
+ def processed_file_names(self) -> List[str]:
71
+ return ['dataset.pt']
72
+
73
+ @property
74
+ def raw_paths(self) -> List[str]:
75
+ r"""The absolute filepaths that must be present in order to skip
76
+ downloading.
77
+ """
78
+ return [osp.join(self.raw_dir, f) for f in self.raw_file_names]
79
+
80
+ def process(self) -> None:
81
+ import urllib3
82
+ from MedShapeNet import MedShapeNet as msn
83
+
84
+ msn_instance = msn(timeout=120)
85
+
86
+ urllib3.HTTPConnectionPool("medshapenet.ddns.net", maxsize=50)
87
+
88
+ list_of_datasets = msn_instance.datasets(False)
89
+ list_of_datasets = list(
90
+ filter(
91
+ lambda x: x not in [
92
+ 'medshapenetcore/ASOCA', 'medshapenetcore/AVT',
93
+ 'medshapenetcore/AutoImplantCraniotomy',
94
+ 'medshapenetcore/FaceVR'
95
+ ], list_of_datasets))
96
+
97
+ subset = []
98
+ for dataset in list_of_datasets:
99
+ parts = dataset.split("/")
100
+ self.newpath = self.root + '/' + parts[1 if len(parts) > 1 else 0]
101
+ if not os.path.exists(self.newpath):
102
+ os.makedirs(self.newpath)
103
+ stl_files = msn_instance.dataset_files(dataset, '.stl')
104
+ subset.extend(stl_files[:self.size])
105
+
106
+ for stl_file in stl_files[:self.size]:
107
+ msn_instance.download_stl_as_numpy(bucket_name=dataset,
108
+ stl_file=stl_file,
109
+ output_dir=self.newpath,
110
+ print_output=False)
111
+
112
+ class_mapping = {
113
+ '3DTeethSeg': 0,
114
+ 'CoronaryArteries': 1,
115
+ 'FLARE': 2,
116
+ 'KITS': 3,
117
+ 'PULMONARY': 4,
118
+ 'SurgicalInstruments': 5,
119
+ 'ThoracicAorta_Saitta': 6,
120
+ 'ToothFairy': 7
121
+ }
122
+
123
+ for dataset, path in zip([subset], self.processed_paths):
124
+ data_list = []
125
+ for item in dataset:
126
+ class_name = item.split("/")[0]
127
+ item = item.split("stl")[0]
128
+ target = class_mapping[class_name]
129
+ file = osp.join(self.root, item + 'npz')
130
+
131
+ data = np.load(file)
132
+ pre_data_list = Data(
133
+ pos=torch.tensor(data["vertices"], dtype=torch.float),
134
+ face=torch.tensor(data["faces"],
135
+ dtype=torch.long).t().contiguous())
136
+ pre_data_list.y = torch.tensor([target], dtype=torch.long)
137
+ data_list.append(pre_data_list)
138
+
139
+ if self.pre_filter is not None:
140
+ data_list = [d for d in data_list if self.pre_filter(d)]
141
+
142
+ if self.pre_transform is not None:
143
+ data_list = [self.pre_transform(d) for d in data_list]
144
+
145
+ self.save(data_list, path)
@@ -79,7 +79,7 @@ class ModelNet(InMemoryDataset):
79
79
 
80
80
  urls = {
81
81
  '10':
82
- 'http://vision.princeton.edu/projects/2014/3DShapeNets/ModelNet10.zip',
82
+ 'http://3dvision.princeton.edu/projects/2014/3DShapeNets/ModelNet10.zip', # noqa
83
83
  '40': 'http://modelnet.cs.princeton.edu/ModelNet40.zip'
84
84
  }
85
85