pyg-nightly 2.6.0.dev20240704__py3-none-any.whl → 2.8.0.dev20251207__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pyg-nightly might be problematic. Click here for more details.
- {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/METADATA +81 -58
- {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/RECORD +265 -221
- {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/WHEEL +1 -1
- pyg_nightly-2.8.0.dev20251207.dist-info/licenses/LICENSE +19 -0
- torch_geometric/__init__.py +34 -1
- torch_geometric/_compile.py +11 -3
- torch_geometric/_onnx.py +228 -0
- torch_geometric/config_mixin.py +8 -3
- torch_geometric/config_store.py +1 -1
- torch_geometric/contrib/__init__.py +1 -1
- torch_geometric/contrib/explain/pgm_explainer.py +1 -1
- torch_geometric/data/__init__.py +19 -1
- torch_geometric/data/batch.py +2 -2
- torch_geometric/data/collate.py +1 -3
- torch_geometric/data/data.py +110 -6
- torch_geometric/data/database.py +19 -5
- torch_geometric/data/dataset.py +14 -9
- torch_geometric/data/extract.py +1 -1
- torch_geometric/data/feature_store.py +17 -22
- torch_geometric/data/graph_store.py +3 -2
- torch_geometric/data/hetero_data.py +139 -7
- torch_geometric/data/hypergraph_data.py +2 -2
- torch_geometric/data/in_memory_dataset.py +2 -2
- torch_geometric/data/lightning/datamodule.py +42 -28
- torch_geometric/data/storage.py +9 -1
- torch_geometric/datasets/__init__.py +20 -1
- torch_geometric/datasets/actor.py +7 -9
- torch_geometric/datasets/airfrans.py +17 -20
- torch_geometric/datasets/airports.py +8 -10
- torch_geometric/datasets/amazon.py +8 -11
- torch_geometric/datasets/amazon_book.py +8 -9
- torch_geometric/datasets/amazon_products.py +7 -9
- torch_geometric/datasets/aminer.py +8 -9
- torch_geometric/datasets/aqsol.py +10 -13
- torch_geometric/datasets/attributed_graph_dataset.py +8 -10
- torch_geometric/datasets/ba_multi_shapes.py +10 -12
- torch_geometric/datasets/ba_shapes.py +5 -6
- torch_geometric/datasets/brca_tgca.py +1 -1
- torch_geometric/datasets/city.py +157 -0
- torch_geometric/datasets/dbp15k.py +1 -1
- torch_geometric/datasets/gdelt_lite.py +3 -2
- torch_geometric/datasets/ged_dataset.py +3 -2
- torch_geometric/datasets/git_mol_dataset.py +263 -0
- torch_geometric/datasets/gnn_benchmark_dataset.py +3 -2
- torch_geometric/datasets/hgb_dataset.py +2 -2
- torch_geometric/datasets/hm.py +1 -1
- torch_geometric/datasets/instruct_mol_dataset.py +134 -0
- torch_geometric/datasets/linkx_dataset.py +4 -3
- torch_geometric/datasets/lrgb.py +3 -5
- torch_geometric/datasets/malnet_tiny.py +2 -1
- torch_geometric/datasets/md17.py +3 -3
- torch_geometric/datasets/medshapenet.py +145 -0
- torch_geometric/datasets/mnist_superpixels.py +2 -3
- torch_geometric/datasets/modelnet.py +1 -1
- torch_geometric/datasets/molecule_gpt_dataset.py +492 -0
- torch_geometric/datasets/molecule_net.py +3 -2
- torch_geometric/datasets/neurograph.py +1 -3
- torch_geometric/datasets/ogb_mag.py +1 -1
- torch_geometric/datasets/opf.py +19 -5
- torch_geometric/datasets/pascal_pf.py +1 -1
- torch_geometric/datasets/pcqm4m.py +2 -1
- torch_geometric/datasets/ppi.py +2 -1
- torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
- torch_geometric/datasets/qm7.py +1 -1
- torch_geometric/datasets/qm9.py +3 -2
- torch_geometric/datasets/shrec2016.py +2 -2
- torch_geometric/datasets/snap_dataset.py +8 -4
- torch_geometric/datasets/tag_dataset.py +462 -0
- torch_geometric/datasets/teeth3ds.py +269 -0
- torch_geometric/datasets/web_qsp_dataset.py +342 -0
- torch_geometric/datasets/wikics.py +2 -1
- torch_geometric/datasets/wikidata.py +2 -1
- torch_geometric/deprecation.py +1 -1
- torch_geometric/distributed/__init__.py +13 -0
- torch_geometric/distributed/dist_loader.py +2 -2
- torch_geometric/distributed/local_feature_store.py +3 -2
- torch_geometric/distributed/local_graph_store.py +2 -1
- torch_geometric/distributed/partition.py +9 -8
- torch_geometric/distributed/rpc.py +3 -3
- torch_geometric/edge_index.py +35 -22
- torch_geometric/explain/algorithm/attention_explainer.py +219 -29
- torch_geometric/explain/algorithm/base.py +2 -2
- torch_geometric/explain/algorithm/captum.py +1 -1
- torch_geometric/explain/algorithm/captum_explainer.py +2 -1
- torch_geometric/explain/algorithm/gnn_explainer.py +406 -69
- torch_geometric/explain/algorithm/graphmask_explainer.py +8 -8
- torch_geometric/explain/algorithm/pg_explainer.py +305 -47
- torch_geometric/explain/explainer.py +2 -2
- torch_geometric/explain/explanation.py +89 -5
- torch_geometric/explain/metric/faithfulness.py +1 -1
- torch_geometric/graphgym/checkpoint.py +2 -1
- torch_geometric/graphgym/config.py +3 -2
- torch_geometric/graphgym/imports.py +15 -4
- torch_geometric/graphgym/logger.py +1 -1
- torch_geometric/graphgym/loss.py +1 -1
- torch_geometric/graphgym/models/encoder.py +2 -2
- torch_geometric/graphgym/models/layer.py +1 -1
- torch_geometric/graphgym/utils/comp_budget.py +4 -3
- torch_geometric/hash_tensor.py +798 -0
- torch_geometric/index.py +16 -7
- torch_geometric/inspector.py +6 -2
- torch_geometric/io/fs.py +27 -0
- torch_geometric/io/tu.py +2 -3
- torch_geometric/llm/__init__.py +9 -0
- torch_geometric/llm/large_graph_indexer.py +741 -0
- torch_geometric/llm/models/__init__.py +23 -0
- torch_geometric/llm/models/g_retriever.py +251 -0
- torch_geometric/llm/models/git_mol.py +336 -0
- torch_geometric/llm/models/glem.py +397 -0
- torch_geometric/llm/models/llm.py +470 -0
- torch_geometric/llm/models/llm_judge.py +158 -0
- torch_geometric/llm/models/molecule_gpt.py +222 -0
- torch_geometric/llm/models/protein_mpnn.py +333 -0
- torch_geometric/llm/models/sentence_transformer.py +188 -0
- torch_geometric/llm/models/txt2kg.py +353 -0
- torch_geometric/llm/models/vision_transformer.py +38 -0
- torch_geometric/llm/rag_loader.py +154 -0
- torch_geometric/llm/utils/__init__.py +10 -0
- torch_geometric/llm/utils/backend_utils.py +443 -0
- torch_geometric/llm/utils/feature_store.py +169 -0
- torch_geometric/llm/utils/graph_store.py +199 -0
- torch_geometric/llm/utils/vectorrag.py +125 -0
- torch_geometric/loader/cluster.py +6 -5
- torch_geometric/loader/graph_saint.py +2 -1
- torch_geometric/loader/ibmb_loader.py +4 -4
- torch_geometric/loader/link_loader.py +1 -1
- torch_geometric/loader/link_neighbor_loader.py +2 -1
- torch_geometric/loader/mixin.py +6 -5
- torch_geometric/loader/neighbor_loader.py +1 -1
- torch_geometric/loader/neighbor_sampler.py +2 -2
- torch_geometric/loader/prefetch.py +4 -3
- torch_geometric/loader/temporal_dataloader.py +2 -2
- torch_geometric/loader/utils.py +10 -10
- torch_geometric/metrics/__init__.py +23 -2
- torch_geometric/metrics/link_pred.py +755 -85
- torch_geometric/nn/__init__.py +1 -0
- torch_geometric/nn/aggr/__init__.py +2 -0
- torch_geometric/nn/aggr/base.py +1 -1
- torch_geometric/nn/aggr/equilibrium.py +1 -1
- torch_geometric/nn/aggr/fused.py +1 -1
- torch_geometric/nn/aggr/patch_transformer.py +149 -0
- torch_geometric/nn/aggr/set_transformer.py +1 -1
- torch_geometric/nn/aggr/utils.py +9 -4
- torch_geometric/nn/attention/__init__.py +9 -1
- torch_geometric/nn/attention/polynormer.py +107 -0
- torch_geometric/nn/attention/qformer.py +71 -0
- torch_geometric/nn/attention/sgformer.py +99 -0
- torch_geometric/nn/conv/__init__.py +2 -0
- torch_geometric/nn/conv/appnp.py +1 -1
- torch_geometric/nn/conv/collect.jinja +6 -3
- torch_geometric/nn/conv/cugraph/gat_conv.py +8 -2
- torch_geometric/nn/conv/cugraph/rgcn_conv.py +3 -0
- torch_geometric/nn/conv/cugraph/sage_conv.py +3 -0
- torch_geometric/nn/conv/dna_conv.py +1 -1
- torch_geometric/nn/conv/eg_conv.py +7 -7
- torch_geometric/nn/conv/gat_conv.py +33 -4
- torch_geometric/nn/conv/gatv2_conv.py +35 -4
- torch_geometric/nn/conv/gen_conv.py +1 -1
- torch_geometric/nn/conv/general_conv.py +1 -1
- torch_geometric/nn/conv/gravnet_conv.py +2 -1
- torch_geometric/nn/conv/hetero_conv.py +3 -2
- torch_geometric/nn/conv/meshcnn_conv.py +487 -0
- torch_geometric/nn/conv/message_passing.py +6 -5
- torch_geometric/nn/conv/mixhop_conv.py +1 -1
- torch_geometric/nn/conv/rgcn_conv.py +2 -1
- torch_geometric/nn/conv/sg_conv.py +1 -1
- torch_geometric/nn/conv/spline_conv.py +2 -1
- torch_geometric/nn/conv/ssg_conv.py +1 -1
- torch_geometric/nn/conv/transformer_conv.py +5 -3
- torch_geometric/nn/data_parallel.py +5 -4
- torch_geometric/nn/dense/linear.py +5 -24
- torch_geometric/nn/encoding.py +17 -3
- torch_geometric/nn/fx.py +17 -15
- torch_geometric/nn/model_hub.py +5 -16
- torch_geometric/nn/models/__init__.py +11 -0
- torch_geometric/nn/models/attentive_fp.py +1 -1
- torch_geometric/nn/models/attract_repel.py +148 -0
- torch_geometric/nn/models/basic_gnn.py +2 -1
- torch_geometric/nn/models/captum.py +1 -1
- torch_geometric/nn/models/deep_graph_infomax.py +1 -1
- torch_geometric/nn/models/dimenet.py +2 -2
- torch_geometric/nn/models/dimenet_utils.py +4 -2
- torch_geometric/nn/models/gpse.py +1083 -0
- torch_geometric/nn/models/graph_unet.py +13 -4
- torch_geometric/nn/models/lpformer.py +783 -0
- torch_geometric/nn/models/metapath2vec.py +1 -1
- torch_geometric/nn/models/mlp.py +4 -2
- torch_geometric/nn/models/node2vec.py +1 -1
- torch_geometric/nn/models/polynormer.py +206 -0
- torch_geometric/nn/models/rev_gnn.py +3 -3
- torch_geometric/nn/models/schnet.py +2 -1
- torch_geometric/nn/models/sgformer.py +219 -0
- torch_geometric/nn/models/signed_gcn.py +1 -1
- torch_geometric/nn/models/visnet.py +2 -2
- torch_geometric/nn/norm/batch_norm.py +17 -7
- torch_geometric/nn/norm/diff_group_norm.py +7 -2
- torch_geometric/nn/norm/graph_norm.py +9 -4
- torch_geometric/nn/norm/instance_norm.py +5 -1
- torch_geometric/nn/norm/layer_norm.py +15 -7
- torch_geometric/nn/norm/msg_norm.py +8 -2
- torch_geometric/nn/pool/__init__.py +15 -9
- torch_geometric/nn/pool/cluster_pool.py +144 -0
- torch_geometric/nn/pool/connect/base.py +1 -3
- torch_geometric/nn/pool/edge_pool.py +1 -1
- torch_geometric/nn/pool/knn.py +13 -10
- torch_geometric/nn/pool/select/base.py +1 -4
- torch_geometric/nn/summary.py +1 -1
- torch_geometric/nn/to_hetero_module.py +4 -3
- torch_geometric/nn/to_hetero_transformer.py +3 -3
- torch_geometric/nn/to_hetero_with_bases_transformer.py +5 -5
- torch_geometric/profile/__init__.py +2 -0
- torch_geometric/profile/nvtx.py +66 -0
- torch_geometric/profile/profiler.py +18 -9
- torch_geometric/profile/utils.py +20 -5
- torch_geometric/sampler/__init__.py +2 -1
- torch_geometric/sampler/base.py +337 -8
- torch_geometric/sampler/hgt_sampler.py +11 -1
- torch_geometric/sampler/neighbor_sampler.py +298 -25
- torch_geometric/sampler/utils.py +93 -5
- torch_geometric/testing/__init__.py +4 -0
- torch_geometric/testing/decorators.py +35 -5
- torch_geometric/testing/distributed.py +1 -1
- torch_geometric/transforms/__init__.py +4 -0
- torch_geometric/transforms/add_gpse.py +49 -0
- torch_geometric/transforms/add_metapaths.py +10 -8
- torch_geometric/transforms/add_positional_encoding.py +2 -2
- torch_geometric/transforms/base_transform.py +2 -1
- torch_geometric/transforms/delaunay.py +65 -15
- torch_geometric/transforms/face_to_edge.py +32 -3
- torch_geometric/transforms/gdc.py +8 -9
- torch_geometric/transforms/largest_connected_components.py +1 -1
- torch_geometric/transforms/mask.py +5 -1
- torch_geometric/transforms/node_property_split.py +1 -1
- torch_geometric/transforms/normalize_features.py +3 -3
- torch_geometric/transforms/pad.py +1 -1
- torch_geometric/transforms/random_link_split.py +1 -1
- torch_geometric/transforms/remove_duplicated_edges.py +4 -2
- torch_geometric/transforms/remove_self_loops.py +36 -0
- torch_geometric/transforms/rooted_subgraph.py +1 -1
- torch_geometric/transforms/svd_feature_reduction.py +1 -1
- torch_geometric/transforms/virtual_node.py +2 -1
- torch_geometric/typing.py +82 -17
- torch_geometric/utils/__init__.py +6 -1
- torch_geometric/utils/_lexsort.py +0 -9
- torch_geometric/utils/_negative_sampling.py +28 -13
- torch_geometric/utils/_normalize_edge_index.py +46 -0
- torch_geometric/utils/_scatter.py +126 -164
- torch_geometric/utils/_sort_edge_index.py +0 -2
- torch_geometric/utils/_spmm.py +16 -14
- torch_geometric/utils/_subgraph.py +4 -0
- torch_geometric/utils/_tree_decomposition.py +1 -1
- torch_geometric/utils/_trim_to_layer.py +2 -2
- torch_geometric/utils/augmentation.py +1 -1
- torch_geometric/utils/convert.py +17 -10
- torch_geometric/utils/cross_entropy.py +34 -13
- torch_geometric/utils/embedding.py +91 -2
- torch_geometric/utils/geodesic.py +28 -25
- torch_geometric/utils/influence.py +279 -0
- torch_geometric/utils/map.py +14 -10
- torch_geometric/utils/nested.py +1 -1
- torch_geometric/utils/smiles.py +3 -3
- torch_geometric/utils/sparse.py +32 -24
- torch_geometric/visualization/__init__.py +2 -1
- torch_geometric/visualization/graph.py +250 -5
- torch_geometric/warnings.py +11 -2
- torch_geometric/nn/nlp/__init__.py +0 -7
- torch_geometric/nn/nlp/llm.py +0 -283
- torch_geometric/nn/nlp/sentence_transformer.py +0 -94
|
@@ -0,0 +1,263 @@
|
|
|
1
|
+
import sys
|
|
2
|
+
from typing import Any, Callable, Dict, List, Optional
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import torch
|
|
6
|
+
from tqdm import tqdm
|
|
7
|
+
|
|
8
|
+
from torch_geometric.data import (
|
|
9
|
+
Data,
|
|
10
|
+
InMemoryDataset,
|
|
11
|
+
download_google_url,
|
|
12
|
+
extract_zip,
|
|
13
|
+
)
|
|
14
|
+
from torch_geometric.io import fs
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def safe_index(lst: List[Any], e: int) -> int:
|
|
18
|
+
return lst.index(e) if e in lst else len(lst) - 1
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class GitMolDataset(InMemoryDataset):
|
|
22
|
+
r"""The dataset from the `"GIT-Mol: A Multi-modal Large Language Model
|
|
23
|
+
for Molecular Science with Graph, Image, and Text"
|
|
24
|
+
<https://arxiv.org/pdf/2308.06911>`_ paper.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
root (str): Root directory where the dataset should be saved.
|
|
28
|
+
transform (callable, optional): A function/transform that takes in an
|
|
29
|
+
:obj:`torch_geometric.data.Data` object and returns a transformed
|
|
30
|
+
version. The data object will be transformed before every access.
|
|
31
|
+
(default: :obj:`None`)
|
|
32
|
+
pre_transform (callable, optional): A function/transform that takes in
|
|
33
|
+
an :obj:`torch_geometric.data.Data` object and returns a
|
|
34
|
+
transformed version. The data object will be transformed before
|
|
35
|
+
being saved to disk. (default: :obj:`None`)
|
|
36
|
+
pre_filter (callable, optional): A function that takes in an
|
|
37
|
+
:obj:`torch_geometric.data.Data` object and returns a boolean
|
|
38
|
+
value, indicating whether the data object should be included in the
|
|
39
|
+
final dataset. (default: :obj:`None`)
|
|
40
|
+
force_reload (bool, optional): Whether to re-process the dataset.
|
|
41
|
+
(default: :obj:`False`)
|
|
42
|
+
split (int, optional): Datasets split, train/valid/test=0/1/2.
|
|
43
|
+
(default: :obj:`0`)
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
raw_url_id = '1loBXabD6ncAFY-vanRsVtRUSFkEtBweg'
|
|
47
|
+
|
|
48
|
+
def __init__(
|
|
49
|
+
self,
|
|
50
|
+
root: str,
|
|
51
|
+
transform: Optional[Callable] = None,
|
|
52
|
+
pre_transform: Optional[Callable] = None,
|
|
53
|
+
pre_filter: Optional[Callable] = None,
|
|
54
|
+
force_reload: bool = False,
|
|
55
|
+
split: int = 0,
|
|
56
|
+
):
|
|
57
|
+
from torchvision import transforms
|
|
58
|
+
|
|
59
|
+
self.split = split
|
|
60
|
+
|
|
61
|
+
if self.split == 0:
|
|
62
|
+
self.img_transform = transforms.Compose([
|
|
63
|
+
transforms.Resize((224, 224)),
|
|
64
|
+
transforms.RandomRotation(15),
|
|
65
|
+
transforms.ColorJitter(brightness=0.5, contrast=0.5, hue=0.5),
|
|
66
|
+
transforms.ToTensor(),
|
|
67
|
+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
|
68
|
+
std=[0.229, 0.224, 0.225])
|
|
69
|
+
])
|
|
70
|
+
else:
|
|
71
|
+
self.img_transform = transforms.Compose([
|
|
72
|
+
transforms.Resize((224, 224)),
|
|
73
|
+
transforms.ToTensor(),
|
|
74
|
+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
|
75
|
+
std=[0.229, 0.224, 0.225])
|
|
76
|
+
])
|
|
77
|
+
|
|
78
|
+
super().__init__(root, transform, pre_transform, pre_filter,
|
|
79
|
+
force_reload=force_reload)
|
|
80
|
+
|
|
81
|
+
self.load(self.processed_paths[0])
|
|
82
|
+
|
|
83
|
+
@property
|
|
84
|
+
def raw_file_names(self) -> List[str]:
|
|
85
|
+
return ['train_3500.pkl', 'valid_450.pkl', 'test_450.pkl']
|
|
86
|
+
|
|
87
|
+
@property
|
|
88
|
+
def processed_file_names(self) -> str:
|
|
89
|
+
return ['train.pt', 'valid.pt', 'test.pt'][self.split]
|
|
90
|
+
|
|
91
|
+
def download(self) -> None:
|
|
92
|
+
file_path = download_google_url(
|
|
93
|
+
self.raw_url_id,
|
|
94
|
+
self.raw_dir,
|
|
95
|
+
'gitmol.zip',
|
|
96
|
+
)
|
|
97
|
+
extract_zip(file_path, self.raw_dir)
|
|
98
|
+
|
|
99
|
+
def process(self) -> None:
|
|
100
|
+
import pandas as pd
|
|
101
|
+
from PIL import Image
|
|
102
|
+
|
|
103
|
+
try:
|
|
104
|
+
from rdkit import Chem, RDLogger
|
|
105
|
+
RDLogger.DisableLog('rdApp.*') # type: ignore[attr-defined]
|
|
106
|
+
WITH_RDKIT = True
|
|
107
|
+
|
|
108
|
+
except ImportError:
|
|
109
|
+
WITH_RDKIT = False
|
|
110
|
+
|
|
111
|
+
if not WITH_RDKIT:
|
|
112
|
+
print(("Using a pre-processed version of the dataset. Please "
|
|
113
|
+
"install 'rdkit' to alternatively process the raw data."),
|
|
114
|
+
file=sys.stderr)
|
|
115
|
+
|
|
116
|
+
data_list = fs.torch_load(self.raw_paths[0])
|
|
117
|
+
data_list = [Data(**data_dict) for data_dict in data_list]
|
|
118
|
+
|
|
119
|
+
if self.pre_filter is not None:
|
|
120
|
+
data_list = [d for d in data_list if self.pre_filter(d)]
|
|
121
|
+
|
|
122
|
+
if self.pre_transform is not None:
|
|
123
|
+
data_list = [self.pre_transform(d) for d in data_list]
|
|
124
|
+
|
|
125
|
+
self.save(data_list, self.processed_paths[0])
|
|
126
|
+
return
|
|
127
|
+
|
|
128
|
+
allowable_features: Dict[str, List[Any]] = {
|
|
129
|
+
'possible_atomic_num_list':
|
|
130
|
+
list(range(1, 119)) + ['misc'],
|
|
131
|
+
'possible_formal_charge_list':
|
|
132
|
+
[-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 'misc'],
|
|
133
|
+
'possible_chirality_list': [
|
|
134
|
+
Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
|
|
135
|
+
Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
|
|
136
|
+
Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
|
|
137
|
+
Chem.rdchem.ChiralType.CHI_OTHER
|
|
138
|
+
],
|
|
139
|
+
'possible_hybridization_list': [
|
|
140
|
+
Chem.rdchem.HybridizationType.SP,
|
|
141
|
+
Chem.rdchem.HybridizationType.SP2,
|
|
142
|
+
Chem.rdchem.HybridizationType.SP3,
|
|
143
|
+
Chem.rdchem.HybridizationType.SP3D,
|
|
144
|
+
Chem.rdchem.HybridizationType.SP3D2,
|
|
145
|
+
Chem.rdchem.HybridizationType.UNSPECIFIED, 'misc'
|
|
146
|
+
],
|
|
147
|
+
'possible_numH_list': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'],
|
|
148
|
+
'possible_implicit_valence_list': [0, 1, 2, 3, 4, 5, 6],
|
|
149
|
+
'possible_degree_list': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 'misc'],
|
|
150
|
+
'possible_number_radical_e_list': [0, 1, 2, 3, 4, 'misc'],
|
|
151
|
+
'possible_is_aromatic_list': [False, True],
|
|
152
|
+
'possible_is_in_ring_list': [False, True],
|
|
153
|
+
'possible_bond_type_list': [
|
|
154
|
+
Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE,
|
|
155
|
+
Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC,
|
|
156
|
+
Chem.rdchem.BondType.ZERO
|
|
157
|
+
],
|
|
158
|
+
'possible_bond_dirs': [ # only for double bond stereo information
|
|
159
|
+
Chem.rdchem.BondDir.NONE, Chem.rdchem.BondDir.ENDUPRIGHT,
|
|
160
|
+
Chem.rdchem.BondDir.ENDDOWNRIGHT
|
|
161
|
+
],
|
|
162
|
+
'possible_bond_stereo_list': [
|
|
163
|
+
Chem.rdchem.BondStereo.STEREONONE,
|
|
164
|
+
Chem.rdchem.BondStereo.STEREOZ,
|
|
165
|
+
Chem.rdchem.BondStereo.STEREOE,
|
|
166
|
+
Chem.rdchem.BondStereo.STEREOCIS,
|
|
167
|
+
Chem.rdchem.BondStereo.STEREOTRANS,
|
|
168
|
+
Chem.rdchem.BondStereo.STEREOANY,
|
|
169
|
+
],
|
|
170
|
+
'possible_is_conjugated_list': [False, True]
|
|
171
|
+
}
|
|
172
|
+
|
|
173
|
+
data = pd.read_pickle(
|
|
174
|
+
f'{self.raw_dir}/igcdata_toy/{self.raw_file_names[self.split]}')
|
|
175
|
+
|
|
176
|
+
data_list = []
|
|
177
|
+
for _, r in tqdm(data.iterrows(), total=data.shape[0]):
|
|
178
|
+
smiles = r['isosmiles']
|
|
179
|
+
mol = Chem.MolFromSmiles(smiles.strip('\n'))
|
|
180
|
+
if mol is not None:
|
|
181
|
+
# text
|
|
182
|
+
summary = r['summary']
|
|
183
|
+
# image
|
|
184
|
+
cid = r['cid']
|
|
185
|
+
img_file = f'{self.raw_dir}/igcdata_toy/imgs/CID_{cid}.png'
|
|
186
|
+
img = Image.open(img_file).convert('RGB')
|
|
187
|
+
img = self.img_transform(img).unsqueeze(0)
|
|
188
|
+
# graph
|
|
189
|
+
atom_features_list = []
|
|
190
|
+
for atom in mol.GetAtoms():
|
|
191
|
+
atom_feature = [
|
|
192
|
+
safe_index(
|
|
193
|
+
allowable_features['possible_atomic_num_list'],
|
|
194
|
+
atom.GetAtomicNum()),
|
|
195
|
+
allowable_features['possible_chirality_list'].index(
|
|
196
|
+
atom.GetChiralTag()),
|
|
197
|
+
safe_index(allowable_features['possible_degree_list'],
|
|
198
|
+
atom.GetTotalDegree()),
|
|
199
|
+
safe_index(
|
|
200
|
+
allowable_features['possible_formal_charge_list'],
|
|
201
|
+
atom.GetFormalCharge()),
|
|
202
|
+
safe_index(allowable_features['possible_numH_list'],
|
|
203
|
+
atom.GetTotalNumHs()),
|
|
204
|
+
safe_index(
|
|
205
|
+
allowable_features[
|
|
206
|
+
'possible_number_radical_e_list'],
|
|
207
|
+
atom.GetNumRadicalElectrons()),
|
|
208
|
+
safe_index(
|
|
209
|
+
allowable_features['possible_hybridization_list'],
|
|
210
|
+
atom.GetHybridization()),
|
|
211
|
+
allowable_features['possible_is_aromatic_list'].index(
|
|
212
|
+
atom.GetIsAromatic()),
|
|
213
|
+
allowable_features['possible_is_in_ring_list'].index(
|
|
214
|
+
atom.IsInRing()),
|
|
215
|
+
]
|
|
216
|
+
atom_features_list.append(atom_feature)
|
|
217
|
+
x = torch.tensor(np.array(atom_features_list),
|
|
218
|
+
dtype=torch.long)
|
|
219
|
+
|
|
220
|
+
edges_list = []
|
|
221
|
+
edge_features_list = []
|
|
222
|
+
for bond in mol.GetBonds():
|
|
223
|
+
i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
|
|
224
|
+
edge_feature = [
|
|
225
|
+
safe_index(
|
|
226
|
+
allowable_features['possible_bond_type_list'],
|
|
227
|
+
bond.GetBondType()),
|
|
228
|
+
allowable_features['possible_bond_stereo_list'].index(
|
|
229
|
+
bond.GetStereo()),
|
|
230
|
+
allowable_features['possible_is_conjugated_list'].
|
|
231
|
+
index(bond.GetIsConjugated()),
|
|
232
|
+
]
|
|
233
|
+
edges_list.append((i, j))
|
|
234
|
+
edge_features_list.append(edge_feature)
|
|
235
|
+
edges_list.append((j, i))
|
|
236
|
+
edge_features_list.append(edge_feature)
|
|
237
|
+
|
|
238
|
+
edge_index = torch.tensor(
|
|
239
|
+
np.array(edges_list).T,
|
|
240
|
+
dtype=torch.long,
|
|
241
|
+
)
|
|
242
|
+
edge_attr = torch.tensor(
|
|
243
|
+
np.array(edge_features_list),
|
|
244
|
+
dtype=torch.long,
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
data = Data(
|
|
248
|
+
x=x,
|
|
249
|
+
edge_index=edge_index,
|
|
250
|
+
smiles=smiles,
|
|
251
|
+
edge_attr=edge_attr,
|
|
252
|
+
image=img,
|
|
253
|
+
caption=summary,
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
if self.pre_filter is not None and not self.pre_filter(data):
|
|
257
|
+
continue
|
|
258
|
+
if self.pre_transform is not None:
|
|
259
|
+
data = self.pre_transform(data)
|
|
260
|
+
|
|
261
|
+
data_list.append(data)
|
|
262
|
+
|
|
263
|
+
self.save(data_list, self.processed_paths[0])
|
|
@@ -12,6 +12,7 @@ from torch_geometric.data import (
|
|
|
12
12
|
download_url,
|
|
13
13
|
extract_zip,
|
|
14
14
|
)
|
|
15
|
+
from torch_geometric.io import fs
|
|
15
16
|
from torch_geometric.utils import remove_self_loops
|
|
16
17
|
|
|
17
18
|
|
|
@@ -181,7 +182,7 @@ class GNNBenchmarkDataset(InMemoryDataset):
|
|
|
181
182
|
data_list = self.process_CSL()
|
|
182
183
|
self.save(data_list, self.processed_paths[0])
|
|
183
184
|
else:
|
|
184
|
-
inputs =
|
|
185
|
+
inputs = fs.torch_load(self.raw_paths[0])
|
|
185
186
|
for i in range(len(inputs)):
|
|
186
187
|
data_list = [Data(**data_dict) for data_dict in inputs[i]]
|
|
187
188
|
|
|
@@ -197,7 +198,7 @@ class GNNBenchmarkDataset(InMemoryDataset):
|
|
|
197
198
|
with open(self.raw_paths[0], 'rb') as f:
|
|
198
199
|
adjs = pickle.load(f)
|
|
199
200
|
|
|
200
|
-
ys =
|
|
201
|
+
ys = fs.torch_load(self.raw_paths[1]).tolist()
|
|
201
202
|
|
|
202
203
|
data_list = []
|
|
203
204
|
for adj, y in zip(adjs, ys):
|
|
@@ -123,8 +123,8 @@ class HGBDataset(InMemoryDataset):
|
|
|
123
123
|
start = info.index('LINK\tSTART\tEND\tMEANING') + 1
|
|
124
124
|
end = info[start:].index('')
|
|
125
125
|
for key, row in enumerate(info[start:start + end]):
|
|
126
|
-
|
|
127
|
-
src, dst, rel = (v for v in
|
|
126
|
+
edge = row.split('\t')[1:]
|
|
127
|
+
src, dst, rel = (v for v in edge if v != '')
|
|
128
128
|
src, dst = n_types[int(src)], n_types[int(dst)]
|
|
129
129
|
rel = rel.split('-')[1]
|
|
130
130
|
e_types[key] = (src, rel, dst)
|
torch_geometric/datasets/hm.py
CHANGED
|
@@ -81,7 +81,7 @@ class HM(InMemoryDataset):
|
|
|
81
81
|
xs.append(torch.from_numpy(x).to(torch.float))
|
|
82
82
|
|
|
83
83
|
x = torch.from_numpy(df['age'].values).to(torch.float).view(-1, 1)
|
|
84
|
-
x = x.nan_to_num(nan=x.nanmean())
|
|
84
|
+
x = x.nan_to_num(nan=x.nanmean()) # type: ignore
|
|
85
85
|
xs.append(x / x.max())
|
|
86
86
|
|
|
87
87
|
data['customer'].x = torch.cat(xs, dim=-1)
|
|
@@ -0,0 +1,134 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import sys
|
|
3
|
+
from typing import Callable, List, Optional
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from tqdm import tqdm
|
|
7
|
+
|
|
8
|
+
from torch_geometric.data import Data, InMemoryDataset
|
|
9
|
+
from torch_geometric.io import fs
|
|
10
|
+
from torch_geometric.utils import one_hot
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class InstructMolDataset(InMemoryDataset):
|
|
14
|
+
r"""The dataset from the `"InstructMol: Multi-Modal Integration for
|
|
15
|
+
Building a Versatile and Reliable Molecular Assistant in Drug Discovery"
|
|
16
|
+
<https://arxiv.org/pdf/2311.16208>`_ paper.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
root (str): Root directory where the dataset should be saved.
|
|
20
|
+
transform (callable, optional): A function/transform that takes in an
|
|
21
|
+
:obj:`torch_geometric.data.Data` object and returns a transformed
|
|
22
|
+
version. The data object will be transformed before every access.
|
|
23
|
+
(default: :obj:`None`)
|
|
24
|
+
pre_transform (callable, optional): A function/transform that takes in
|
|
25
|
+
an :obj:`torch_geometric.data.Data` object and returns a
|
|
26
|
+
transformed version. The data object will be transformed before
|
|
27
|
+
being saved to disk. (default: :obj:`None`)
|
|
28
|
+
pre_filter (callable, optional): A function that takes in an
|
|
29
|
+
:obj:`torch_geometric.data.Data` object and returns a boolean
|
|
30
|
+
value, indicating whether the data object should be included in the
|
|
31
|
+
final dataset. (default: :obj:`None`)
|
|
32
|
+
force_reload (bool, optional): Whether to re-process the dataset.
|
|
33
|
+
(default: :obj:`False`)
|
|
34
|
+
"""
|
|
35
|
+
raw_url = 'https://huggingface.co/datasets/OpenMol/PubChemSFT/resolve/main'
|
|
36
|
+
|
|
37
|
+
def __init__(
|
|
38
|
+
self,
|
|
39
|
+
root: str,
|
|
40
|
+
transform: Optional[Callable] = None,
|
|
41
|
+
pre_transform: Optional[Callable] = None,
|
|
42
|
+
pre_filter: Optional[Callable] = None,
|
|
43
|
+
force_reload: bool = False,
|
|
44
|
+
):
|
|
45
|
+
super().__init__(root, transform, pre_transform, pre_filter,
|
|
46
|
+
force_reload=force_reload)
|
|
47
|
+
self.load(self.processed_paths[0])
|
|
48
|
+
|
|
49
|
+
@property
|
|
50
|
+
def raw_file_names(self) -> List[str]:
|
|
51
|
+
return ['all_clean.json']
|
|
52
|
+
|
|
53
|
+
@property
|
|
54
|
+
def processed_file_names(self) -> List[str]:
|
|
55
|
+
return ['data.pt']
|
|
56
|
+
|
|
57
|
+
def download(self) -> None:
|
|
58
|
+
print('downloading dataset...')
|
|
59
|
+
fs.cp(f'{self.raw_url}/all_clean.json', self.raw_dir)
|
|
60
|
+
|
|
61
|
+
def process(self) -> None:
|
|
62
|
+
try:
|
|
63
|
+
from rdkit import Chem
|
|
64
|
+
from rdkit.Chem.rdchem import BondType as BT
|
|
65
|
+
WITH_RDKIT = True
|
|
66
|
+
|
|
67
|
+
except ImportError:
|
|
68
|
+
WITH_RDKIT = False
|
|
69
|
+
|
|
70
|
+
if not WITH_RDKIT:
|
|
71
|
+
print(("Using a pre-processed version of the dataset. Please "
|
|
72
|
+
"install 'rdkit' to alternatively process the raw data."),
|
|
73
|
+
file=sys.stderr)
|
|
74
|
+
|
|
75
|
+
data_list = fs.torch_load(self.raw_paths[0])
|
|
76
|
+
data_list = [Data(**data_dict) for data_dict in data_list]
|
|
77
|
+
|
|
78
|
+
if self.pre_filter is not None:
|
|
79
|
+
data_list = [d for d in data_list if self.pre_filter(d)]
|
|
80
|
+
|
|
81
|
+
if self.pre_transform is not None:
|
|
82
|
+
data_list = [self.pre_transform(d) for d in data_list]
|
|
83
|
+
|
|
84
|
+
self.save(data_list, self.processed_paths[0])
|
|
85
|
+
return
|
|
86
|
+
|
|
87
|
+
# types of atom and bond
|
|
88
|
+
types = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4, 'Unknow': 5}
|
|
89
|
+
bonds = {BT.SINGLE: 0, BT.DOUBLE: 1, BT.TRIPLE: 2, BT.AROMATIC: 3}
|
|
90
|
+
|
|
91
|
+
# load data
|
|
92
|
+
mols = json.load(open(f'{self.raw_dir}/all_clean.json'))
|
|
93
|
+
|
|
94
|
+
data_list = []
|
|
95
|
+
for smiles, qa_pairs in tqdm(mols.items(), total=len(mols)):
|
|
96
|
+
mol = Chem.MolFromSmiles(smiles)
|
|
97
|
+
if mol is None:
|
|
98
|
+
continue
|
|
99
|
+
|
|
100
|
+
x: torch.Tensor = torch.tensor([
|
|
101
|
+
types[atom.GetSymbol()] if atom.GetSymbol() in types else 5
|
|
102
|
+
for atom in mol.GetAtoms()
|
|
103
|
+
])
|
|
104
|
+
x = one_hot(x, num_classes=len(types), dtype=torch.float)
|
|
105
|
+
|
|
106
|
+
rows, cols, edge_types = [], [], []
|
|
107
|
+
for bond in mol.GetBonds():
|
|
108
|
+
i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
|
|
109
|
+
edge_types += [bonds[bond.GetBondType()]] * 2
|
|
110
|
+
rows += [i, j]
|
|
111
|
+
cols += [j, i]
|
|
112
|
+
|
|
113
|
+
edge_index = torch.tensor([rows, cols], dtype=torch.long)
|
|
114
|
+
edge_type = torch.tensor(edge_types, dtype=torch.long)
|
|
115
|
+
edge_attr = one_hot(edge_type, num_classes=len(bonds))
|
|
116
|
+
|
|
117
|
+
for question, answer in qa_pairs:
|
|
118
|
+
data = Data(
|
|
119
|
+
x=x,
|
|
120
|
+
edge_index=edge_index,
|
|
121
|
+
edge_attr=edge_attr,
|
|
122
|
+
smiles=smiles,
|
|
123
|
+
instruction=question,
|
|
124
|
+
y=answer,
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
if self.pre_filter is not None and not self.pre_filter(data):
|
|
128
|
+
continue
|
|
129
|
+
if self.pre_transform is not None:
|
|
130
|
+
data = self.pre_transform(data)
|
|
131
|
+
|
|
132
|
+
data_list.append(data)
|
|
133
|
+
|
|
134
|
+
self.save(data_list, self.processed_paths[0])
|
|
@@ -5,6 +5,7 @@ import numpy as np
|
|
|
5
5
|
import torch
|
|
6
6
|
|
|
7
7
|
from torch_geometric.data import Data, InMemoryDataset, download_url
|
|
8
|
+
from torch_geometric.io import fs
|
|
8
9
|
from torch_geometric.utils import one_hot
|
|
9
10
|
|
|
10
11
|
|
|
@@ -115,9 +116,9 @@ class LINKXDataset(InMemoryDataset):
|
|
|
115
116
|
|
|
116
117
|
def _process_wiki(self) -> Data:
|
|
117
118
|
paths = {x.split('/')[-1]: x for x in self.raw_paths}
|
|
118
|
-
x =
|
|
119
|
-
edge_index =
|
|
120
|
-
y =
|
|
119
|
+
x = fs.torch_load(paths['wiki_features2M.pt'])
|
|
120
|
+
edge_index = fs.torch_load(paths['wiki_edges2M.pt']).t().contiguous()
|
|
121
|
+
y = fs.torch_load(paths['wiki_views2M.pt'])
|
|
121
122
|
|
|
122
123
|
return Data(x=x, edge_index=edge_index, y=y)
|
|
123
124
|
|
torch_geometric/datasets/lrgb.py
CHANGED
|
@@ -188,9 +188,8 @@ class LRGBDataset(InMemoryDataset):
|
|
|
188
188
|
graphs = pickle.load(f)
|
|
189
189
|
elif self.name.split('-')[0] == 'peptides':
|
|
190
190
|
# Peptides-func and Peptides-struct
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
graphs = torch.load(f)
|
|
191
|
+
graphs = fs.torch_load(
|
|
192
|
+
osp.join(self.raw_dir, f'{split}.pt'))
|
|
194
193
|
|
|
195
194
|
data_list = []
|
|
196
195
|
for graph in tqdm(graphs, desc=f'Processing {split} dataset'):
|
|
@@ -260,8 +259,7 @@ class LRGBDataset(InMemoryDataset):
|
|
|
260
259
|
|
|
261
260
|
def process_pcqm_contact(self) -> None:
|
|
262
261
|
for split in ['train', 'val', 'test']:
|
|
263
|
-
|
|
264
|
-
graphs = torch.load(f)
|
|
262
|
+
graphs = fs.torch_load(osp.join(self.raw_dir, f'{split}.pt'))
|
|
265
263
|
|
|
266
264
|
data_list = []
|
|
267
265
|
for graph in tqdm(graphs, desc=f'Processing {split} dataset'):
|
|
@@ -11,6 +11,7 @@ from torch_geometric.data import (
|
|
|
11
11
|
extract_tar,
|
|
12
12
|
extract_zip,
|
|
13
13
|
)
|
|
14
|
+
from torch_geometric.io import fs
|
|
14
15
|
|
|
15
16
|
|
|
16
17
|
class MalNetTiny(InMemoryDataset):
|
|
@@ -65,7 +66,7 @@ class MalNetTiny(InMemoryDataset):
|
|
|
65
66
|
self.load(self.processed_paths[0])
|
|
66
67
|
|
|
67
68
|
if split is not None:
|
|
68
|
-
split_slices =
|
|
69
|
+
split_slices = fs.torch_load(self.processed_paths[1])
|
|
69
70
|
if split == 'train':
|
|
70
71
|
self._indices = range(split_slices[0], split_slices[1])
|
|
71
72
|
elif split == 'val':
|
torch_geometric/datasets/md17.py
CHANGED
|
@@ -57,7 +57,7 @@ class MD17(InMemoryDataset):
|
|
|
57
57
|
+--------------------+--------------------+-------------------------------+-----------+
|
|
58
58
|
| Uracil | DFT | :obj:`uracil` | 133,770 |
|
|
59
59
|
+--------------------+--------------------+-------------------------------+-----------+
|
|
60
|
-
| Naphthalene | DFT | :obj:`
|
|
60
|
+
| Naphthalene | DFT | :obj:`naphthalene` | 326,250 |
|
|
61
61
|
+--------------------+--------------------+-------------------------------+-----------+
|
|
62
62
|
| Aspirin | DFT | :obj:`aspirin` | 211,762 |
|
|
63
63
|
+--------------------+--------------------+-------------------------------+-----------+
|
|
@@ -77,7 +77,7 @@ class MD17(InMemoryDataset):
|
|
|
77
77
|
+--------------------+--------------------+-------------------------------+-----------+
|
|
78
78
|
| Uracil (R) | DFT (PBE/def2-SVP) | :obj:`revised uracil` | 100,000 |
|
|
79
79
|
+--------------------+--------------------+-------------------------------+-----------+
|
|
80
|
-
| Naphthalene (R) | DFT (PBE/def2-SVP) | :obj:`revised
|
|
80
|
+
| Naphthalene (R) | DFT (PBE/def2-SVP) | :obj:`revised naphthalene` | 100,000 |
|
|
81
81
|
+--------------------+--------------------+-------------------------------+-----------+
|
|
82
82
|
| Aspirin (R) | DFT (PBE/def2-SVP) | :obj:`revised aspirin` | 100,000 |
|
|
83
83
|
+--------------------+--------------------+-------------------------------+-----------+
|
|
@@ -309,7 +309,7 @@ class MD17(InMemoryDataset):
|
|
|
309
309
|
file_names = {
|
|
310
310
|
'benzene': 'md17_benzene2017.npz',
|
|
311
311
|
'uracil': 'md17_uracil.npz',
|
|
312
|
-
'
|
|
312
|
+
'naphthalene': 'md17_naphthalene.npz',
|
|
313
313
|
'aspirin': 'md17_aspirin.npz',
|
|
314
314
|
'salicylic acid': 'md17_salicylic.npz',
|
|
315
315
|
'malonaldehyde': 'md17_malonaldehyde.npz',
|
|
@@ -0,0 +1,145 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import os.path as osp
|
|
3
|
+
from typing import Callable, List, Optional
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
from torch_geometric.data import Data, InMemoryDataset
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class MedShapeNet(InMemoryDataset):
|
|
12
|
+
r"""The MedShapeNet datasets from the `"MedShapeNet -- A Large-Scale
|
|
13
|
+
Dataset of 3D Medical Shapes for Computer Vision"
|
|
14
|
+
<https://arxiv.org/abs/2308.16139>`_ paper,
|
|
15
|
+
containing 8 different type of structures (classes).
|
|
16
|
+
|
|
17
|
+
.. note::
|
|
18
|
+
|
|
19
|
+
Data objects hold mesh faces instead of edge indices.
|
|
20
|
+
To convert the mesh to a graph, use the
|
|
21
|
+
:obj:`torch_geometric.transforms.FaceToEdge` as :obj:`pre_transform`.
|
|
22
|
+
To convert the mesh to a point cloud, use the
|
|
23
|
+
:obj:`torch_geometric.transforms.SamplePoints` as :obj:`transform` to
|
|
24
|
+
sample a fixed number of points on the mesh faces according to their
|
|
25
|
+
face area.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
root (str): Root directory where the dataset should be saved.
|
|
29
|
+
size (int): Number of invividual 3D structures to download per
|
|
30
|
+
type (classes).
|
|
31
|
+
transform (callable, optional): A function/transform that takes in an
|
|
32
|
+
:obj:`torch_geometric.data.Data` object and returns a transformed
|
|
33
|
+
version. The data object will be transformed before every access.
|
|
34
|
+
(default: :obj:`None`)
|
|
35
|
+
pre_transform (callable, optional): A function/transform that takes in
|
|
36
|
+
an :obj:`torch_geometric.data.Data` object and returns a
|
|
37
|
+
transformed version. The data object will be transformed before
|
|
38
|
+
being saved to disk. (default: :obj:`None`)
|
|
39
|
+
pre_filter (callable, optional): A function that takes in an
|
|
40
|
+
:obj:`torch_geometric.data.Data` object and returns a boolean
|
|
41
|
+
value, indicating whether the data object should be included in the
|
|
42
|
+
final dataset. (default: :obj:`None`)
|
|
43
|
+
force_reload (bool, optional): Whether to re-process the dataset.
|
|
44
|
+
(default: :obj:`False`)
|
|
45
|
+
"""
|
|
46
|
+
def __init__(
|
|
47
|
+
self,
|
|
48
|
+
root: str,
|
|
49
|
+
size: int = 100,
|
|
50
|
+
transform: Optional[Callable] = None,
|
|
51
|
+
pre_transform: Optional[Callable] = None,
|
|
52
|
+
pre_filter: Optional[Callable] = None,
|
|
53
|
+
force_reload: bool = False,
|
|
54
|
+
) -> None:
|
|
55
|
+
self.size = size
|
|
56
|
+
super().__init__(root, transform, pre_transform, pre_filter,
|
|
57
|
+
force_reload=force_reload)
|
|
58
|
+
|
|
59
|
+
path = self.processed_paths[0]
|
|
60
|
+
self.load(path)
|
|
61
|
+
|
|
62
|
+
@property
|
|
63
|
+
def raw_file_names(self) -> List[str]:
|
|
64
|
+
return [
|
|
65
|
+
'3DTeethSeg', 'CoronaryArteries', 'FLARE', 'KITS', 'PULMONARY',
|
|
66
|
+
'SurgicalInstruments', 'ThoracicAorta_Saitta', 'ToothFairy'
|
|
67
|
+
]
|
|
68
|
+
|
|
69
|
+
@property
|
|
70
|
+
def processed_file_names(self) -> List[str]:
|
|
71
|
+
return ['dataset.pt']
|
|
72
|
+
|
|
73
|
+
@property
|
|
74
|
+
def raw_paths(self) -> List[str]:
|
|
75
|
+
r"""The absolute filepaths that must be present in order to skip
|
|
76
|
+
downloading.
|
|
77
|
+
"""
|
|
78
|
+
return [osp.join(self.raw_dir, f) for f in self.raw_file_names]
|
|
79
|
+
|
|
80
|
+
def process(self) -> None:
|
|
81
|
+
import urllib3
|
|
82
|
+
from MedShapeNet import MedShapeNet as msn
|
|
83
|
+
|
|
84
|
+
msn_instance = msn(timeout=120)
|
|
85
|
+
|
|
86
|
+
urllib3.HTTPConnectionPool("medshapenet.ddns.net", maxsize=50)
|
|
87
|
+
|
|
88
|
+
list_of_datasets = msn_instance.datasets(False)
|
|
89
|
+
list_of_datasets = list(
|
|
90
|
+
filter(
|
|
91
|
+
lambda x: x not in [
|
|
92
|
+
'medshapenetcore/ASOCA', 'medshapenetcore/AVT',
|
|
93
|
+
'medshapenetcore/AutoImplantCraniotomy',
|
|
94
|
+
'medshapenetcore/FaceVR'
|
|
95
|
+
], list_of_datasets))
|
|
96
|
+
|
|
97
|
+
subset = []
|
|
98
|
+
for dataset in list_of_datasets:
|
|
99
|
+
parts = dataset.split("/")
|
|
100
|
+
self.newpath = self.root + '/' + parts[1 if len(parts) > 1 else 0]
|
|
101
|
+
if not os.path.exists(self.newpath):
|
|
102
|
+
os.makedirs(self.newpath)
|
|
103
|
+
stl_files = msn_instance.dataset_files(dataset, '.stl')
|
|
104
|
+
subset.extend(stl_files[:self.size])
|
|
105
|
+
|
|
106
|
+
for stl_file in stl_files[:self.size]:
|
|
107
|
+
msn_instance.download_stl_as_numpy(bucket_name=dataset,
|
|
108
|
+
stl_file=stl_file,
|
|
109
|
+
output_dir=self.newpath,
|
|
110
|
+
print_output=False)
|
|
111
|
+
|
|
112
|
+
class_mapping = {
|
|
113
|
+
'3DTeethSeg': 0,
|
|
114
|
+
'CoronaryArteries': 1,
|
|
115
|
+
'FLARE': 2,
|
|
116
|
+
'KITS': 3,
|
|
117
|
+
'PULMONARY': 4,
|
|
118
|
+
'SurgicalInstruments': 5,
|
|
119
|
+
'ThoracicAorta_Saitta': 6,
|
|
120
|
+
'ToothFairy': 7
|
|
121
|
+
}
|
|
122
|
+
|
|
123
|
+
for dataset, path in zip([subset], self.processed_paths):
|
|
124
|
+
data_list = []
|
|
125
|
+
for item in dataset:
|
|
126
|
+
class_name = item.split("/")[0]
|
|
127
|
+
item = item.split("stl")[0]
|
|
128
|
+
target = class_mapping[class_name]
|
|
129
|
+
file = osp.join(self.root, item + 'npz')
|
|
130
|
+
|
|
131
|
+
data = np.load(file)
|
|
132
|
+
pre_data_list = Data(
|
|
133
|
+
pos=torch.tensor(data["vertices"], dtype=torch.float),
|
|
134
|
+
face=torch.tensor(data["faces"],
|
|
135
|
+
dtype=torch.long).t().contiguous())
|
|
136
|
+
pre_data_list.y = torch.tensor([target], dtype=torch.long)
|
|
137
|
+
data_list.append(pre_data_list)
|
|
138
|
+
|
|
139
|
+
if self.pre_filter is not None:
|
|
140
|
+
data_list = [d for d in data_list if self.pre_filter(d)]
|
|
141
|
+
|
|
142
|
+
if self.pre_transform is not None:
|
|
143
|
+
data_list = [self.pre_transform(d) for d in data_list]
|
|
144
|
+
|
|
145
|
+
self.save(data_list, path)
|