atomic-datasets 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- atomic_datasets/__init__.py +3 -0
- atomic_datasets/datasets/__init__.py +15 -0
- atomic_datasets/datasets/chembl3d.py +205 -0
- atomic_datasets/datasets/crossdocked.py +167 -0
- atomic_datasets/datasets/geom_drugs.py +209 -0
- atomic_datasets/datasets/platonic_solids.py +135 -0
- atomic_datasets/datasets/proteins.py +672 -0
- atomic_datasets/datasets/qm9.py +308 -0
- atomic_datasets/datasets/tetris.py +51 -0
- atomic_datasets/datasets/tmqm.py +475 -0
- atomic_datasets/datatypes.py +61 -0
- atomic_datasets/utils/__init__.py +14 -0
- atomic_datasets/utils/cache.py +95 -0
- atomic_datasets/utils/decorators.py +9 -0
- atomic_datasets/utils/download.py +99 -0
- atomic_datasets/utils/periodic_table.py +59 -0
- atomic_datasets/utils/rdkit.py +79 -0
- atomic_datasets/utils/visualizer.py +112 -0
- atomic_datasets/utils/xyz.py +24 -0
- atomic_datasets/wrappers/__init__.py +0 -0
- atomic_datasets/wrappers/jax.py +99 -0
- atomic_datasets/wrappers/torch.py +104 -0
- atomic_datasets-0.1.0.dist-info/METADATA +414 -0
- atomic_datasets-0.1.0.dist-info/RECORD +27 -0
- atomic_datasets-0.1.0.dist-info/WHEEL +5 -0
- atomic_datasets-0.1.0.dist-info/licenses/LICENSE +21 -0
- atomic_datasets-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from .qm9 import QM9
|
|
2
|
+
from .geom_drugs import GEOMDrugs
|
|
3
|
+
from .tmqm import tmQM
|
|
4
|
+
from .platonic_solids import PlatonicSolids
|
|
5
|
+
from .tetris import Tetris
|
|
6
|
+
from .proteins import (
|
|
7
|
+
CATHAlphaCarbons,
|
|
8
|
+
MiniproteinsAlphaCarbons,
|
|
9
|
+
MiniproteinsBackbone,
|
|
10
|
+
MiniproteinsBackboneNoAA,
|
|
11
|
+
Miniproteins,
|
|
12
|
+
get_amino_acids,
|
|
13
|
+
)
|
|
14
|
+
from .chembl3d import ChEMBL3D
|
|
15
|
+
from .crossdocked import CrossDocked
|
|
@@ -0,0 +1,205 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import json
|
|
3
|
+
import zipfile
|
|
4
|
+
from typing import Dict, Iterable, Optional, Sequence, Any
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
from atomic_datasets import datatypes
|
|
9
|
+
from atomic_datasets import utils
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
# Zenodo URL for preprocessed data
|
|
13
|
+
ChEMBL3D_ZENODO_URL = "https://zenodo.org/records/18488050/files/chembl3d_processed.zip"
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class ChEMBL3D(datatypes.MolecularDataset):
|
|
17
|
+
"""
|
|
18
|
+
Dataset of ChEMBL3D structures from https://github.com/isayevlab/LoQI.
|
|
19
|
+
|
|
20
|
+
Contains structures for a large subset of ChEMBL, pre-optimized or extracted
|
|
21
|
+
from 3D experimental data. This implementation uses a high-performance
|
|
22
|
+
contiguous memory layout with memory-mapping support.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
root_dir: Directory to store/load data
|
|
26
|
+
split: Which split to use ('train', 'val', 'test_small', 'test_rot_bonds', 'test_cremp')
|
|
27
|
+
start_index: Start index for slicing the dataset
|
|
28
|
+
end_index: End index for slicing the dataset
|
|
29
|
+
train_on_single_molecule: If True, use single molecule for all splits
|
|
30
|
+
train_on_single_molecule_index: Index of molecule to use if train_on_single_molecule=True
|
|
31
|
+
mmap_mode: Memory-map mode for numpy arrays ('r', 'r+', 'c', or None to load into memory)
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
ATOMIC_NUMBERS = np.asarray([
|
|
35
|
+
1, 5, 6, 7, 8, 9, 13, 14, 15, 16, 17, 33, 35, 53, 80, 83, 34
|
|
36
|
+
], dtype=np.int32)
|
|
37
|
+
|
|
38
|
+
def __init__(
|
|
39
|
+
self,
|
|
40
|
+
root_dir: str,
|
|
41
|
+
split: str,
|
|
42
|
+
start_index: Optional[int] = None,
|
|
43
|
+
end_index: Optional[int] = None,
|
|
44
|
+
train_on_single_molecule: bool = False,
|
|
45
|
+
train_on_single_molecule_index: int = 0,
|
|
46
|
+
mmap_mode: Optional[str] = 'r',
|
|
47
|
+
):
|
|
48
|
+
super().__init__(atomic_numbers=self.ATOMIC_NUMBERS)
|
|
49
|
+
|
|
50
|
+
self.root_dir = os.path.join(root_dir, "chembl3d")
|
|
51
|
+
self.split = split
|
|
52
|
+
self.start_index = start_index
|
|
53
|
+
self.end_index = end_index
|
|
54
|
+
self.train_on_single_molecule = train_on_single_molecule
|
|
55
|
+
self.train_on_single_molecule_index = train_on_single_molecule_index
|
|
56
|
+
self.mmap_mode = mmap_mode
|
|
57
|
+
|
|
58
|
+
self.preprocessed = False
|
|
59
|
+
|
|
60
|
+
# Data storage - will be initialized as contiguous arrays or memory-mapped views
|
|
61
|
+
self._positions = None # (total_atoms, 3) float32
|
|
62
|
+
self._species = None # (total_atoms,) uint8
|
|
63
|
+
self._offsets = None # (n_molecules + 1,) int64
|
|
64
|
+
self._charges = None # (total_atoms,) int8
|
|
65
|
+
self._is_aromatic = None # (total_atoms,) uint8
|
|
66
|
+
self._is_in_ring = None # (total_atoms,) uint8
|
|
67
|
+
self._hybridization = None # (total_atoms,) uint8
|
|
68
|
+
self._smiles = None # list of str
|
|
69
|
+
self._chemblids = None # list of str
|
|
70
|
+
self._indices = None # indices after filtering/slicing
|
|
71
|
+
|
|
72
|
+
self.preprocess()
|
|
73
|
+
|
|
74
|
+
def _get_file_path(self, filename: str) -> str:
|
|
75
|
+
"""Helper to get the absolute path to a preprocessed file."""
|
|
76
|
+
return os.path.join(self.root_dir, filename)
|
|
77
|
+
|
|
78
|
+
def _download_from_zenodo(self):
|
|
79
|
+
"""Download and extract preprocessed data from Zenodo."""
|
|
80
|
+
os.makedirs(self.root_dir, exist_ok=True)
|
|
81
|
+
|
|
82
|
+
zip_path = os.path.join(self.root_dir, "chembl3d_processed.zip")
|
|
83
|
+
|
|
84
|
+
# Download
|
|
85
|
+
print(f"Downloading ChEMBL3D from Zenodo...")
|
|
86
|
+
print(f" URL: {ChEMBL3D_ZENODO_URL}")
|
|
87
|
+
utils.download_url(ChEMBL3D_ZENODO_URL, self.root_dir)
|
|
88
|
+
|
|
89
|
+
# Extract
|
|
90
|
+
print(f"Extracting to {self.root_dir}...")
|
|
91
|
+
with zipfile.ZipFile(zip_path, 'r') as zf:
|
|
92
|
+
# Extract and flatten directory structure
|
|
93
|
+
for member in zf.namelist():
|
|
94
|
+
# Skip directories
|
|
95
|
+
if member.endswith('/'):
|
|
96
|
+
continue
|
|
97
|
+
|
|
98
|
+
# Get the filename without the parent directory
|
|
99
|
+
filename = os.path.basename(member)
|
|
100
|
+
if not filename:
|
|
101
|
+
continue
|
|
102
|
+
|
|
103
|
+
# Extract to root_dir
|
|
104
|
+
source = zf.open(member)
|
|
105
|
+
target_path = os.path.join(self.root_dir, filename)
|
|
106
|
+
with open(target_path, 'wb') as target:
|
|
107
|
+
target.write(source.read())
|
|
108
|
+
source.close()
|
|
109
|
+
|
|
110
|
+
# Clean up zip file to save disk space
|
|
111
|
+
os.remove(zip_path)
|
|
112
|
+
print("Download complete.")
|
|
113
|
+
|
|
114
|
+
def preprocess(self):
|
|
115
|
+
"""Load preprocessed numpy files and setup indices."""
|
|
116
|
+
if self.preprocessed:
|
|
117
|
+
return
|
|
118
|
+
|
|
119
|
+
prefix = self.split
|
|
120
|
+
|
|
121
|
+
# Check that files exist, download if needed
|
|
122
|
+
positions_path = self._get_file_path(f"{prefix}_positions.npy")
|
|
123
|
+
if not os.path.exists(positions_path):
|
|
124
|
+
self._download_from_zenodo()
|
|
125
|
+
|
|
126
|
+
print(f"Loading ChEMBL3D {self.split} split from: {self.root_dir}")
|
|
127
|
+
|
|
128
|
+
# Load arrays (memory-mapped for large files to save RAM)
|
|
129
|
+
self._positions = np.load(self._get_file_path(f"{prefix}_positions.npy"), mmap_mode=self.mmap_mode)
|
|
130
|
+
self._species = np.load(self._get_file_path(f"{prefix}_species.npy"), mmap_mode=self.mmap_mode)
|
|
131
|
+
self._offsets = np.load(self._get_file_path(f"{prefix}_offsets.npy")) # Small, load fully into RAM
|
|
132
|
+
self._charges = np.load(self._get_file_path(f"{prefix}_charges.npy"), mmap_mode=self.mmap_mode)
|
|
133
|
+
self._is_aromatic = np.load(self._get_file_path(f"{prefix}_is_aromatic.npy"), mmap_mode=self.mmap_mode)
|
|
134
|
+
self._is_in_ring = np.load(self._get_file_path(f"{prefix}_is_in_ring.npy"), mmap_mode=self.mmap_mode)
|
|
135
|
+
self._hybridization = np.load(self._get_file_path(f"{prefix}_hybridization.npy"), mmap_mode=self.mmap_mode)
|
|
136
|
+
|
|
137
|
+
# Load metadata (SMILES and IDs)
|
|
138
|
+
metadata_path = self._get_file_path(f"{prefix}_metadata.json")
|
|
139
|
+
with open(metadata_path, 'r') as f:
|
|
140
|
+
metadata = json.load(f)
|
|
141
|
+
self._smiles = metadata["smiles"]
|
|
142
|
+
self._chemblids = metadata["chemblids"]
|
|
143
|
+
|
|
144
|
+
n_molecules = len(self._offsets) - 1
|
|
145
|
+
n_atoms = len(self._positions)
|
|
146
|
+
print(f"Loaded {n_molecules:,} molecules, {n_atoms:,} total atoms")
|
|
147
|
+
|
|
148
|
+
# Apply single molecule override for debugging or overfitting tests
|
|
149
|
+
if self.train_on_single_molecule:
|
|
150
|
+
self._indices = np.array([self.train_on_single_molecule_index])
|
|
151
|
+
else:
|
|
152
|
+
self._indices = np.arange(n_molecules)
|
|
153
|
+
|
|
154
|
+
# Apply start/end index for subsetting
|
|
155
|
+
self._indices = self._indices[slice(self.start_index, self.end_index)]
|
|
156
|
+
|
|
157
|
+
self.preprocessed = True
|
|
158
|
+
|
|
159
|
+
def __len__(self) -> int:
|
|
160
|
+
"""Returns the number of molecules in the current split/slice."""
|
|
161
|
+
return len(self._indices)
|
|
162
|
+
|
|
163
|
+
def __getitem__(self, idx: int) -> datatypes.Graph:
|
|
164
|
+
"""
|
|
165
|
+
Fast random access to a molecule using offset-based slicing.
|
|
166
|
+
|
|
167
|
+
Returns:
|
|
168
|
+
A datatypes.Graph object containing positions, species, and atomic properties.
|
|
169
|
+
"""
|
|
170
|
+
if idx < 0:
|
|
171
|
+
idx = len(self._indices) + idx
|
|
172
|
+
if idx < 0 or idx >= len(self._indices):
|
|
173
|
+
raise IndexError(f"Index {idx} out of range for dataset of size {len(self._indices)}")
|
|
174
|
+
|
|
175
|
+
mol_idx = self._indices[idx]
|
|
176
|
+
start = self._offsets[mol_idx]
|
|
177
|
+
end = self._offsets[mol_idx + 1]
|
|
178
|
+
|
|
179
|
+
# Slice arrays (copies from mmap into RAM for the return object)
|
|
180
|
+
species = np.array(self._species[start:end])
|
|
181
|
+
atomic_numbers = self.species_to_atomic_numbers(species)
|
|
182
|
+
atom_types = utils.atomic_numbers_to_symbols(atomic_numbers)
|
|
183
|
+
|
|
184
|
+
return datatypes.Graph(
|
|
185
|
+
nodes=dict(
|
|
186
|
+
positions=np.array(self._positions[start:end]),
|
|
187
|
+
atomic_numbers=atomic_numbers,
|
|
188
|
+
species=species,
|
|
189
|
+
atom_types=atom_types,
|
|
190
|
+
),
|
|
191
|
+
edges=None,
|
|
192
|
+
receivers=None,
|
|
193
|
+
senders=None,
|
|
194
|
+
globals=None,
|
|
195
|
+
n_node=np.asarray([end - start]),
|
|
196
|
+
n_edge=None,
|
|
197
|
+
properties={
|
|
198
|
+
"smiles": self._smiles[mol_idx],
|
|
199
|
+
"chemblid": self._chemblids[mol_idx],
|
|
200
|
+
"charges": np.array(self._charges[start:end]),
|
|
201
|
+
"is_aromatic": np.array(self._is_aromatic[start:end]),
|
|
202
|
+
"is_in_ring": np.array(self._is_in_ring[start:end]),
|
|
203
|
+
"hybridization": np.array(self._hybridization[start:end]),
|
|
204
|
+
},
|
|
205
|
+
)
|
|
@@ -0,0 +1,167 @@
|
|
|
1
|
+
from typing import Iterable, Dict, Optional, Tuple
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import logging
|
|
5
|
+
import json
|
|
6
|
+
import zipfile
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
|
|
10
|
+
from atomic_datasets import utils
|
|
11
|
+
from atomic_datasets import datatypes
|
|
12
|
+
|
|
13
|
+
# Zenodo URL for preprocessed data
|
|
14
|
+
CROSSDOCKED_ZENODO_URL = "https://zenodo.org/records/18584578/files/crossdocked.zip"
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class CrossDocked(datatypes.MolecularDataset):
|
|
18
|
+
"""
|
|
19
|
+
The CrossDocked dataset from https://pubs.acs.org/doi/full/10.1021/acs.jcim.0c00411
|
|
20
|
+
with splits from Luo et al. (https://proceedings.neurips.cc/paper/2021/hash/314450613369e0ee72d0da7f6fee773c-Abstract.html).
|
|
21
|
+
|
|
22
|
+
Loads preprocessed data from Zenodo (memory-mapped for efficiency).
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
root_dir: Directory to store/load data
|
|
26
|
+
split: Which split to use ('train', 'val', 'test')
|
|
27
|
+
start_index: Start index for slicing the dataset
|
|
28
|
+
end_index: End index for slicing the dataset
|
|
29
|
+
mmap_mode: Memory-map mode for numpy arrays ('r', 'r+', 'c', or None to load into memory)
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
ATOMIC_NUMBERS = np.asarray([
|
|
33
|
+
13, 33, 79, 5, 35, 6, 17, 27, 24, 29, 9, 26, 1, 80, 53, 3, 12,
|
|
34
|
+
42, 7, 8, 15, 44, 16, 21, 34, 14, 50, 23, 74, 39
|
|
35
|
+
], dtype=np.int32)
|
|
36
|
+
|
|
37
|
+
def __init__(
|
|
38
|
+
self,
|
|
39
|
+
root_dir: str,
|
|
40
|
+
split: str = "train",
|
|
41
|
+
start_index: Optional[int] = None,
|
|
42
|
+
end_index: Optional[int] = None,
|
|
43
|
+
mmap_mode: Optional[str] = 'r',
|
|
44
|
+
):
|
|
45
|
+
super().__init__(atomic_numbers=self.ATOMIC_NUMBERS)
|
|
46
|
+
|
|
47
|
+
self.root_dir = os.path.join(root_dir, "crossdocked")
|
|
48
|
+
self.split = split
|
|
49
|
+
self.start_index = start_index
|
|
50
|
+
self.end_index = end_index
|
|
51
|
+
self.mmap_mode = mmap_mode
|
|
52
|
+
|
|
53
|
+
self.preprocessed = False
|
|
54
|
+
|
|
55
|
+
# Data storage
|
|
56
|
+
self._positions = None # (N_total, 3) memory-mapped
|
|
57
|
+
self._atom_types = None # (N_total,) memory-mapped (indices into lookup)
|
|
58
|
+
self._offsets = None # (n_complexes + 1,) start indices
|
|
59
|
+
self._n_atoms = None # (n_complexes,) atoms per complex
|
|
60
|
+
self._atom_type_lookup = None # (n_types,) symbol strings
|
|
61
|
+
self._properties = None # List of dicts (pocket_file, starting_fragment_mask)
|
|
62
|
+
self._indices = None # Indices after slicing
|
|
63
|
+
|
|
64
|
+
if split not in ("train", "val", "test"):
|
|
65
|
+
raise ValueError(f"split must be 'train', 'val', or 'test', got '{split}'")
|
|
66
|
+
|
|
67
|
+
self.preprocess()
|
|
68
|
+
|
|
69
|
+
readme_path = os.path.join(self.root_dir, "README.md")
|
|
70
|
+
if os.path.exists(readme_path):
|
|
71
|
+
print("Dataset description available at:", os.path.abspath(readme_path))
|
|
72
|
+
|
|
73
|
+
def preprocess(self):
|
|
74
|
+
"""Initialize data access - downloads if needed, then loads."""
|
|
75
|
+
if self.preprocessed:
|
|
76
|
+
return
|
|
77
|
+
|
|
78
|
+
self._ensure_downloaded()
|
|
79
|
+
self._load_data()
|
|
80
|
+
self._setup_indices()
|
|
81
|
+
|
|
82
|
+
self.preprocessed = True
|
|
83
|
+
|
|
84
|
+
def _ensure_downloaded(self):
|
|
85
|
+
"""Download and extract preprocessed files from Zenodo if not present."""
|
|
86
|
+
os.makedirs(self.root_dir, exist_ok=True)
|
|
87
|
+
|
|
88
|
+
# Check if data is already extracted
|
|
89
|
+
marker_file = os.path.join(self.root_dir, "crossdocked", "train_positions.npy")
|
|
90
|
+
if os.path.exists(marker_file):
|
|
91
|
+
return
|
|
92
|
+
|
|
93
|
+
zip_filename = "crossdocked.zip"
|
|
94
|
+
zip_path = os.path.join(self.root_dir, zip_filename)
|
|
95
|
+
|
|
96
|
+
if not os.path.exists(zip_path):
|
|
97
|
+
utils.download_url(CROSSDOCKED_ZENODO_URL, self.root_dir, filename=zip_filename)
|
|
98
|
+
|
|
99
|
+
print(f"Extracting {zip_filename}...")
|
|
100
|
+
with zipfile.ZipFile(zip_path, 'r') as zf:
|
|
101
|
+
zf.extractall(self.root_dir)
|
|
102
|
+
|
|
103
|
+
os.remove(zip_path)
|
|
104
|
+
print("Extraction complete.")
|
|
105
|
+
|
|
106
|
+
def _load_data(self):
|
|
107
|
+
"""Load preprocessed data using memory mapping."""
|
|
108
|
+
prefix = self.split
|
|
109
|
+
print(f"Loading CrossDocked {self.split} split from {self.root_dir}")
|
|
110
|
+
|
|
111
|
+
self._positions = np.load(
|
|
112
|
+
os.path.join(self.root_dir, "crossdocked", f"{prefix}_positions.npy"),
|
|
113
|
+
mmap_mode=self.mmap_mode,
|
|
114
|
+
)
|
|
115
|
+
self._atom_types = np.load(
|
|
116
|
+
os.path.join(self.root_dir, "crossdocked", f"{prefix}_atom_types.npy"),
|
|
117
|
+
mmap_mode=self.mmap_mode,
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
self._offsets = np.load(os.path.join(self.root_dir, "crossdocked", f"{prefix}_offsets.npy"))
|
|
121
|
+
self._n_atoms = np.load(os.path.join(self.root_dir, "crossdocked", f"{prefix}_n_atoms.npy"))
|
|
122
|
+
self._atom_type_lookup = np.load(
|
|
123
|
+
os.path.join(self.root_dir, "crossdocked", f"{prefix}_atom_type_lookup.npy")
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
with open(os.path.join(self.root_dir, "crossdocked", f"{prefix}_properties.json")) as f:
|
|
127
|
+
self._properties = json.load(f)
|
|
128
|
+
|
|
129
|
+
n_complexes = len(self._n_atoms)
|
|
130
|
+
print(f"Loaded {n_complexes} complexes")
|
|
131
|
+
|
|
132
|
+
def _setup_indices(self):
|
|
133
|
+
"""Setup indices with optional slicing."""
|
|
134
|
+
n_complexes = len(self._n_atoms)
|
|
135
|
+
self._indices = np.arange(n_complexes)
|
|
136
|
+
self._indices = self._indices[slice(self.start_index, self.end_index)]
|
|
137
|
+
|
|
138
|
+
def __len__(self) -> int:
|
|
139
|
+
return len(self._indices)
|
|
140
|
+
|
|
141
|
+
def __getitem__(self, idx: int) -> datatypes.Graph:
|
|
142
|
+
"""Fast slice access via memory-mapped offsets."""
|
|
143
|
+
if idx < 0:
|
|
144
|
+
idx = len(self._indices) + idx
|
|
145
|
+
|
|
146
|
+
real_idx = self._indices[idx]
|
|
147
|
+
start, end = self._offsets[real_idx], self._offsets[real_idx + 1]
|
|
148
|
+
|
|
149
|
+
species = np.array(self._atom_types[start:end])
|
|
150
|
+
atom_types = self._atom_type_lookup[species]
|
|
151
|
+
atomic_numbers = utils.atomic_symbols_to_numbers(atom_types)
|
|
152
|
+
|
|
153
|
+
return datatypes.Graph(
|
|
154
|
+
nodes=dict(
|
|
155
|
+
positions=np.array(self._positions[start:end]),
|
|
156
|
+
atomic_numbers=atomic_numbers,
|
|
157
|
+
species=species,
|
|
158
|
+
atom_types=atom_types,
|
|
159
|
+
),
|
|
160
|
+
edges=None,
|
|
161
|
+
senders=None,
|
|
162
|
+
receivers=None,
|
|
163
|
+
n_edge=None,
|
|
164
|
+
n_node=np.asarray([self._n_atoms[real_idx]]),
|
|
165
|
+
globals=None,
|
|
166
|
+
properties=self._properties[real_idx],
|
|
167
|
+
)
|
|
@@ -0,0 +1,209 @@
|
|
|
1
|
+
from typing import Dict, Iterable, Optional, List, Tuple, Union
|
|
2
|
+
import os
|
|
3
|
+
import logging
|
|
4
|
+
import json
|
|
5
|
+
import zipfile
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
from atomic_datasets import datatypes
|
|
10
|
+
from atomic_datasets import utils
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
# Zenodo URL for preprocessed data
|
|
14
|
+
GEOM_DRUGS_ZENODO_URL = "https://zenodo.org/record/18484634/files/geom_drugs_processed.zip"
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class GEOMDrugs(datatypes.MolecularDataset):
|
|
18
|
+
"""
|
|
19
|
+
The GEOM (Drugs) dataset from https://www.nature.com/articles/s41597-022-01288-4.
|
|
20
|
+
|
|
21
|
+
Loads preprocessed data from Zenodo (memory-mapped for efficiency).
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
root_dir: Directory to store/load data
|
|
25
|
+
split: Which split to use ('train', 'val', 'test')
|
|
26
|
+
start_index: Start index for slicing the dataset
|
|
27
|
+
end_index: End index for slicing the dataset
|
|
28
|
+
max_atoms: Filter out molecules with more atoms than this
|
|
29
|
+
conformer_selection: How to select conformers ('first', 'random', 'all')
|
|
30
|
+
random_seed: Random seed for conformer selection (if conformer_selection='random')
|
|
31
|
+
mmap_mode: Memory-map mode for numpy arrays ('r', 'r+', 'c', or None to load into memory)
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
# Atomic numbers present in GEOM-Drugs
|
|
35
|
+
ATOMIC_NUMBERS = np.asarray([1, 5, 6, 7, 8, 9, 13, 14, 15, 16, 17, 33, 35, 53, 80, 83])
|
|
36
|
+
|
|
37
|
+
def __init__(
|
|
38
|
+
self,
|
|
39
|
+
root_dir: str,
|
|
40
|
+
split: str = "train",
|
|
41
|
+
start_index: Optional[int] = None,
|
|
42
|
+
end_index: Optional[int] = None,
|
|
43
|
+
max_atoms: Optional[int] = None,
|
|
44
|
+
conformer_selection: str = "all",
|
|
45
|
+
random_seed: int = 0,
|
|
46
|
+
mmap_mode: Optional[str] = 'r',
|
|
47
|
+
):
|
|
48
|
+
# Initialize the base class mapping logic
|
|
49
|
+
super().__init__(atomic_numbers=self.ATOMIC_NUMBERS)
|
|
50
|
+
|
|
51
|
+
self.root_dir = os.path.join(root_dir, "geom_drugs")
|
|
52
|
+
self.split = split
|
|
53
|
+
self.start_index = start_index
|
|
54
|
+
self.end_index = end_index
|
|
55
|
+
self.max_atoms = max_atoms
|
|
56
|
+
self.conformer_selection = conformer_selection
|
|
57
|
+
self.random_seed = random_seed
|
|
58
|
+
self.mmap_mode = mmap_mode
|
|
59
|
+
|
|
60
|
+
self.preprocessed = False
|
|
61
|
+
|
|
62
|
+
# Data storage
|
|
63
|
+
self._positions = None # (N_total, 3) memory-mapped
|
|
64
|
+
self._atomic_numbers = None # (N_total,) memory-mapped
|
|
65
|
+
self._offsets = None # (n_conformers + 1,) start indices
|
|
66
|
+
self._n_atoms = None # (n_conformers,) atoms per conformer
|
|
67
|
+
self._mol_indices = None # (n_conformers,) molecule index
|
|
68
|
+
self._smiles = None # List of SMILES strings
|
|
69
|
+
self._indices = None # Indices into conformers (after filtering)
|
|
70
|
+
self._rng = None
|
|
71
|
+
|
|
72
|
+
if split not in ("train", "val", "test"):
|
|
73
|
+
raise ValueError(f"split must be 'train', 'val', or 'test', got '{split}'")
|
|
74
|
+
|
|
75
|
+
if conformer_selection not in ("first", "random", "all"):
|
|
76
|
+
raise ValueError(f"conformer_selection must be 'first', 'random', or 'all', got '{conformer_selection}'")
|
|
77
|
+
|
|
78
|
+
self.preprocess()
|
|
79
|
+
|
|
80
|
+
def preprocess(self):
|
|
81
|
+
"""Initialize data access - downloads if needed, then loads."""
|
|
82
|
+
if self.preprocessed:
|
|
83
|
+
return
|
|
84
|
+
|
|
85
|
+
# Download and extract if needed
|
|
86
|
+
self._ensure_downloaded()
|
|
87
|
+
|
|
88
|
+
# Load data
|
|
89
|
+
self._load_data()
|
|
90
|
+
|
|
91
|
+
# Setup indices based on conformer selection
|
|
92
|
+
self._rng = np.random.default_rng(self.random_seed)
|
|
93
|
+
self._setup_indices()
|
|
94
|
+
|
|
95
|
+
self.preprocessed = True
|
|
96
|
+
|
|
97
|
+
def _ensure_downloaded(self):
|
|
98
|
+
"""Download and extract preprocessed files from Zenodo if not present."""
|
|
99
|
+
os.makedirs(self.root_dir, exist_ok=True)
|
|
100
|
+
|
|
101
|
+
# Check if data is already extracted
|
|
102
|
+
marker_file = os.path.join(self.root_dir, "train_positions.npy")
|
|
103
|
+
if os.path.exists(marker_file):
|
|
104
|
+
return
|
|
105
|
+
|
|
106
|
+
zip_filename = "geom_drugs_processed.zip"
|
|
107
|
+
zip_path = os.path.join(self.root_dir, zip_filename)
|
|
108
|
+
|
|
109
|
+
if not os.path.exists(zip_path):
|
|
110
|
+
print(f"Downloading {zip_filename}...")
|
|
111
|
+
utils.download_url(GEOM_DRUGS_ZENODO_URL, self.root_dir, filename=zip_filename)
|
|
112
|
+
|
|
113
|
+
print(f"Extracting {zip_filename}...")
|
|
114
|
+
with zipfile.ZipFile(zip_path, 'r') as zf:
|
|
115
|
+
zf.extractall(self.root_dir)
|
|
116
|
+
|
|
117
|
+
os.remove(zip_path)
|
|
118
|
+
print("Extraction complete.")
|
|
119
|
+
|
|
120
|
+
def _load_data(self):
|
|
121
|
+
"""Load preprocessed data using memory mapping."""
|
|
122
|
+
prefix = self.split
|
|
123
|
+
print(f"Loading GEOM-Drugs {self.split} split from {self.root_dir}")
|
|
124
|
+
|
|
125
|
+
self._positions = np.load(
|
|
126
|
+
os.path.join(self.root_dir, f"{prefix}_positions.npy"),
|
|
127
|
+
mmap_mode=self.mmap_mode
|
|
128
|
+
)
|
|
129
|
+
self._species = np.load(
|
|
130
|
+
os.path.join(self.root_dir, f"{prefix}_species.npy"),
|
|
131
|
+
mmap_mode=self.mmap_mode
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
# Regular arrays for metadata/indexing
|
|
135
|
+
self._offsets = np.load(os.path.join(self.root_dir, f"{prefix}_offsets.npy"))
|
|
136
|
+
self._n_atoms = np.load(os.path.join(self.root_dir, f"{prefix}_n_atoms.npy"))
|
|
137
|
+
self._mol_indices = np.load(os.path.join(self.root_dir, f"{prefix}_mol_indices.npy"))
|
|
138
|
+
|
|
139
|
+
with open(os.path.join(self.root_dir, f"{prefix}_smiles.json")) as f:
|
|
140
|
+
self._smiles = json.load(f)
|
|
141
|
+
|
|
142
|
+
n_molecules = len(np.unique(self._mol_indices))
|
|
143
|
+
n_conformers = len(self._n_atoms)
|
|
144
|
+
print(f"Loaded {n_molecules} molecules with {n_conformers} total conformers")
|
|
145
|
+
|
|
146
|
+
def _setup_indices(self):
|
|
147
|
+
"""Setup indices based on conformer selection mode and dataset slicing."""
|
|
148
|
+
n_conformers = len(self._n_atoms)
|
|
149
|
+
|
|
150
|
+
if self.conformer_selection == "all":
|
|
151
|
+
self._indices = np.arange(n_conformers)
|
|
152
|
+
else:
|
|
153
|
+
unique_mols, first_indices = np.unique(self._mol_indices, return_index=True)
|
|
154
|
+
|
|
155
|
+
if self.conformer_selection == "first":
|
|
156
|
+
self._indices = first_indices
|
|
157
|
+
elif self.conformer_selection == "random":
|
|
158
|
+
counts = np.diff(np.append(first_indices, n_conformers))
|
|
159
|
+
offsets = np.array([self._rng.integers(0, c) for c in counts])
|
|
160
|
+
self._indices = first_indices + offsets
|
|
161
|
+
|
|
162
|
+
# Filter by atom count
|
|
163
|
+
if self.max_atoms is not None:
|
|
164
|
+
mask = self._n_atoms[self._indices] <= self.max_atoms
|
|
165
|
+
self._indices = self._indices[mask]
|
|
166
|
+
|
|
167
|
+
# Apply start/end slicing
|
|
168
|
+
self._indices = self._indices[slice(self.start_index, self.end_index)]
|
|
169
|
+
|
|
170
|
+
def __len__(self) -> int:
|
|
171
|
+
return len(self._indices)
|
|
172
|
+
|
|
173
|
+
def __getitem__(self, idx: int) -> datatypes.Graph:
|
|
174
|
+
"""Fast slice access via memory-mapped offsets."""
|
|
175
|
+
if idx < 0:
|
|
176
|
+
idx = len(self._indices) + idx
|
|
177
|
+
|
|
178
|
+
real_idx = self._indices[idx]
|
|
179
|
+
start, end = self._offsets[real_idx], self._offsets[real_idx + 1]
|
|
180
|
+
|
|
181
|
+
# Extract data for this conformer
|
|
182
|
+
positions = np.array(self._positions[start:end])
|
|
183
|
+
species = np.array(self._species[start:end])
|
|
184
|
+
atomic_numbers = self.species_to_atomic_numbers(species)
|
|
185
|
+
atom_types = utils.atomic_numbers_to_symbols(atomic_numbers)
|
|
186
|
+
|
|
187
|
+
return datatypes.Graph(
|
|
188
|
+
nodes=dict(
|
|
189
|
+
positions=np.array(self._positions[start:end]),
|
|
190
|
+
atomic_numbers=atomic_numbers,
|
|
191
|
+
species=species,
|
|
192
|
+
atom_types=atom_types,
|
|
193
|
+
),
|
|
194
|
+
edges=None,
|
|
195
|
+
senders=None,
|
|
196
|
+
receivers=None,
|
|
197
|
+
n_edge=None,
|
|
198
|
+
n_node=np.asarray([self._n_atoms[real_idx]]),
|
|
199
|
+
globals=None,
|
|
200
|
+
properties=dict(smiles=self._smiles[real_idx]),
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
def get_num_conformers(self, mol_idx: int) -> int:
|
|
204
|
+
"""Get number of conformers for a molecule by its molecule index."""
|
|
205
|
+
return np.sum(self._mol_indices == mol_idx)
|
|
206
|
+
|
|
207
|
+
def get_molecule_indices(self) -> np.ndarray:
|
|
208
|
+
"""Get array mapping each loaded conformer back to its molecule index."""
|
|
209
|
+
return self._mol_indices[self._indices]
|