bblean 0.6.0b1__cp312-cp312-macosx_10_13_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bblean/__init__.py +22 -0
- bblean/_config.py +61 -0
- bblean/_console.py +187 -0
- bblean/_cpp_similarity.cpython-312-darwin.so +0 -0
- bblean/_legacy/__init__.py +0 -0
- bblean/_legacy/bb_int64.py +1252 -0
- bblean/_legacy/bb_uint8.py +1144 -0
- bblean/_memory.py +198 -0
- bblean/_merges.py +212 -0
- bblean/_py_similarity.py +278 -0
- bblean/_timer.py +42 -0
- bblean/_version.py +34 -0
- bblean/analysis.py +258 -0
- bblean/bitbirch.py +1437 -0
- bblean/cli.py +1854 -0
- bblean/csrc/README.md +1 -0
- bblean/csrc/similarity.cpp +521 -0
- bblean/fingerprints.py +424 -0
- bblean/metrics.py +199 -0
- bblean/multiround.py +489 -0
- bblean/plotting.py +479 -0
- bblean/similarity.py +304 -0
- bblean/sklearn.py +203 -0
- bblean/smiles.py +61 -0
- bblean/utils.py +130 -0
- bblean-0.6.0b1.dist-info/METADATA +283 -0
- bblean-0.6.0b1.dist-info/RECORD +31 -0
- bblean-0.6.0b1.dist-info/WHEEL +6 -0
- bblean-0.6.0b1.dist-info/entry_points.txt +2 -0
- bblean-0.6.0b1.dist-info/licenses/LICENSE +48 -0
- bblean-0.6.0b1.dist-info/top_level.txt +1 -0
bblean/fingerprints.py
ADDED
|
@@ -0,0 +1,424 @@
|
|
|
1
|
+
r"""Utilites for manipulating fingerprints and fingerprint files"""
|
|
2
|
+
|
|
3
|
+
import warnings
|
|
4
|
+
import dataclasses
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from numpy.typing import NDArray, DTypeLike
|
|
7
|
+
import numpy as np
|
|
8
|
+
import typing as tp
|
|
9
|
+
import multiprocessing.shared_memory as shmem
|
|
10
|
+
|
|
11
|
+
from rich.console import Console
|
|
12
|
+
from rdkit.Chem import rdFingerprintGenerator, MolFromSmiles, SanitizeFlags, SanitizeMol
|
|
13
|
+
|
|
14
|
+
from bblean._config import DEFAULTS
|
|
15
|
+
from bblean._console import get_console
|
|
16
|
+
|
|
17
|
+
__all__ = [
|
|
18
|
+
"make_fake_fingerprints",
|
|
19
|
+
"fps_from_smiles",
|
|
20
|
+
"pack_fingerprints",
|
|
21
|
+
"unpack_fingerprints",
|
|
22
|
+
]
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
# Deprecated
|
|
26
|
+
def calc_centroid(
|
|
27
|
+
linear_sum: NDArray[np.integer], n_samples: int, *, pack: bool = True
|
|
28
|
+
) -> NDArray[np.uint8]:
|
|
29
|
+
warnings.warn(
|
|
30
|
+
"Please use `bblean.similarity.centroid_from_sum(...)` instead",
|
|
31
|
+
DeprecationWarning,
|
|
32
|
+
stacklevel=2,
|
|
33
|
+
)
|
|
34
|
+
if n_samples <= 1:
|
|
35
|
+
centroid = linear_sum.astype(np.uint8, copy=False)
|
|
36
|
+
else:
|
|
37
|
+
centroid = (linear_sum >= n_samples * 0.5).view(np.uint8)
|
|
38
|
+
if pack:
|
|
39
|
+
return np.packbits(centroid, axis=-1)
|
|
40
|
+
return centroid
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
centroid_from_sum = calc_centroid
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def pack_fingerprints(a: NDArray[np.uint8]) -> NDArray[np.uint8]:
|
|
47
|
+
r"""Pack binary (only 0s and 1s) uint8 fingerprint arrays"""
|
|
48
|
+
# packbits may pad with zeros if n_features is not a multiple of 8
|
|
49
|
+
return np.packbits(a, axis=-1)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def unpack_fingerprints(
|
|
53
|
+
a: NDArray[np.uint8], n_features: int | None = None
|
|
54
|
+
) -> NDArray[np.uint8]:
|
|
55
|
+
r"""Unpack packed uint8 arrays into binary uint8 arrays (with only 0s and 1s)
|
|
56
|
+
|
|
57
|
+
.. note::
|
|
58
|
+
|
|
59
|
+
If ``n_features`` is not passed, unpacking will only recover the correct number
|
|
60
|
+
of features if it is a multiple of 8, otherwise fingerprints will be padded with
|
|
61
|
+
zeros to the closest multiple of 8. This is generally not an issue since most
|
|
62
|
+
common fingerprints feature sizes (2048, 1024, etc) are multiples of 8, but if
|
|
63
|
+
you are using a non-standard number of features you should pass ``n_features``
|
|
64
|
+
explicitly.
|
|
65
|
+
"""
|
|
66
|
+
# n_features is required to discard padded zeros if it is not a multiple of 8
|
|
67
|
+
return np.unpackbits(a, axis=-1, count=n_features)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def make_fake_fingerprints(
|
|
71
|
+
num: int,
|
|
72
|
+
n_features: int = DEFAULTS.n_features,
|
|
73
|
+
pack: bool = True,
|
|
74
|
+
seed: int | None = None,
|
|
75
|
+
dtype: DTypeLike = np.uint8,
|
|
76
|
+
) -> NDArray[np.uint8]:
|
|
77
|
+
r"""Make random fingerprints with statistics similar to (some) real databases"""
|
|
78
|
+
import scipy.stats # Hide this import since scipy is heavy
|
|
79
|
+
|
|
80
|
+
if n_features < 1 or n_features % 8 != 0:
|
|
81
|
+
raise ValueError("n_features must be a multiple of 8, and greater than 0")
|
|
82
|
+
# Generate "synthetic" fingerprints with a popcount distribution
|
|
83
|
+
# similar to one in a real smiles database
|
|
84
|
+
# Fps are guaranteed to *not* be all zeros or all ones
|
|
85
|
+
if pack:
|
|
86
|
+
if np.dtype(dtype) != np.dtype(np.uint8):
|
|
87
|
+
raise ValueError("Only np.uint8 dtype is supported for packed input")
|
|
88
|
+
loc = 750
|
|
89
|
+
scale = 400
|
|
90
|
+
bounds = (0, n_features)
|
|
91
|
+
rng = np.random.default_rng(seed)
|
|
92
|
+
safe_bounds = (bounds[0] + 1, bounds[1] - 1)
|
|
93
|
+
a = (safe_bounds[0] - loc) / scale
|
|
94
|
+
b = (safe_bounds[1] - loc) / scale
|
|
95
|
+
popcounts_fake_float = scipy.stats.truncnorm.rvs(
|
|
96
|
+
a, b, loc=loc, scale=scale, size=num, random_state=rng
|
|
97
|
+
)
|
|
98
|
+
popcounts_fake = np.rint(popcounts_fake_float).astype(np.int64)
|
|
99
|
+
zerocounts_fake = n_features - popcounts_fake
|
|
100
|
+
repeats_fake = np.empty((num * 2), dtype=np.int64)
|
|
101
|
+
repeats_fake[0::2] = popcounts_fake
|
|
102
|
+
repeats_fake[1::2] = zerocounts_fake
|
|
103
|
+
initial = np.tile(np.array([1, 0], np.uint8), num)
|
|
104
|
+
expanded = np.repeat(initial, repeats=repeats_fake)
|
|
105
|
+
fps_fake = rng.permuted(expanded.reshape(num, n_features), axis=-1)
|
|
106
|
+
if pack:
|
|
107
|
+
return np.packbits(fps_fake, axis=1)
|
|
108
|
+
return fps_fake.astype(dtype, copy=False)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def _get_generator(kind: str, n_features: int) -> tp.Any:
|
|
112
|
+
if kind == "rdkit":
|
|
113
|
+
return rdFingerprintGenerator.GetRDKitFPGenerator(fpSize=n_features)
|
|
114
|
+
elif kind == "ecfp4":
|
|
115
|
+
return rdFingerprintGenerator.GetMorganGenerator(radius=2, fpSize=n_features)
|
|
116
|
+
elif kind == "ecfp6":
|
|
117
|
+
return rdFingerprintGenerator.GetMorganGenerator(radius=3, fpSize=n_features)
|
|
118
|
+
raise ValueError(f"Unknonw kind {kind}. Should be one of 'rdkit|ecfp4|ecfp6'")
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def _get_sanitize_flags(sanitize: str) -> tp.Any:
|
|
122
|
+
if sanitize == "all":
|
|
123
|
+
return SanitizeFlags.SANITIZE_ALL
|
|
124
|
+
elif sanitize == "minimal":
|
|
125
|
+
flags = SanitizeFlags.SANITIZE_CLEANUP | SanitizeFlags.SANITIZE_SYMMRINGS
|
|
126
|
+
return flags
|
|
127
|
+
else:
|
|
128
|
+
raise ValueError("Unknown 'sanitize', must be one of 'all', 'minimal'")
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
@tp.overload
|
|
132
|
+
def fps_from_smiles(
|
|
133
|
+
smiles: tp.Iterable[str],
|
|
134
|
+
kind: str = DEFAULTS.fp_kind,
|
|
135
|
+
n_features: int = DEFAULTS.n_features,
|
|
136
|
+
dtype: DTypeLike = np.uint8,
|
|
137
|
+
sanitize: str = "all",
|
|
138
|
+
skip_invalid: tp.Literal[False] = False,
|
|
139
|
+
pack: bool = True,
|
|
140
|
+
) -> NDArray[np.uint8]:
|
|
141
|
+
pass
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
@tp.overload
|
|
145
|
+
def fps_from_smiles(
|
|
146
|
+
smiles: tp.Iterable[str],
|
|
147
|
+
kind: str = DEFAULTS.fp_kind,
|
|
148
|
+
n_features: int = DEFAULTS.n_features,
|
|
149
|
+
dtype: DTypeLike = np.uint8,
|
|
150
|
+
sanitize: str = "all",
|
|
151
|
+
skip_invalid: tp.Literal[True] = True,
|
|
152
|
+
pack: bool = True,
|
|
153
|
+
) -> tuple[NDArray[np.uint8], NDArray[np.int64]]:
|
|
154
|
+
pass
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def fps_from_smiles(
|
|
158
|
+
smiles: tp.Iterable[str],
|
|
159
|
+
kind: str = DEFAULTS.fp_kind,
|
|
160
|
+
n_features: int = DEFAULTS.n_features,
|
|
161
|
+
dtype: DTypeLike = np.uint8,
|
|
162
|
+
sanitize: str = "all",
|
|
163
|
+
skip_invalid: bool = False,
|
|
164
|
+
pack: bool = True,
|
|
165
|
+
) -> tp.Union[NDArray[np.uint8], tuple[NDArray[np.uint8], NDArray[np.int64]]]:
|
|
166
|
+
r"""Convert a sequence of smiles into chemical fingerprints"""
|
|
167
|
+
if n_features < 1 or n_features % 8 != 0:
|
|
168
|
+
raise ValueError("n_features must be a multiple of 8, and greater than 0")
|
|
169
|
+
if isinstance(smiles, str):
|
|
170
|
+
smiles = [smiles]
|
|
171
|
+
|
|
172
|
+
if pack and not (np.dtype(dtype) == np.dtype(np.uint8)):
|
|
173
|
+
raise ValueError("Packing only supported for uint8 dtype")
|
|
174
|
+
|
|
175
|
+
fpg = _get_generator(kind, n_features)
|
|
176
|
+
|
|
177
|
+
sanitize_flags = _get_sanitize_flags(sanitize)
|
|
178
|
+
|
|
179
|
+
smiles = list(smiles)
|
|
180
|
+
fps = np.empty((len(smiles), n_features), dtype=dtype)
|
|
181
|
+
|
|
182
|
+
invalid_idxs = []
|
|
183
|
+
for i, smi in enumerate(smiles):
|
|
184
|
+
mol = MolFromSmiles(smi, sanitize=False)
|
|
185
|
+
if mol is None:
|
|
186
|
+
if skip_invalid:
|
|
187
|
+
invalid_idxs.append(i)
|
|
188
|
+
continue
|
|
189
|
+
else:
|
|
190
|
+
raise ValueError(f"Unable to parse smiles {smi}")
|
|
191
|
+
try:
|
|
192
|
+
SanitizeMol(mol, sanitizeOps=sanitize_flags)
|
|
193
|
+
fps[i, :] = fpg.GetFingerprintAsNumPy(mol)
|
|
194
|
+
except Exception:
|
|
195
|
+
if skip_invalid:
|
|
196
|
+
invalid_idxs.append(i)
|
|
197
|
+
continue
|
|
198
|
+
raise
|
|
199
|
+
|
|
200
|
+
if invalid_idxs:
|
|
201
|
+
fps = np.delete(fps, invalid_idxs, axis=0)
|
|
202
|
+
if pack:
|
|
203
|
+
if skip_invalid:
|
|
204
|
+
return pack_fingerprints(fps), np.array(invalid_idxs, dtype=np.int64)
|
|
205
|
+
return pack_fingerprints(fps)
|
|
206
|
+
if skip_invalid:
|
|
207
|
+
return fps, np.array(invalid_idxs, dtype=np.int64)
|
|
208
|
+
return fps
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
def _get_fps_file_num(path: Path) -> int:
|
|
212
|
+
with open(path, mode="rb") as f:
|
|
213
|
+
major, minor = np.lib.format.read_magic(f)
|
|
214
|
+
shape, _, _ = getattr(np.lib.format, f"read_array_header_{major}_{minor}")(f)
|
|
215
|
+
return shape[0]
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
def _get_fps_file_shape_and_dtype(
|
|
219
|
+
path: Path, raise_if_invalid: bool = False
|
|
220
|
+
) -> tuple[tuple[int, int], np.dtype, bool, bool]:
|
|
221
|
+
with open(path, mode="rb") as f:
|
|
222
|
+
major, minor = np.lib.format.read_magic(f)
|
|
223
|
+
shape, _, dtype = getattr(np.lib.format, f"read_array_header_{major}_{minor}")(
|
|
224
|
+
f
|
|
225
|
+
)
|
|
226
|
+
shape_is_valid = len(shape) == 2
|
|
227
|
+
dtype_is_valid = np.issubdtype(dtype, np.integer)
|
|
228
|
+
if raise_if_invalid and (not shape_is_valid) or (not dtype_is_valid):
|
|
229
|
+
raise ValueError(
|
|
230
|
+
f"Fingerprints file {path} is invalid. Shape: {shape}, DType {dtype}"
|
|
231
|
+
)
|
|
232
|
+
return shape, dtype, shape_is_valid, dtype_is_valid
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
def _print_fps_file_info(path: Path, console: Console | None = None) -> None:
|
|
236
|
+
if console is None:
|
|
237
|
+
console = Console()
|
|
238
|
+
shape, dtype, shape_is_valid, dtype_is_valid = _get_fps_file_shape_and_dtype(path)
|
|
239
|
+
|
|
240
|
+
console.print(f"File: {path.resolve()}")
|
|
241
|
+
if shape_is_valid and dtype_is_valid:
|
|
242
|
+
console.print(" - [green]Valid fingerprint file[/green]")
|
|
243
|
+
else:
|
|
244
|
+
console.print(" - [red]Invalid fingerprint file[/red]")
|
|
245
|
+
if shape_is_valid:
|
|
246
|
+
console.print(f" - Num. fingerprints: {shape[0]:,}")
|
|
247
|
+
console.print(f" - Num. features: {shape[1]:,}")
|
|
248
|
+
else:
|
|
249
|
+
console.print(f" - Shape: {shape}")
|
|
250
|
+
console.print(f" - DType: [yellow]{dtype.name}[/yellow]")
|
|
251
|
+
console.print()
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
class _FingerprintFileSequence:
|
|
255
|
+
def __init__(self, files: tp.Iterable[Path]) -> None:
|
|
256
|
+
self._files = list(files)
|
|
257
|
+
if len(self._files) == 0:
|
|
258
|
+
raise ValueError("At least 1 fingerprint file must be provided")
|
|
259
|
+
|
|
260
|
+
def __getitem__(self, idxs: tp.Sequence[int]) -> NDArray[np.uint8]:
|
|
261
|
+
return _get_fingerprints_from_file_seq(self._files, idxs)
|
|
262
|
+
|
|
263
|
+
@property
|
|
264
|
+
def shape(self) -> tuple[int, int]:
|
|
265
|
+
shape, dtype, _, _ = _get_fps_file_shape_and_dtype(
|
|
266
|
+
self._files[0], raise_if_invalid=True
|
|
267
|
+
)
|
|
268
|
+
return shape
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
# TODO: The logic of this function is pretty complicated, maybe there is a way to
|
|
272
|
+
# simplify it?
|
|
273
|
+
def _get_fingerprints_from_file_seq(
|
|
274
|
+
files: tp.Iterable[Path], idxs: tp.Sequence[int]
|
|
275
|
+
) -> NDArray[np.uint8]:
|
|
276
|
+
if sorted(idxs) != list(idxs):
|
|
277
|
+
raise ValueError("idxs must be sorted")
|
|
278
|
+
# Sequence of files is assumed to have indexes in an increasing order,
|
|
279
|
+
# for example, if the first two files have 10k fingerprints, then the
|
|
280
|
+
# assoc. idxs are 0-9999 and 10000-19999. 'idxs' will index this sequence of files
|
|
281
|
+
# iter_idxs = iter(idxs)
|
|
282
|
+
n_features = None
|
|
283
|
+
local_file_idxs = []
|
|
284
|
+
consumed_idxs = 0
|
|
285
|
+
running_count = 0
|
|
286
|
+
for f in files:
|
|
287
|
+
(num, _n_features), _, _, _ = _get_fps_file_shape_and_dtype(
|
|
288
|
+
f, raise_if_invalid=True
|
|
289
|
+
)
|
|
290
|
+
# Fetch idxs Append array([]) if no idxs in the file
|
|
291
|
+
file_idxs = list(
|
|
292
|
+
filter(lambda x: x < running_count + num, idxs[consumed_idxs:])
|
|
293
|
+
)
|
|
294
|
+
consumed_idxs += len(file_idxs)
|
|
295
|
+
local_file_idxs.append(np.array(file_idxs, dtype=np.uint64) - running_count)
|
|
296
|
+
running_count += num
|
|
297
|
+
|
|
298
|
+
if n_features is None:
|
|
299
|
+
n_features = _n_features
|
|
300
|
+
elif _n_features != n_features:
|
|
301
|
+
raise ValueError(
|
|
302
|
+
f"Incompatible in fingerprint file {f},"
|
|
303
|
+
f" expected {n_features}, found {_n_features}"
|
|
304
|
+
)
|
|
305
|
+
if len(idxs) != sum(arr.size for arr in local_file_idxs):
|
|
306
|
+
raise ValueError("idxs could not be extracted from files")
|
|
307
|
+
|
|
308
|
+
arr = np.empty((len(idxs), tp.cast(int, n_features)), dtype=np.uint8)
|
|
309
|
+
i = 0
|
|
310
|
+
for file, local_idxs in zip(files, local_file_idxs):
|
|
311
|
+
size = local_idxs.size
|
|
312
|
+
if not size:
|
|
313
|
+
continue
|
|
314
|
+
arr[i : i + size] = np.load(file, mmap_mode="r")[local_idxs].astype(
|
|
315
|
+
np.uint8, copy=False
|
|
316
|
+
)
|
|
317
|
+
i += size
|
|
318
|
+
return arr
|
|
319
|
+
|
|
320
|
+
|
|
321
|
+
# TODO: Skipping invalid smiles is a bit inefficient
|
|
322
|
+
# NOTE: Mostly convenient for usage in multiprocessing workflows
|
|
323
|
+
@dataclasses.dataclass
|
|
324
|
+
class _FingerprintFileCreator:
|
|
325
|
+
dtype: str
|
|
326
|
+
out_dir: Path
|
|
327
|
+
out_name: str
|
|
328
|
+
digits: int | None
|
|
329
|
+
pack: bool
|
|
330
|
+
kind: str
|
|
331
|
+
n_features: int
|
|
332
|
+
sanitize: str
|
|
333
|
+
skip_invalid: bool
|
|
334
|
+
verbose: bool
|
|
335
|
+
|
|
336
|
+
def __call__(self, input_: tuple[int, tp.Sequence[str]]) -> None:
|
|
337
|
+
console = get_console(self.verbose)
|
|
338
|
+
fpg = _get_generator(self.kind, self.n_features)
|
|
339
|
+
file_idx, batch = input_
|
|
340
|
+
fps = np.empty((len(batch), self.n_features), dtype=self.dtype)
|
|
341
|
+
out_name = self.out_name
|
|
342
|
+
sanitize_flags = _get_sanitize_flags(self.sanitize)
|
|
343
|
+
invalid = []
|
|
344
|
+
for i, smi in enumerate(batch):
|
|
345
|
+
mol = MolFromSmiles(smi, sanitize=False)
|
|
346
|
+
if mol is None:
|
|
347
|
+
if self.skip_invalid:
|
|
348
|
+
invalid.append(i)
|
|
349
|
+
continue
|
|
350
|
+
else:
|
|
351
|
+
raise ValueError(f"Unable to parse smiles {smi}")
|
|
352
|
+
try:
|
|
353
|
+
SanitizeMol(mol, sanitizeOps=sanitize_flags) # Raises if invalid
|
|
354
|
+
except Exception:
|
|
355
|
+
if self.skip_invalid:
|
|
356
|
+
invalid.append(i)
|
|
357
|
+
continue
|
|
358
|
+
raise
|
|
359
|
+
fps[i, :] = fpg.GetFingerprintAsNumPy(mol)
|
|
360
|
+
if self.pack:
|
|
361
|
+
fps = pack_fingerprints(fps)
|
|
362
|
+
if self.digits is not None:
|
|
363
|
+
out_name = f"{out_name}.{str(file_idx).zfill(self.digits)}"
|
|
364
|
+
if invalid:
|
|
365
|
+
prev_num = len(fps)
|
|
366
|
+
fps = np.delete(fps, invalid, axis=0)
|
|
367
|
+
new_num = len(fps)
|
|
368
|
+
console.print(
|
|
369
|
+
f"File {file_idx}: Generated {new_num} fingerprints\n"
|
|
370
|
+
f" File {file_idx}: Skipped {prev_num - new_num} invalid smiles"
|
|
371
|
+
)
|
|
372
|
+
np.save(self.out_dir / out_name, fps)
|
|
373
|
+
|
|
374
|
+
|
|
375
|
+
@dataclasses.dataclass
|
|
376
|
+
class _FingerprintArrayFiller:
|
|
377
|
+
invalid_mask_shmem_name: str
|
|
378
|
+
shmem_name: str
|
|
379
|
+
kind: str
|
|
380
|
+
fp_size: int
|
|
381
|
+
pack: bool
|
|
382
|
+
dtype: str
|
|
383
|
+
num_smiles: int
|
|
384
|
+
sanitize: str
|
|
385
|
+
skip_invalid: bool
|
|
386
|
+
|
|
387
|
+
def __call__(self, idx_range: tuple[int, int], batch: tp.Sequence[str]) -> None:
|
|
388
|
+
fpg = _get_generator(self.kind, self.fp_size)
|
|
389
|
+
(idx0, idx1) = idx_range
|
|
390
|
+
fps_shmem = shmem.SharedMemory(name=self.shmem_name)
|
|
391
|
+
invalid_mask_shmem = shmem.SharedMemory(name=self.invalid_mask_shmem_name)
|
|
392
|
+
sanitize_flags = _get_sanitize_flags(self.sanitize)
|
|
393
|
+
|
|
394
|
+
if self.pack:
|
|
395
|
+
out_dim = (self.fp_size + 7) // 8
|
|
396
|
+
else:
|
|
397
|
+
out_dim = self.fp_size
|
|
398
|
+
fps = np.ndarray(
|
|
399
|
+
(self.num_smiles, out_dim), dtype=self.dtype, buffer=fps_shmem.buf
|
|
400
|
+
)
|
|
401
|
+
invalid_mask = np.ndarray(
|
|
402
|
+
(self.num_smiles,), dtype=np.bool, buffer=invalid_mask_shmem.buf
|
|
403
|
+
)
|
|
404
|
+
for i, smi in zip(range(idx0, idx1), batch):
|
|
405
|
+
mol = MolFromSmiles(smi, sanitize=False)
|
|
406
|
+
if mol is None:
|
|
407
|
+
if self.skip_invalid:
|
|
408
|
+
invalid_mask[i] = True
|
|
409
|
+
continue
|
|
410
|
+
else:
|
|
411
|
+
raise ValueError(f"Unable to parse smiles {smi}")
|
|
412
|
+
try:
|
|
413
|
+
SanitizeMol(mol, sanitizeOps=sanitize_flags) # Raises if invalid
|
|
414
|
+
except Exception:
|
|
415
|
+
if self.skip_invalid:
|
|
416
|
+
invalid_mask[i] = True
|
|
417
|
+
continue
|
|
418
|
+
raise
|
|
419
|
+
fp = fpg.GetFingerprintAsNumPy(mol)
|
|
420
|
+
if self.pack:
|
|
421
|
+
fp = pack_fingerprints(fp)
|
|
422
|
+
fps[i, :] = fp
|
|
423
|
+
fps_shmem.close()
|
|
424
|
+
invalid_mask_shmem.close()
|
bblean/metrics.py
ADDED
|
@@ -0,0 +1,199 @@
|
|
|
1
|
+
r"""Clustering metrics using Tanimoto similarity"""
|
|
2
|
+
|
|
3
|
+
from contextlib import nullcontext
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
from numpy.typing import NDArray
|
|
7
|
+
from rich.progress import Progress
|
|
8
|
+
|
|
9
|
+
from bblean.similarity import (
|
|
10
|
+
jt_isim_from_sum,
|
|
11
|
+
jt_sim_packed,
|
|
12
|
+
jt_isim_packed,
|
|
13
|
+
jt_isim_unpacked,
|
|
14
|
+
centroid as centroid_from_fps,
|
|
15
|
+
centroid_from_sum,
|
|
16
|
+
jt_isim_medoid,
|
|
17
|
+
)
|
|
18
|
+
from bblean.fingerprints import unpack_fingerprints, pack_fingerprints
|
|
19
|
+
|
|
20
|
+
__all__ = ["jt_isim_chi", "jt_isim_dunn", "jt_dbi"]
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def _calc_centrals(
|
|
24
|
+
cluster_fps: list[NDArray[np.uint8]],
|
|
25
|
+
kind: str,
|
|
26
|
+
input_is_packed: bool = True,
|
|
27
|
+
n_features: int | None = None,
|
|
28
|
+
pack: bool = True,
|
|
29
|
+
) -> list[NDArray[np.uint8]]:
|
|
30
|
+
if kind == "medoid":
|
|
31
|
+
return [
|
|
32
|
+
jt_isim_medoid(
|
|
33
|
+
c, input_is_packed=input_is_packed, n_features=n_features, pack=pack
|
|
34
|
+
)[1]
|
|
35
|
+
for c in cluster_fps
|
|
36
|
+
]
|
|
37
|
+
elif kind == "centroid":
|
|
38
|
+
return [
|
|
39
|
+
centroid_from_fps(
|
|
40
|
+
c, input_is_packed=input_is_packed, n_features=n_features, pack=pack
|
|
41
|
+
)
|
|
42
|
+
for c in cluster_fps
|
|
43
|
+
]
|
|
44
|
+
raise ValueError(f"Unknown arg {kind} use 'medoids|centroids'")
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def jt_isim_chi(
|
|
48
|
+
cluster_fps: list[NDArray[np.uint8]],
|
|
49
|
+
all_fps_central: NDArray[np.uint8] | str = "centroid",
|
|
50
|
+
centrals: list[NDArray[np.uint8]] | str = "centroid",
|
|
51
|
+
input_is_packed: bool = True,
|
|
52
|
+
n_features: int | None = None,
|
|
53
|
+
verbose: bool = False,
|
|
54
|
+
) -> float:
|
|
55
|
+
"""Calinski-Harabasz clustering index
|
|
56
|
+
|
|
57
|
+
An approximation to the CHI index using the Tanimoto iSIM. *Higher* is better.
|
|
58
|
+
"""
|
|
59
|
+
all_fps_num = sum(len(c) for c in cluster_fps)
|
|
60
|
+
if isinstance(all_fps_central, str):
|
|
61
|
+
if not all_fps_central == "centroid":
|
|
62
|
+
raise NotImplementedError("Currently only 'centroid' implemented for CHI")
|
|
63
|
+
if input_is_packed:
|
|
64
|
+
unpacked_clusts = [unpack_fingerprints(c, n_features) for c in cluster_fps]
|
|
65
|
+
else:
|
|
66
|
+
unpacked_clusts = cluster_fps
|
|
67
|
+
total_linear_sum = sum(np.sum(c, axis=0) for c in unpacked_clusts)
|
|
68
|
+
all_fps_central = centroid_from_sum(total_linear_sum, all_fps_num)
|
|
69
|
+
|
|
70
|
+
if isinstance(centrals, str):
|
|
71
|
+
if not centrals == "centroid":
|
|
72
|
+
raise NotImplementedError("Currently only 'centroid' implemented for CHI")
|
|
73
|
+
centrals = _calc_centrals(cluster_fps, centrals, input_is_packed, n_features)
|
|
74
|
+
else:
|
|
75
|
+
if not input_is_packed:
|
|
76
|
+
centrals = [pack_fingerprints(c) for c in centrals]
|
|
77
|
+
|
|
78
|
+
clusters_num = len(cluster_fps)
|
|
79
|
+
# Packed cluster_fps required for CHI
|
|
80
|
+
if not input_is_packed:
|
|
81
|
+
cluster_fps = [pack_fingerprints(c) for c in cluster_fps]
|
|
82
|
+
|
|
83
|
+
if clusters_num <= 1:
|
|
84
|
+
return 0
|
|
85
|
+
|
|
86
|
+
wcss = 0.0 # within-cluster sum of squares
|
|
87
|
+
bcss = 0.0 # between-cluster sum of squares
|
|
88
|
+
progress = Progress(transient=True) if verbose else nullcontext()
|
|
89
|
+
with progress as pbar:
|
|
90
|
+
if verbose:
|
|
91
|
+
task = pbar.add_task( # type: ignore
|
|
92
|
+
"[italic]Calculating CHI[/italic]...",
|
|
93
|
+
total=(len(centrals)),
|
|
94
|
+
)
|
|
95
|
+
for central, clust in zip(centrals, cluster_fps):
|
|
96
|
+
# NOTE: In the original implementation there isn't a (1 - jt...) here (!)
|
|
97
|
+
bcss += (
|
|
98
|
+
len(clust) * (1 - jt_sim_packed(all_fps_central, central).item()) ** 2
|
|
99
|
+
)
|
|
100
|
+
d = 1 - jt_sim_packed(clust, central)
|
|
101
|
+
wcss += np.dot(d, d)
|
|
102
|
+
if verbose:
|
|
103
|
+
pbar.update(task, advance=1) # type: ignore
|
|
104
|
+
# TODO: When can the denom be 0?
|
|
105
|
+
return bcss * (all_fps_num - clusters_num) / (wcss * (clusters_num - 1))
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def jt_dbi(
|
|
109
|
+
cluster_fps: list[NDArray[np.uint8]],
|
|
110
|
+
centrals: list[NDArray[np.uint8]] | str = "centroid",
|
|
111
|
+
input_is_packed: bool = True,
|
|
112
|
+
n_features: int | None = None,
|
|
113
|
+
verbose: bool = False,
|
|
114
|
+
) -> float:
|
|
115
|
+
"""Davies-Bouldin clustering index
|
|
116
|
+
|
|
117
|
+
DBI index using the Tanimoto distance. *Lower* is better.
|
|
118
|
+
"""
|
|
119
|
+
if isinstance(centrals, str):
|
|
120
|
+
centrals = _calc_centrals(cluster_fps, centrals, input_is_packed, n_features)
|
|
121
|
+
else:
|
|
122
|
+
if not input_is_packed:
|
|
123
|
+
centrals = [pack_fingerprints(c) for c in centrals]
|
|
124
|
+
|
|
125
|
+
# Centrals can be 'medoids' or 'centroids'
|
|
126
|
+
if not input_is_packed:
|
|
127
|
+
cluster_fps = [pack_fingerprints(c) for c in cluster_fps]
|
|
128
|
+
# Packed cluster_fps required for DBI
|
|
129
|
+
|
|
130
|
+
fps_num = 0
|
|
131
|
+
S: list[float] = []
|
|
132
|
+
for central, clust_fps in zip(centrals, cluster_fps):
|
|
133
|
+
size = len(clust_fps)
|
|
134
|
+
S.append(np.sum(1 - jt_sim_packed(clust_fps, central)) / size)
|
|
135
|
+
fps_num += size
|
|
136
|
+
|
|
137
|
+
if fps_num == 0:
|
|
138
|
+
return 0
|
|
139
|
+
|
|
140
|
+
# Quadratic scaling on num. clusters
|
|
141
|
+
progress = Progress(transient=True) if verbose else nullcontext()
|
|
142
|
+
with progress as pbar:
|
|
143
|
+
if verbose:
|
|
144
|
+
task = pbar.add_task( # type: ignore
|
|
145
|
+
"[italic]Calculating DBI[/italic]...",
|
|
146
|
+
total=(len(centrals) ** 2 - len(centrals)),
|
|
147
|
+
)
|
|
148
|
+
numerator = 0.0
|
|
149
|
+
for i, central in enumerate(centrals):
|
|
150
|
+
max_d = 0.0
|
|
151
|
+
for j, other_central in enumerate(centrals):
|
|
152
|
+
if i == j:
|
|
153
|
+
continue
|
|
154
|
+
Mij = 1 - jt_sim_packed(central, other_central).item()
|
|
155
|
+
max_d = max(max_d, (S[i] + S[j]) / Mij)
|
|
156
|
+
if verbose:
|
|
157
|
+
pbar.update(task, advance=1) # type: ignore
|
|
158
|
+
numerator += max_d
|
|
159
|
+
return numerator / fps_num
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
# This is the Dunn varaint used in the original BitBirch article
|
|
163
|
+
def jt_isim_dunn(
|
|
164
|
+
cluster_fps: list[NDArray[np.uint8]],
|
|
165
|
+
input_is_packed: bool = True,
|
|
166
|
+
n_features: int | None = None,
|
|
167
|
+
verbose: bool = False,
|
|
168
|
+
) -> float:
|
|
169
|
+
"""Dunn clustering index
|
|
170
|
+
|
|
171
|
+
An approximation to the Dunn index using the Tanimoto iSIM. *Higher* is better.
|
|
172
|
+
"""
|
|
173
|
+
# Unpacked cluster_fps required for Dunn
|
|
174
|
+
if input_is_packed:
|
|
175
|
+
D = [jt_isim_packed(clust) for clust in cluster_fps]
|
|
176
|
+
cluster_fps = [unpack_fingerprints(clust, n_features) for clust in cluster_fps]
|
|
177
|
+
else:
|
|
178
|
+
D = [jt_isim_unpacked(clust) for clust in cluster_fps]
|
|
179
|
+
max_d = max(D)
|
|
180
|
+
if max_d == 0:
|
|
181
|
+
# TODO: Unclear what to return in this case, probably 1.0 is safer?
|
|
182
|
+
return 1
|
|
183
|
+
min_d = 1.00
|
|
184
|
+
# Quadratic scaling on num. clusters
|
|
185
|
+
pairs_num = len(cluster_fps) * (len(cluster_fps) - 1) // 2
|
|
186
|
+
progress = Progress(transient=True) if verbose else nullcontext()
|
|
187
|
+
with progress as pbar:
|
|
188
|
+
if verbose:
|
|
189
|
+
task = pbar.add_task( # type: ignore
|
|
190
|
+
"[italic]Calculating Dunn (slow)[/italic]...", total=pairs_num
|
|
191
|
+
)
|
|
192
|
+
for i, clust1 in enumerate(cluster_fps[:-1]):
|
|
193
|
+
for j, clust2 in enumerate(cluster_fps[i + 1 :]):
|
|
194
|
+
combined = np.sum(clust1, axis=0) + np.sum(clust2, axis=0)
|
|
195
|
+
dij = 1 - jt_isim_from_sum(combined, len(clust1) + len(clust2))
|
|
196
|
+
min_d = min(dij, min_d)
|
|
197
|
+
if verbose:
|
|
198
|
+
pbar.update(task, advance=1) # type: ignore
|
|
199
|
+
return min_d / max(D)
|