pyg-nightly 2.7.0.dev20241009__py3-none-any.whl → 2.8.0.dev20251207__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pyg-nightly might be problematic. Click here for more details.
- {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/METADATA +77 -53
- {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/RECORD +226 -189
- {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/WHEEL +1 -1
- pyg_nightly-2.8.0.dev20251207.dist-info/licenses/LICENSE +19 -0
- torch_geometric/__init__.py +14 -2
- torch_geometric/_compile.py +9 -3
- torch_geometric/_onnx.py +214 -0
- torch_geometric/config_mixin.py +5 -3
- torch_geometric/config_store.py +1 -1
- torch_geometric/contrib/__init__.py +1 -1
- torch_geometric/contrib/explain/pgm_explainer.py +1 -1
- torch_geometric/data/batch.py +2 -2
- torch_geometric/data/collate.py +1 -3
- torch_geometric/data/data.py +109 -5
- torch_geometric/data/database.py +4 -0
- torch_geometric/data/dataset.py +14 -11
- torch_geometric/data/extract.py +1 -1
- torch_geometric/data/feature_store.py +17 -22
- torch_geometric/data/graph_store.py +3 -2
- torch_geometric/data/hetero_data.py +139 -7
- torch_geometric/data/hypergraph_data.py +2 -2
- torch_geometric/data/in_memory_dataset.py +2 -2
- torch_geometric/data/lightning/datamodule.py +42 -28
- torch_geometric/data/storage.py +9 -1
- torch_geometric/datasets/__init__.py +18 -1
- torch_geometric/datasets/actor.py +7 -9
- torch_geometric/datasets/airfrans.py +15 -17
- torch_geometric/datasets/airports.py +8 -10
- torch_geometric/datasets/amazon.py +8 -11
- torch_geometric/datasets/amazon_book.py +8 -9
- torch_geometric/datasets/amazon_products.py +7 -9
- torch_geometric/datasets/aminer.py +8 -9
- torch_geometric/datasets/aqsol.py +10 -13
- torch_geometric/datasets/attributed_graph_dataset.py +8 -10
- torch_geometric/datasets/ba_multi_shapes.py +10 -12
- torch_geometric/datasets/ba_shapes.py +5 -6
- torch_geometric/datasets/city.py +157 -0
- torch_geometric/datasets/dbp15k.py +1 -1
- torch_geometric/datasets/git_mol_dataset.py +263 -0
- torch_geometric/datasets/hgb_dataset.py +2 -2
- torch_geometric/datasets/hm.py +1 -1
- torch_geometric/datasets/instruct_mol_dataset.py +134 -0
- torch_geometric/datasets/md17.py +3 -3
- torch_geometric/datasets/medshapenet.py +145 -0
- torch_geometric/datasets/modelnet.py +1 -1
- torch_geometric/datasets/molecule_gpt_dataset.py +492 -0
- torch_geometric/datasets/molecule_net.py +3 -2
- torch_geometric/datasets/ppi.py +2 -1
- torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
- torch_geometric/datasets/qm7.py +1 -1
- torch_geometric/datasets/qm9.py +1 -1
- torch_geometric/datasets/snap_dataset.py +8 -4
- torch_geometric/datasets/tag_dataset.py +462 -0
- torch_geometric/datasets/teeth3ds.py +269 -0
- torch_geometric/datasets/web_qsp_dataset.py +310 -209
- torch_geometric/datasets/wikics.py +2 -1
- torch_geometric/deprecation.py +1 -1
- torch_geometric/distributed/__init__.py +13 -0
- torch_geometric/distributed/dist_loader.py +2 -2
- torch_geometric/distributed/partition.py +2 -2
- torch_geometric/distributed/rpc.py +3 -3
- torch_geometric/edge_index.py +18 -14
- torch_geometric/explain/algorithm/attention_explainer.py +219 -29
- torch_geometric/explain/algorithm/base.py +2 -2
- torch_geometric/explain/algorithm/captum.py +1 -1
- torch_geometric/explain/algorithm/captum_explainer.py +2 -1
- torch_geometric/explain/algorithm/gnn_explainer.py +406 -69
- torch_geometric/explain/algorithm/graphmask_explainer.py +8 -8
- torch_geometric/explain/algorithm/pg_explainer.py +305 -47
- torch_geometric/explain/explainer.py +2 -2
- torch_geometric/explain/explanation.py +87 -3
- torch_geometric/explain/metric/faithfulness.py +1 -1
- torch_geometric/graphgym/config.py +3 -2
- torch_geometric/graphgym/imports.py +15 -4
- torch_geometric/graphgym/logger.py +1 -1
- torch_geometric/graphgym/loss.py +1 -1
- torch_geometric/graphgym/models/encoder.py +2 -2
- torch_geometric/graphgym/models/layer.py +1 -1
- torch_geometric/graphgym/utils/comp_budget.py +4 -3
- torch_geometric/hash_tensor.py +798 -0
- torch_geometric/index.py +14 -5
- torch_geometric/inspector.py +4 -0
- torch_geometric/io/fs.py +5 -4
- torch_geometric/llm/__init__.py +9 -0
- torch_geometric/llm/large_graph_indexer.py +741 -0
- torch_geometric/llm/models/__init__.py +23 -0
- torch_geometric/{nn → llm}/models/g_retriever.py +77 -45
- torch_geometric/llm/models/git_mol.py +336 -0
- torch_geometric/llm/models/glem.py +397 -0
- torch_geometric/{nn/nlp → llm/models}/llm.py +179 -31
- torch_geometric/llm/models/llm_judge.py +158 -0
- torch_geometric/llm/models/molecule_gpt.py +222 -0
- torch_geometric/llm/models/protein_mpnn.py +333 -0
- torch_geometric/llm/models/sentence_transformer.py +188 -0
- torch_geometric/llm/models/txt2kg.py +353 -0
- torch_geometric/llm/models/vision_transformer.py +38 -0
- torch_geometric/llm/rag_loader.py +154 -0
- torch_geometric/llm/utils/__init__.py +10 -0
- torch_geometric/llm/utils/backend_utils.py +443 -0
- torch_geometric/llm/utils/feature_store.py +169 -0
- torch_geometric/llm/utils/graph_store.py +199 -0
- torch_geometric/llm/utils/vectorrag.py +125 -0
- torch_geometric/loader/cluster.py +4 -4
- torch_geometric/loader/ibmb_loader.py +4 -4
- torch_geometric/loader/link_loader.py +1 -1
- torch_geometric/loader/link_neighbor_loader.py +2 -1
- torch_geometric/loader/mixin.py +6 -5
- torch_geometric/loader/neighbor_loader.py +1 -1
- torch_geometric/loader/neighbor_sampler.py +2 -2
- torch_geometric/loader/prefetch.py +3 -2
- torch_geometric/loader/temporal_dataloader.py +2 -2
- torch_geometric/loader/utils.py +10 -10
- torch_geometric/metrics/__init__.py +14 -0
- torch_geometric/metrics/link_pred.py +745 -92
- torch_geometric/nn/__init__.py +1 -0
- torch_geometric/nn/aggr/base.py +1 -1
- torch_geometric/nn/aggr/equilibrium.py +1 -1
- torch_geometric/nn/aggr/fused.py +1 -1
- torch_geometric/nn/aggr/patch_transformer.py +8 -2
- torch_geometric/nn/aggr/set_transformer.py +1 -1
- torch_geometric/nn/aggr/utils.py +9 -4
- torch_geometric/nn/attention/__init__.py +9 -1
- torch_geometric/nn/attention/polynormer.py +107 -0
- torch_geometric/nn/attention/qformer.py +71 -0
- torch_geometric/nn/attention/sgformer.py +99 -0
- torch_geometric/nn/conv/__init__.py +2 -0
- torch_geometric/nn/conv/appnp.py +1 -1
- torch_geometric/nn/conv/cugraph/gat_conv.py +8 -2
- torch_geometric/nn/conv/cugraph/rgcn_conv.py +3 -0
- torch_geometric/nn/conv/cugraph/sage_conv.py +3 -0
- torch_geometric/nn/conv/dna_conv.py +1 -1
- torch_geometric/nn/conv/eg_conv.py +7 -7
- torch_geometric/nn/conv/gen_conv.py +1 -1
- torch_geometric/nn/conv/gravnet_conv.py +2 -1
- torch_geometric/nn/conv/hetero_conv.py +2 -1
- torch_geometric/nn/conv/meshcnn_conv.py +487 -0
- torch_geometric/nn/conv/message_passing.py +5 -4
- torch_geometric/nn/conv/rgcn_conv.py +2 -1
- torch_geometric/nn/conv/sg_conv.py +1 -1
- torch_geometric/nn/conv/spline_conv.py +2 -1
- torch_geometric/nn/conv/ssg_conv.py +1 -1
- torch_geometric/nn/conv/transformer_conv.py +5 -3
- torch_geometric/nn/data_parallel.py +5 -4
- torch_geometric/nn/dense/linear.py +0 -20
- torch_geometric/nn/encoding.py +17 -3
- torch_geometric/nn/fx.py +14 -12
- torch_geometric/nn/model_hub.py +2 -15
- torch_geometric/nn/models/__init__.py +11 -2
- torch_geometric/nn/models/attentive_fp.py +1 -1
- torch_geometric/nn/models/attract_repel.py +148 -0
- torch_geometric/nn/models/basic_gnn.py +2 -1
- torch_geometric/nn/models/captum.py +1 -1
- torch_geometric/nn/models/deep_graph_infomax.py +1 -1
- torch_geometric/nn/models/dimenet.py +2 -2
- torch_geometric/nn/models/dimenet_utils.py +4 -2
- torch_geometric/nn/models/gpse.py +1083 -0
- torch_geometric/nn/models/graph_unet.py +13 -4
- torch_geometric/nn/models/lpformer.py +783 -0
- torch_geometric/nn/models/metapath2vec.py +1 -1
- torch_geometric/nn/models/mlp.py +4 -2
- torch_geometric/nn/models/node2vec.py +1 -1
- torch_geometric/nn/models/polynormer.py +206 -0
- torch_geometric/nn/models/rev_gnn.py +3 -3
- torch_geometric/nn/models/sgformer.py +219 -0
- torch_geometric/nn/models/signed_gcn.py +1 -1
- torch_geometric/nn/models/visnet.py +2 -2
- torch_geometric/nn/norm/batch_norm.py +17 -7
- torch_geometric/nn/norm/diff_group_norm.py +7 -2
- torch_geometric/nn/norm/graph_norm.py +9 -4
- torch_geometric/nn/norm/instance_norm.py +5 -1
- torch_geometric/nn/norm/layer_norm.py +15 -7
- torch_geometric/nn/norm/msg_norm.py +8 -2
- torch_geometric/nn/pool/__init__.py +8 -4
- torch_geometric/nn/pool/cluster_pool.py +3 -4
- torch_geometric/nn/pool/connect/base.py +1 -3
- torch_geometric/nn/pool/knn.py +13 -10
- torch_geometric/nn/pool/select/base.py +1 -4
- torch_geometric/nn/to_hetero_module.py +4 -3
- torch_geometric/nn/to_hetero_transformer.py +3 -3
- torch_geometric/nn/to_hetero_with_bases_transformer.py +4 -4
- torch_geometric/profile/__init__.py +2 -0
- torch_geometric/profile/nvtx.py +66 -0
- torch_geometric/profile/utils.py +20 -5
- torch_geometric/sampler/__init__.py +2 -1
- torch_geometric/sampler/base.py +336 -7
- torch_geometric/sampler/hgt_sampler.py +11 -1
- torch_geometric/sampler/neighbor_sampler.py +296 -23
- torch_geometric/sampler/utils.py +93 -5
- torch_geometric/testing/__init__.py +4 -0
- torch_geometric/testing/decorators.py +35 -5
- torch_geometric/testing/distributed.py +1 -1
- torch_geometric/transforms/__init__.py +2 -0
- torch_geometric/transforms/add_gpse.py +49 -0
- torch_geometric/transforms/add_metapaths.py +8 -6
- torch_geometric/transforms/add_positional_encoding.py +2 -2
- torch_geometric/transforms/base_transform.py +2 -1
- torch_geometric/transforms/delaunay.py +65 -15
- torch_geometric/transforms/face_to_edge.py +32 -3
- torch_geometric/transforms/gdc.py +7 -8
- torch_geometric/transforms/largest_connected_components.py +1 -1
- torch_geometric/transforms/mask.py +5 -1
- torch_geometric/transforms/normalize_features.py +3 -3
- torch_geometric/transforms/random_link_split.py +1 -1
- torch_geometric/transforms/remove_duplicated_edges.py +4 -2
- torch_geometric/transforms/rooted_subgraph.py +1 -1
- torch_geometric/typing.py +70 -17
- torch_geometric/utils/__init__.py +4 -1
- torch_geometric/utils/_lexsort.py +0 -9
- torch_geometric/utils/_negative_sampling.py +27 -12
- torch_geometric/utils/_scatter.py +132 -195
- torch_geometric/utils/_sort_edge_index.py +0 -2
- torch_geometric/utils/_spmm.py +16 -14
- torch_geometric/utils/_subgraph.py +4 -0
- torch_geometric/utils/_trim_to_layer.py +2 -2
- torch_geometric/utils/convert.py +17 -10
- torch_geometric/utils/cross_entropy.py +34 -13
- torch_geometric/utils/embedding.py +91 -2
- torch_geometric/utils/geodesic.py +4 -3
- torch_geometric/utils/influence.py +279 -0
- torch_geometric/utils/map.py +13 -9
- torch_geometric/utils/nested.py +1 -1
- torch_geometric/utils/smiles.py +3 -3
- torch_geometric/utils/sparse.py +7 -14
- torch_geometric/visualization/__init__.py +2 -1
- torch_geometric/visualization/graph.py +250 -5
- torch_geometric/warnings.py +11 -2
- torch_geometric/nn/nlp/__init__.py +0 -7
- torch_geometric/nn/nlp/sentence_transformer.py +0 -101
|
@@ -123,8 +123,8 @@ class HGBDataset(InMemoryDataset):
|
|
|
123
123
|
start = info.index('LINK\tSTART\tEND\tMEANING') + 1
|
|
124
124
|
end = info[start:].index('')
|
|
125
125
|
for key, row in enumerate(info[start:start + end]):
|
|
126
|
-
|
|
127
|
-
src, dst, rel = (v for v in
|
|
126
|
+
edge = row.split('\t')[1:]
|
|
127
|
+
src, dst, rel = (v for v in edge if v != '')
|
|
128
128
|
src, dst = n_types[int(src)], n_types[int(dst)]
|
|
129
129
|
rel = rel.split('-')[1]
|
|
130
130
|
e_types[key] = (src, rel, dst)
|
torch_geometric/datasets/hm.py
CHANGED
|
@@ -81,7 +81,7 @@ class HM(InMemoryDataset):
|
|
|
81
81
|
xs.append(torch.from_numpy(x).to(torch.float))
|
|
82
82
|
|
|
83
83
|
x = torch.from_numpy(df['age'].values).to(torch.float).view(-1, 1)
|
|
84
|
-
x = x.nan_to_num(nan=x.nanmean())
|
|
84
|
+
x = x.nan_to_num(nan=x.nanmean()) # type: ignore
|
|
85
85
|
xs.append(x / x.max())
|
|
86
86
|
|
|
87
87
|
data['customer'].x = torch.cat(xs, dim=-1)
|
|
@@ -0,0 +1,134 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import sys
|
|
3
|
+
from typing import Callable, List, Optional
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from tqdm import tqdm
|
|
7
|
+
|
|
8
|
+
from torch_geometric.data import Data, InMemoryDataset
|
|
9
|
+
from torch_geometric.io import fs
|
|
10
|
+
from torch_geometric.utils import one_hot
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class InstructMolDataset(InMemoryDataset):
|
|
14
|
+
r"""The dataset from the `"InstructMol: Multi-Modal Integration for
|
|
15
|
+
Building a Versatile and Reliable Molecular Assistant in Drug Discovery"
|
|
16
|
+
<https://arxiv.org/pdf/2311.16208>`_ paper.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
root (str): Root directory where the dataset should be saved.
|
|
20
|
+
transform (callable, optional): A function/transform that takes in an
|
|
21
|
+
:obj:`torch_geometric.data.Data` object and returns a transformed
|
|
22
|
+
version. The data object will be transformed before every access.
|
|
23
|
+
(default: :obj:`None`)
|
|
24
|
+
pre_transform (callable, optional): A function/transform that takes in
|
|
25
|
+
an :obj:`torch_geometric.data.Data` object and returns a
|
|
26
|
+
transformed version. The data object will be transformed before
|
|
27
|
+
being saved to disk. (default: :obj:`None`)
|
|
28
|
+
pre_filter (callable, optional): A function that takes in an
|
|
29
|
+
:obj:`torch_geometric.data.Data` object and returns a boolean
|
|
30
|
+
value, indicating whether the data object should be included in the
|
|
31
|
+
final dataset. (default: :obj:`None`)
|
|
32
|
+
force_reload (bool, optional): Whether to re-process the dataset.
|
|
33
|
+
(default: :obj:`False`)
|
|
34
|
+
"""
|
|
35
|
+
raw_url = 'https://huggingface.co/datasets/OpenMol/PubChemSFT/resolve/main'
|
|
36
|
+
|
|
37
|
+
def __init__(
|
|
38
|
+
self,
|
|
39
|
+
root: str,
|
|
40
|
+
transform: Optional[Callable] = None,
|
|
41
|
+
pre_transform: Optional[Callable] = None,
|
|
42
|
+
pre_filter: Optional[Callable] = None,
|
|
43
|
+
force_reload: bool = False,
|
|
44
|
+
):
|
|
45
|
+
super().__init__(root, transform, pre_transform, pre_filter,
|
|
46
|
+
force_reload=force_reload)
|
|
47
|
+
self.load(self.processed_paths[0])
|
|
48
|
+
|
|
49
|
+
@property
|
|
50
|
+
def raw_file_names(self) -> List[str]:
|
|
51
|
+
return ['all_clean.json']
|
|
52
|
+
|
|
53
|
+
@property
|
|
54
|
+
def processed_file_names(self) -> List[str]:
|
|
55
|
+
return ['data.pt']
|
|
56
|
+
|
|
57
|
+
def download(self) -> None:
|
|
58
|
+
print('downloading dataset...')
|
|
59
|
+
fs.cp(f'{self.raw_url}/all_clean.json', self.raw_dir)
|
|
60
|
+
|
|
61
|
+
def process(self) -> None:
|
|
62
|
+
try:
|
|
63
|
+
from rdkit import Chem
|
|
64
|
+
from rdkit.Chem.rdchem import BondType as BT
|
|
65
|
+
WITH_RDKIT = True
|
|
66
|
+
|
|
67
|
+
except ImportError:
|
|
68
|
+
WITH_RDKIT = False
|
|
69
|
+
|
|
70
|
+
if not WITH_RDKIT:
|
|
71
|
+
print(("Using a pre-processed version of the dataset. Please "
|
|
72
|
+
"install 'rdkit' to alternatively process the raw data."),
|
|
73
|
+
file=sys.stderr)
|
|
74
|
+
|
|
75
|
+
data_list = fs.torch_load(self.raw_paths[0])
|
|
76
|
+
data_list = [Data(**data_dict) for data_dict in data_list]
|
|
77
|
+
|
|
78
|
+
if self.pre_filter is not None:
|
|
79
|
+
data_list = [d for d in data_list if self.pre_filter(d)]
|
|
80
|
+
|
|
81
|
+
if self.pre_transform is not None:
|
|
82
|
+
data_list = [self.pre_transform(d) for d in data_list]
|
|
83
|
+
|
|
84
|
+
self.save(data_list, self.processed_paths[0])
|
|
85
|
+
return
|
|
86
|
+
|
|
87
|
+
# types of atom and bond
|
|
88
|
+
types = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4, 'Unknow': 5}
|
|
89
|
+
bonds = {BT.SINGLE: 0, BT.DOUBLE: 1, BT.TRIPLE: 2, BT.AROMATIC: 3}
|
|
90
|
+
|
|
91
|
+
# load data
|
|
92
|
+
mols = json.load(open(f'{self.raw_dir}/all_clean.json'))
|
|
93
|
+
|
|
94
|
+
data_list = []
|
|
95
|
+
for smiles, qa_pairs in tqdm(mols.items(), total=len(mols)):
|
|
96
|
+
mol = Chem.MolFromSmiles(smiles)
|
|
97
|
+
if mol is None:
|
|
98
|
+
continue
|
|
99
|
+
|
|
100
|
+
x: torch.Tensor = torch.tensor([
|
|
101
|
+
types[atom.GetSymbol()] if atom.GetSymbol() in types else 5
|
|
102
|
+
for atom in mol.GetAtoms()
|
|
103
|
+
])
|
|
104
|
+
x = one_hot(x, num_classes=len(types), dtype=torch.float)
|
|
105
|
+
|
|
106
|
+
rows, cols, edge_types = [], [], []
|
|
107
|
+
for bond in mol.GetBonds():
|
|
108
|
+
i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
|
|
109
|
+
edge_types += [bonds[bond.GetBondType()]] * 2
|
|
110
|
+
rows += [i, j]
|
|
111
|
+
cols += [j, i]
|
|
112
|
+
|
|
113
|
+
edge_index = torch.tensor([rows, cols], dtype=torch.long)
|
|
114
|
+
edge_type = torch.tensor(edge_types, dtype=torch.long)
|
|
115
|
+
edge_attr = one_hot(edge_type, num_classes=len(bonds))
|
|
116
|
+
|
|
117
|
+
for question, answer in qa_pairs:
|
|
118
|
+
data = Data(
|
|
119
|
+
x=x,
|
|
120
|
+
edge_index=edge_index,
|
|
121
|
+
edge_attr=edge_attr,
|
|
122
|
+
smiles=smiles,
|
|
123
|
+
instruction=question,
|
|
124
|
+
y=answer,
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
if self.pre_filter is not None and not self.pre_filter(data):
|
|
128
|
+
continue
|
|
129
|
+
if self.pre_transform is not None:
|
|
130
|
+
data = self.pre_transform(data)
|
|
131
|
+
|
|
132
|
+
data_list.append(data)
|
|
133
|
+
|
|
134
|
+
self.save(data_list, self.processed_paths[0])
|
torch_geometric/datasets/md17.py
CHANGED
|
@@ -57,7 +57,7 @@ class MD17(InMemoryDataset):
|
|
|
57
57
|
+--------------------+--------------------+-------------------------------+-----------+
|
|
58
58
|
| Uracil | DFT | :obj:`uracil` | 133,770 |
|
|
59
59
|
+--------------------+--------------------+-------------------------------+-----------+
|
|
60
|
-
| Naphthalene | DFT | :obj:`
|
|
60
|
+
| Naphthalene | DFT | :obj:`naphthalene` | 326,250 |
|
|
61
61
|
+--------------------+--------------------+-------------------------------+-----------+
|
|
62
62
|
| Aspirin | DFT | :obj:`aspirin` | 211,762 |
|
|
63
63
|
+--------------------+--------------------+-------------------------------+-----------+
|
|
@@ -77,7 +77,7 @@ class MD17(InMemoryDataset):
|
|
|
77
77
|
+--------------------+--------------------+-------------------------------+-----------+
|
|
78
78
|
| Uracil (R) | DFT (PBE/def2-SVP) | :obj:`revised uracil` | 100,000 |
|
|
79
79
|
+--------------------+--------------------+-------------------------------+-----------+
|
|
80
|
-
| Naphthalene (R) | DFT (PBE/def2-SVP) | :obj:`revised
|
|
80
|
+
| Naphthalene (R) | DFT (PBE/def2-SVP) | :obj:`revised naphthalene` | 100,000 |
|
|
81
81
|
+--------------------+--------------------+-------------------------------+-----------+
|
|
82
82
|
| Aspirin (R) | DFT (PBE/def2-SVP) | :obj:`revised aspirin` | 100,000 |
|
|
83
83
|
+--------------------+--------------------+-------------------------------+-----------+
|
|
@@ -309,7 +309,7 @@ class MD17(InMemoryDataset):
|
|
|
309
309
|
file_names = {
|
|
310
310
|
'benzene': 'md17_benzene2017.npz',
|
|
311
311
|
'uracil': 'md17_uracil.npz',
|
|
312
|
-
'
|
|
312
|
+
'naphthalene': 'md17_naphthalene.npz',
|
|
313
313
|
'aspirin': 'md17_aspirin.npz',
|
|
314
314
|
'salicylic acid': 'md17_salicylic.npz',
|
|
315
315
|
'malonaldehyde': 'md17_malonaldehyde.npz',
|
|
@@ -0,0 +1,145 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import os.path as osp
|
|
3
|
+
from typing import Callable, List, Optional
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
from torch_geometric.data import Data, InMemoryDataset
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class MedShapeNet(InMemoryDataset):
|
|
12
|
+
r"""The MedShapeNet datasets from the `"MedShapeNet -- A Large-Scale
|
|
13
|
+
Dataset of 3D Medical Shapes for Computer Vision"
|
|
14
|
+
<https://arxiv.org/abs/2308.16139>`_ paper,
|
|
15
|
+
containing 8 different type of structures (classes).
|
|
16
|
+
|
|
17
|
+
.. note::
|
|
18
|
+
|
|
19
|
+
Data objects hold mesh faces instead of edge indices.
|
|
20
|
+
To convert the mesh to a graph, use the
|
|
21
|
+
:obj:`torch_geometric.transforms.FaceToEdge` as :obj:`pre_transform`.
|
|
22
|
+
To convert the mesh to a point cloud, use the
|
|
23
|
+
:obj:`torch_geometric.transforms.SamplePoints` as :obj:`transform` to
|
|
24
|
+
sample a fixed number of points on the mesh faces according to their
|
|
25
|
+
face area.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
root (str): Root directory where the dataset should be saved.
|
|
29
|
+
size (int): Number of invividual 3D structures to download per
|
|
30
|
+
type (classes).
|
|
31
|
+
transform (callable, optional): A function/transform that takes in an
|
|
32
|
+
:obj:`torch_geometric.data.Data` object and returns a transformed
|
|
33
|
+
version. The data object will be transformed before every access.
|
|
34
|
+
(default: :obj:`None`)
|
|
35
|
+
pre_transform (callable, optional): A function/transform that takes in
|
|
36
|
+
an :obj:`torch_geometric.data.Data` object and returns a
|
|
37
|
+
transformed version. The data object will be transformed before
|
|
38
|
+
being saved to disk. (default: :obj:`None`)
|
|
39
|
+
pre_filter (callable, optional): A function that takes in an
|
|
40
|
+
:obj:`torch_geometric.data.Data` object and returns a boolean
|
|
41
|
+
value, indicating whether the data object should be included in the
|
|
42
|
+
final dataset. (default: :obj:`None`)
|
|
43
|
+
force_reload (bool, optional): Whether to re-process the dataset.
|
|
44
|
+
(default: :obj:`False`)
|
|
45
|
+
"""
|
|
46
|
+
def __init__(
|
|
47
|
+
self,
|
|
48
|
+
root: str,
|
|
49
|
+
size: int = 100,
|
|
50
|
+
transform: Optional[Callable] = None,
|
|
51
|
+
pre_transform: Optional[Callable] = None,
|
|
52
|
+
pre_filter: Optional[Callable] = None,
|
|
53
|
+
force_reload: bool = False,
|
|
54
|
+
) -> None:
|
|
55
|
+
self.size = size
|
|
56
|
+
super().__init__(root, transform, pre_transform, pre_filter,
|
|
57
|
+
force_reload=force_reload)
|
|
58
|
+
|
|
59
|
+
path = self.processed_paths[0]
|
|
60
|
+
self.load(path)
|
|
61
|
+
|
|
62
|
+
@property
|
|
63
|
+
def raw_file_names(self) -> List[str]:
|
|
64
|
+
return [
|
|
65
|
+
'3DTeethSeg', 'CoronaryArteries', 'FLARE', 'KITS', 'PULMONARY',
|
|
66
|
+
'SurgicalInstruments', 'ThoracicAorta_Saitta', 'ToothFairy'
|
|
67
|
+
]
|
|
68
|
+
|
|
69
|
+
@property
|
|
70
|
+
def processed_file_names(self) -> List[str]:
|
|
71
|
+
return ['dataset.pt']
|
|
72
|
+
|
|
73
|
+
@property
|
|
74
|
+
def raw_paths(self) -> List[str]:
|
|
75
|
+
r"""The absolute filepaths that must be present in order to skip
|
|
76
|
+
downloading.
|
|
77
|
+
"""
|
|
78
|
+
return [osp.join(self.raw_dir, f) for f in self.raw_file_names]
|
|
79
|
+
|
|
80
|
+
def process(self) -> None:
|
|
81
|
+
import urllib3
|
|
82
|
+
from MedShapeNet import MedShapeNet as msn
|
|
83
|
+
|
|
84
|
+
msn_instance = msn(timeout=120)
|
|
85
|
+
|
|
86
|
+
urllib3.HTTPConnectionPool("medshapenet.ddns.net", maxsize=50)
|
|
87
|
+
|
|
88
|
+
list_of_datasets = msn_instance.datasets(False)
|
|
89
|
+
list_of_datasets = list(
|
|
90
|
+
filter(
|
|
91
|
+
lambda x: x not in [
|
|
92
|
+
'medshapenetcore/ASOCA', 'medshapenetcore/AVT',
|
|
93
|
+
'medshapenetcore/AutoImplantCraniotomy',
|
|
94
|
+
'medshapenetcore/FaceVR'
|
|
95
|
+
], list_of_datasets))
|
|
96
|
+
|
|
97
|
+
subset = []
|
|
98
|
+
for dataset in list_of_datasets:
|
|
99
|
+
parts = dataset.split("/")
|
|
100
|
+
self.newpath = self.root + '/' + parts[1 if len(parts) > 1 else 0]
|
|
101
|
+
if not os.path.exists(self.newpath):
|
|
102
|
+
os.makedirs(self.newpath)
|
|
103
|
+
stl_files = msn_instance.dataset_files(dataset, '.stl')
|
|
104
|
+
subset.extend(stl_files[:self.size])
|
|
105
|
+
|
|
106
|
+
for stl_file in stl_files[:self.size]:
|
|
107
|
+
msn_instance.download_stl_as_numpy(bucket_name=dataset,
|
|
108
|
+
stl_file=stl_file,
|
|
109
|
+
output_dir=self.newpath,
|
|
110
|
+
print_output=False)
|
|
111
|
+
|
|
112
|
+
class_mapping = {
|
|
113
|
+
'3DTeethSeg': 0,
|
|
114
|
+
'CoronaryArteries': 1,
|
|
115
|
+
'FLARE': 2,
|
|
116
|
+
'KITS': 3,
|
|
117
|
+
'PULMONARY': 4,
|
|
118
|
+
'SurgicalInstruments': 5,
|
|
119
|
+
'ThoracicAorta_Saitta': 6,
|
|
120
|
+
'ToothFairy': 7
|
|
121
|
+
}
|
|
122
|
+
|
|
123
|
+
for dataset, path in zip([subset], self.processed_paths):
|
|
124
|
+
data_list = []
|
|
125
|
+
for item in dataset:
|
|
126
|
+
class_name = item.split("/")[0]
|
|
127
|
+
item = item.split("stl")[0]
|
|
128
|
+
target = class_mapping[class_name]
|
|
129
|
+
file = osp.join(self.root, item + 'npz')
|
|
130
|
+
|
|
131
|
+
data = np.load(file)
|
|
132
|
+
pre_data_list = Data(
|
|
133
|
+
pos=torch.tensor(data["vertices"], dtype=torch.float),
|
|
134
|
+
face=torch.tensor(data["faces"],
|
|
135
|
+
dtype=torch.long).t().contiguous())
|
|
136
|
+
pre_data_list.y = torch.tensor([target], dtype=torch.long)
|
|
137
|
+
data_list.append(pre_data_list)
|
|
138
|
+
|
|
139
|
+
if self.pre_filter is not None:
|
|
140
|
+
data_list = [d for d in data_list if self.pre_filter(d)]
|
|
141
|
+
|
|
142
|
+
if self.pre_transform is not None:
|
|
143
|
+
data_list = [self.pre_transform(d) for d in data_list]
|
|
144
|
+
|
|
145
|
+
self.save(data_list, path)
|
|
@@ -79,7 +79,7 @@ class ModelNet(InMemoryDataset):
|
|
|
79
79
|
|
|
80
80
|
urls = {
|
|
81
81
|
'10':
|
|
82
|
-
'http://
|
|
82
|
+
'http://3dvision.princeton.edu/projects/2014/3DShapeNets/ModelNet10.zip', # noqa
|
|
83
83
|
'40': 'http://modelnet.cs.princeton.edu/ModelNet40.zip'
|
|
84
84
|
}
|
|
85
85
|
|