bblean 0.6.0b2__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bblean/fingerprints.py ADDED
@@ -0,0 +1,424 @@
1
+ r"""Utilites for manipulating fingerprints and fingerprint files"""
2
+
3
+ import warnings
4
+ import dataclasses
5
+ from pathlib import Path
6
+ from numpy.typing import NDArray, DTypeLike
7
+ import numpy as np
8
+ import typing as tp
9
+ import multiprocessing.shared_memory as shmem
10
+
11
+ from rich.console import Console
12
+ from rdkit.Chem import rdFingerprintGenerator, MolFromSmiles, SanitizeFlags, SanitizeMol
13
+
14
+ from bblean._config import DEFAULTS
15
+ from bblean._console import get_console
16
+
17
+ __all__ = [
18
+ "make_fake_fingerprints",
19
+ "fps_from_smiles",
20
+ "pack_fingerprints",
21
+ "unpack_fingerprints",
22
+ ]
23
+
24
+
25
+ # Deprecated
26
+ def calc_centroid(
27
+ linear_sum: NDArray[np.integer], n_samples: int, *, pack: bool = True
28
+ ) -> NDArray[np.uint8]:
29
+ warnings.warn(
30
+ "Please use `bblean.similarity.centroid_from_sum(...)` instead",
31
+ DeprecationWarning,
32
+ stacklevel=2,
33
+ )
34
+ if n_samples <= 1:
35
+ centroid = linear_sum.astype(np.uint8, copy=False)
36
+ else:
37
+ centroid = (linear_sum >= n_samples * 0.5).view(np.uint8)
38
+ if pack:
39
+ return np.packbits(centroid, axis=-1)
40
+ return centroid
41
+
42
+
43
+ centroid_from_sum = calc_centroid
44
+
45
+
46
+ def pack_fingerprints(a: NDArray[np.uint8]) -> NDArray[np.uint8]:
47
+ r"""Pack binary (only 0s and 1s) uint8 fingerprint arrays"""
48
+ # packbits may pad with zeros if n_features is not a multiple of 8
49
+ return np.packbits(a, axis=-1)
50
+
51
+
52
+ def unpack_fingerprints(
53
+ a: NDArray[np.uint8], n_features: int | None = None
54
+ ) -> NDArray[np.uint8]:
55
+ r"""Unpack packed uint8 arrays into binary uint8 arrays (with only 0s and 1s)
56
+
57
+ .. note::
58
+
59
+ If ``n_features`` is not passed, unpacking will only recover the correct number
60
+ of features if it is a multiple of 8, otherwise fingerprints will be padded with
61
+ zeros to the closest multiple of 8. This is generally not an issue since most
62
+ common fingerprints feature sizes (2048, 1024, etc) are multiples of 8, but if
63
+ you are using a non-standard number of features you should pass ``n_features``
64
+ explicitly.
65
+ """
66
+ # n_features is required to discard padded zeros if it is not a multiple of 8
67
+ return np.unpackbits(a, axis=-1, count=n_features)
68
+
69
+
70
+ def make_fake_fingerprints(
71
+ num: int,
72
+ n_features: int = DEFAULTS.n_features,
73
+ pack: bool = True,
74
+ seed: int | None = None,
75
+ dtype: DTypeLike = np.uint8,
76
+ ) -> NDArray[np.uint8]:
77
+ r"""Make random fingerprints with statistics similar to (some) real databases"""
78
+ import scipy.stats # Hide this import since scipy is heavy
79
+
80
+ if n_features < 1 or n_features % 8 != 0:
81
+ raise ValueError("n_features must be a multiple of 8, and greater than 0")
82
+ # Generate "synthetic" fingerprints with a popcount distribution
83
+ # similar to one in a real smiles database
84
+ # Fps are guaranteed to *not* be all zeros or all ones
85
+ if pack:
86
+ if np.dtype(dtype) != np.dtype(np.uint8):
87
+ raise ValueError("Only np.uint8 dtype is supported for packed input")
88
+ loc = 750
89
+ scale = 400
90
+ bounds = (0, n_features)
91
+ rng = np.random.default_rng(seed)
92
+ safe_bounds = (bounds[0] + 1, bounds[1] - 1)
93
+ a = (safe_bounds[0] - loc) / scale
94
+ b = (safe_bounds[1] - loc) / scale
95
+ popcounts_fake_float = scipy.stats.truncnorm.rvs(
96
+ a, b, loc=loc, scale=scale, size=num, random_state=rng
97
+ )
98
+ popcounts_fake = np.rint(popcounts_fake_float).astype(np.int64)
99
+ zerocounts_fake = n_features - popcounts_fake
100
+ repeats_fake = np.empty((num * 2), dtype=np.int64)
101
+ repeats_fake[0::2] = popcounts_fake
102
+ repeats_fake[1::2] = zerocounts_fake
103
+ initial = np.tile(np.array([1, 0], np.uint8), num)
104
+ expanded = np.repeat(initial, repeats=repeats_fake)
105
+ fps_fake = rng.permuted(expanded.reshape(num, n_features), axis=-1)
106
+ if pack:
107
+ return np.packbits(fps_fake, axis=1)
108
+ return fps_fake.astype(dtype, copy=False)
109
+
110
+
111
+ def _get_generator(kind: str, n_features: int) -> tp.Any:
112
+ if kind == "rdkit":
113
+ return rdFingerprintGenerator.GetRDKitFPGenerator(fpSize=n_features)
114
+ elif kind == "ecfp4":
115
+ return rdFingerprintGenerator.GetMorganGenerator(radius=2, fpSize=n_features)
116
+ elif kind == "ecfp6":
117
+ return rdFingerprintGenerator.GetMorganGenerator(radius=3, fpSize=n_features)
118
+ raise ValueError(f"Unknonw kind {kind}. Should be one of 'rdkit|ecfp4|ecfp6'")
119
+
120
+
121
+ def _get_sanitize_flags(sanitize: str) -> tp.Any:
122
+ if sanitize == "all":
123
+ return SanitizeFlags.SANITIZE_ALL
124
+ elif sanitize == "minimal":
125
+ flags = SanitizeFlags.SANITIZE_CLEANUP | SanitizeFlags.SANITIZE_SYMMRINGS
126
+ return flags
127
+ else:
128
+ raise ValueError("Unknown 'sanitize', must be one of 'all', 'minimal'")
129
+
130
+
131
+ @tp.overload
132
+ def fps_from_smiles(
133
+ smiles: tp.Iterable[str],
134
+ kind: str = DEFAULTS.fp_kind,
135
+ n_features: int = DEFAULTS.n_features,
136
+ dtype: DTypeLike = np.uint8,
137
+ sanitize: str = "all",
138
+ skip_invalid: tp.Literal[False] = False,
139
+ pack: bool = True,
140
+ ) -> NDArray[np.uint8]:
141
+ pass
142
+
143
+
144
+ @tp.overload
145
+ def fps_from_smiles(
146
+ smiles: tp.Iterable[str],
147
+ kind: str = DEFAULTS.fp_kind,
148
+ n_features: int = DEFAULTS.n_features,
149
+ dtype: DTypeLike = np.uint8,
150
+ sanitize: str = "all",
151
+ skip_invalid: tp.Literal[True] = True,
152
+ pack: bool = True,
153
+ ) -> tuple[NDArray[np.uint8], NDArray[np.int64]]:
154
+ pass
155
+
156
+
157
+ def fps_from_smiles(
158
+ smiles: tp.Iterable[str],
159
+ kind: str = DEFAULTS.fp_kind,
160
+ n_features: int = DEFAULTS.n_features,
161
+ dtype: DTypeLike = np.uint8,
162
+ sanitize: str = "all",
163
+ skip_invalid: bool = False,
164
+ pack: bool = True,
165
+ ) -> tp.Union[NDArray[np.uint8], tuple[NDArray[np.uint8], NDArray[np.int64]]]:
166
+ r"""Convert a sequence of smiles into chemical fingerprints"""
167
+ if n_features < 1 or n_features % 8 != 0:
168
+ raise ValueError("n_features must be a multiple of 8, and greater than 0")
169
+ if isinstance(smiles, str):
170
+ smiles = [smiles]
171
+
172
+ if pack and not (np.dtype(dtype) == np.dtype(np.uint8)):
173
+ raise ValueError("Packing only supported for uint8 dtype")
174
+
175
+ fpg = _get_generator(kind, n_features)
176
+
177
+ sanitize_flags = _get_sanitize_flags(sanitize)
178
+
179
+ smiles = list(smiles)
180
+ fps = np.empty((len(smiles), n_features), dtype=dtype)
181
+
182
+ invalid_idxs = []
183
+ for i, smi in enumerate(smiles):
184
+ mol = MolFromSmiles(smi, sanitize=False)
185
+ if mol is None:
186
+ if skip_invalid:
187
+ invalid_idxs.append(i)
188
+ continue
189
+ else:
190
+ raise ValueError(f"Unable to parse smiles {smi}")
191
+ try:
192
+ SanitizeMol(mol, sanitizeOps=sanitize_flags)
193
+ fps[i, :] = fpg.GetFingerprintAsNumPy(mol)
194
+ except Exception:
195
+ if skip_invalid:
196
+ invalid_idxs.append(i)
197
+ continue
198
+ raise
199
+
200
+ if invalid_idxs:
201
+ fps = np.delete(fps, invalid_idxs, axis=0)
202
+ if pack:
203
+ if skip_invalid:
204
+ return pack_fingerprints(fps), np.array(invalid_idxs, dtype=np.int64)
205
+ return pack_fingerprints(fps)
206
+ if skip_invalid:
207
+ return fps, np.array(invalid_idxs, dtype=np.int64)
208
+ return fps
209
+
210
+
211
+ def _get_fps_file_num(path: Path) -> int:
212
+ with open(path, mode="rb") as f:
213
+ major, minor = np.lib.format.read_magic(f)
214
+ shape, _, _ = getattr(np.lib.format, f"read_array_header_{major}_{minor}")(f)
215
+ return shape[0]
216
+
217
+
218
+ def _get_fps_file_shape_and_dtype(
219
+ path: Path, raise_if_invalid: bool = False
220
+ ) -> tuple[tuple[int, int], np.dtype, bool, bool]:
221
+ with open(path, mode="rb") as f:
222
+ major, minor = np.lib.format.read_magic(f)
223
+ shape, _, dtype = getattr(np.lib.format, f"read_array_header_{major}_{minor}")(
224
+ f
225
+ )
226
+ shape_is_valid = len(shape) == 2
227
+ dtype_is_valid = np.issubdtype(dtype, np.integer)
228
+ if raise_if_invalid and (not shape_is_valid) or (not dtype_is_valid):
229
+ raise ValueError(
230
+ f"Fingerprints file {path} is invalid. Shape: {shape}, DType {dtype}"
231
+ )
232
+ return shape, dtype, shape_is_valid, dtype_is_valid
233
+
234
+
235
+ def _print_fps_file_info(path: Path, console: Console | None = None) -> None:
236
+ if console is None:
237
+ console = Console()
238
+ shape, dtype, shape_is_valid, dtype_is_valid = _get_fps_file_shape_and_dtype(path)
239
+
240
+ console.print(f"File: {path.resolve()}")
241
+ if shape_is_valid and dtype_is_valid:
242
+ console.print(" - [green]Valid fingerprint file[/green]")
243
+ else:
244
+ console.print(" - [red]Invalid fingerprint file[/red]")
245
+ if shape_is_valid:
246
+ console.print(f" - Num. fingerprints: {shape[0]:,}")
247
+ console.print(f" - Num. features: {shape[1]:,}")
248
+ else:
249
+ console.print(f" - Shape: {shape}")
250
+ console.print(f" - DType: [yellow]{dtype.name}[/yellow]")
251
+ console.print()
252
+
253
+
254
+ class _FingerprintFileSequence:
255
+ def __init__(self, files: tp.Iterable[Path]) -> None:
256
+ self._files = list(files)
257
+ if len(self._files) == 0:
258
+ raise ValueError("At least 1 fingerprint file must be provided")
259
+
260
+ def __getitem__(self, idxs: tp.Sequence[int]) -> NDArray[np.uint8]:
261
+ return _get_fingerprints_from_file_seq(self._files, idxs)
262
+
263
+ @property
264
+ def shape(self) -> tuple[int, int]:
265
+ shape, dtype, _, _ = _get_fps_file_shape_and_dtype(
266
+ self._files[0], raise_if_invalid=True
267
+ )
268
+ return shape
269
+
270
+
271
+ # TODO: The logic of this function is pretty complicated, maybe there is a way to
272
+ # simplify it?
273
+ def _get_fingerprints_from_file_seq(
274
+ files: tp.Iterable[Path], idxs: tp.Sequence[int]
275
+ ) -> NDArray[np.uint8]:
276
+ if sorted(idxs) != list(idxs):
277
+ raise ValueError("idxs must be sorted")
278
+ # Sequence of files is assumed to have indexes in an increasing order,
279
+ # for example, if the first two files have 10k fingerprints, then the
280
+ # assoc. idxs are 0-9999 and 10000-19999. 'idxs' will index this sequence of files
281
+ # iter_idxs = iter(idxs)
282
+ n_features = None
283
+ local_file_idxs = []
284
+ consumed_idxs = 0
285
+ running_count = 0
286
+ for f in files:
287
+ (num, _n_features), _, _, _ = _get_fps_file_shape_and_dtype(
288
+ f, raise_if_invalid=True
289
+ )
290
+ # Fetch idxs Append array([]) if no idxs in the file
291
+ file_idxs = list(
292
+ filter(lambda x: x < running_count + num, idxs[consumed_idxs:])
293
+ )
294
+ consumed_idxs += len(file_idxs)
295
+ local_file_idxs.append(np.array(file_idxs, dtype=np.uint64) - running_count)
296
+ running_count += num
297
+
298
+ if n_features is None:
299
+ n_features = _n_features
300
+ elif _n_features != n_features:
301
+ raise ValueError(
302
+ f"Incompatible in fingerprint file {f},"
303
+ f" expected {n_features}, found {_n_features}"
304
+ )
305
+ if len(idxs) != sum(arr.size for arr in local_file_idxs):
306
+ raise ValueError("idxs could not be extracted from files")
307
+
308
+ arr = np.empty((len(idxs), tp.cast(int, n_features)), dtype=np.uint8)
309
+ i = 0
310
+ for file, local_idxs in zip(files, local_file_idxs):
311
+ size = local_idxs.size
312
+ if not size:
313
+ continue
314
+ arr[i : i + size] = np.load(file, mmap_mode="r")[local_idxs].astype(
315
+ np.uint8, copy=False
316
+ )
317
+ i += size
318
+ return arr
319
+
320
+
321
+ # TODO: Skipping invalid smiles is a bit inefficient
322
+ # NOTE: Mostly convenient for usage in multiprocessing workflows
323
+ @dataclasses.dataclass
324
+ class _FingerprintFileCreator:
325
+ dtype: str
326
+ out_dir: Path
327
+ out_name: str
328
+ digits: int | None
329
+ pack: bool
330
+ kind: str
331
+ n_features: int
332
+ sanitize: str
333
+ skip_invalid: bool
334
+ verbose: bool
335
+
336
+ def __call__(self, input_: tuple[int, tp.Sequence[str]]) -> None:
337
+ console = get_console(self.verbose)
338
+ fpg = _get_generator(self.kind, self.n_features)
339
+ file_idx, batch = input_
340
+ fps = np.empty((len(batch), self.n_features), dtype=self.dtype)
341
+ out_name = self.out_name
342
+ sanitize_flags = _get_sanitize_flags(self.sanitize)
343
+ invalid = []
344
+ for i, smi in enumerate(batch):
345
+ mol = MolFromSmiles(smi, sanitize=False)
346
+ if mol is None:
347
+ if self.skip_invalid:
348
+ invalid.append(i)
349
+ continue
350
+ else:
351
+ raise ValueError(f"Unable to parse smiles {smi}")
352
+ try:
353
+ SanitizeMol(mol, sanitizeOps=sanitize_flags) # Raises if invalid
354
+ except Exception:
355
+ if self.skip_invalid:
356
+ invalid.append(i)
357
+ continue
358
+ raise
359
+ fps[i, :] = fpg.GetFingerprintAsNumPy(mol)
360
+ if self.pack:
361
+ fps = pack_fingerprints(fps)
362
+ if self.digits is not None:
363
+ out_name = f"{out_name}.{str(file_idx).zfill(self.digits)}"
364
+ if invalid:
365
+ prev_num = len(fps)
366
+ fps = np.delete(fps, invalid, axis=0)
367
+ new_num = len(fps)
368
+ console.print(
369
+ f"File {file_idx}: Generated {new_num} fingerprints\n"
370
+ f" File {file_idx}: Skipped {prev_num - new_num} invalid smiles"
371
+ )
372
+ np.save(self.out_dir / out_name, fps)
373
+
374
+
375
+ @dataclasses.dataclass
376
+ class _FingerprintArrayFiller:
377
+ invalid_mask_shmem_name: str
378
+ shmem_name: str
379
+ kind: str
380
+ fp_size: int
381
+ pack: bool
382
+ dtype: str
383
+ num_smiles: int
384
+ sanitize: str
385
+ skip_invalid: bool
386
+
387
+ def __call__(self, idx_range: tuple[int, int], batch: tp.Sequence[str]) -> None:
388
+ fpg = _get_generator(self.kind, self.fp_size)
389
+ (idx0, idx1) = idx_range
390
+ fps_shmem = shmem.SharedMemory(name=self.shmem_name)
391
+ invalid_mask_shmem = shmem.SharedMemory(name=self.invalid_mask_shmem_name)
392
+ sanitize_flags = _get_sanitize_flags(self.sanitize)
393
+
394
+ if self.pack:
395
+ out_dim = (self.fp_size + 7) // 8
396
+ else:
397
+ out_dim = self.fp_size
398
+ fps = np.ndarray(
399
+ (self.num_smiles, out_dim), dtype=self.dtype, buffer=fps_shmem.buf
400
+ )
401
+ invalid_mask = np.ndarray(
402
+ (self.num_smiles,), dtype=np.bool, buffer=invalid_mask_shmem.buf
403
+ )
404
+ for i, smi in zip(range(idx0, idx1), batch):
405
+ mol = MolFromSmiles(smi, sanitize=False)
406
+ if mol is None:
407
+ if self.skip_invalid:
408
+ invalid_mask[i] = True
409
+ continue
410
+ else:
411
+ raise ValueError(f"Unable to parse smiles {smi}")
412
+ try:
413
+ SanitizeMol(mol, sanitizeOps=sanitize_flags) # Raises if invalid
414
+ except Exception:
415
+ if self.skip_invalid:
416
+ invalid_mask[i] = True
417
+ continue
418
+ raise
419
+ fp = fpg.GetFingerprintAsNumPy(mol)
420
+ if self.pack:
421
+ fp = pack_fingerprints(fp)
422
+ fps[i, :] = fp
423
+ fps_shmem.close()
424
+ invalid_mask_shmem.close()
bblean/metrics.py ADDED
@@ -0,0 +1,199 @@
1
+ r"""Clustering metrics using Tanimoto similarity"""
2
+
3
+ from contextlib import nullcontext
4
+
5
+ import numpy as np
6
+ from numpy.typing import NDArray
7
+ from rich.progress import Progress
8
+
9
+ from bblean.similarity import (
10
+ jt_isim_from_sum,
11
+ jt_sim_packed,
12
+ jt_isim_packed,
13
+ jt_isim_unpacked,
14
+ centroid as centroid_from_fps,
15
+ centroid_from_sum,
16
+ jt_isim_medoid,
17
+ )
18
+ from bblean.fingerprints import unpack_fingerprints, pack_fingerprints
19
+
20
+ __all__ = ["jt_isim_chi", "jt_isim_dunn", "jt_dbi"]
21
+
22
+
23
+ def _calc_centrals(
24
+ cluster_fps: list[NDArray[np.uint8]],
25
+ kind: str,
26
+ input_is_packed: bool = True,
27
+ n_features: int | None = None,
28
+ pack: bool = True,
29
+ ) -> list[NDArray[np.uint8]]:
30
+ if kind == "medoid":
31
+ return [
32
+ jt_isim_medoid(
33
+ c, input_is_packed=input_is_packed, n_features=n_features, pack=pack
34
+ )[1]
35
+ for c in cluster_fps
36
+ ]
37
+ elif kind == "centroid":
38
+ return [
39
+ centroid_from_fps(
40
+ c, input_is_packed=input_is_packed, n_features=n_features, pack=pack
41
+ )
42
+ for c in cluster_fps
43
+ ]
44
+ raise ValueError(f"Unknown arg {kind} use 'medoids|centroids'")
45
+
46
+
47
+ def jt_isim_chi(
48
+ cluster_fps: list[NDArray[np.uint8]],
49
+ all_fps_central: NDArray[np.uint8] | str = "centroid",
50
+ centrals: list[NDArray[np.uint8]] | str = "centroid",
51
+ input_is_packed: bool = True,
52
+ n_features: int | None = None,
53
+ verbose: bool = False,
54
+ ) -> float:
55
+ """Calinski-Harabasz clustering index
56
+
57
+ An approximation to the CHI index using the Tanimoto iSIM. *Higher* is better.
58
+ """
59
+ all_fps_num = sum(len(c) for c in cluster_fps)
60
+ if isinstance(all_fps_central, str):
61
+ if not all_fps_central == "centroid":
62
+ raise NotImplementedError("Currently only 'centroid' implemented for CHI")
63
+ if input_is_packed:
64
+ unpacked_clusts = [unpack_fingerprints(c, n_features) for c in cluster_fps]
65
+ else:
66
+ unpacked_clusts = cluster_fps
67
+ total_linear_sum = sum(np.sum(c, axis=0) for c in unpacked_clusts)
68
+ all_fps_central = centroid_from_sum(total_linear_sum, all_fps_num)
69
+
70
+ if isinstance(centrals, str):
71
+ if not centrals == "centroid":
72
+ raise NotImplementedError("Currently only 'centroid' implemented for CHI")
73
+ centrals = _calc_centrals(cluster_fps, centrals, input_is_packed, n_features)
74
+ else:
75
+ if not input_is_packed:
76
+ centrals = [pack_fingerprints(c) for c in centrals]
77
+
78
+ clusters_num = len(cluster_fps)
79
+ # Packed cluster_fps required for CHI
80
+ if not input_is_packed:
81
+ cluster_fps = [pack_fingerprints(c) for c in cluster_fps]
82
+
83
+ if clusters_num <= 1:
84
+ return 0
85
+
86
+ wcss = 0.0 # within-cluster sum of squares
87
+ bcss = 0.0 # between-cluster sum of squares
88
+ progress = Progress(transient=True) if verbose else nullcontext()
89
+ with progress as pbar:
90
+ if verbose:
91
+ task = pbar.add_task( # type: ignore
92
+ "[italic]Calculating CHI[/italic]...",
93
+ total=(len(centrals)),
94
+ )
95
+ for central, clust in zip(centrals, cluster_fps):
96
+ # NOTE: In the original implementation there isn't a (1 - jt...) here (!)
97
+ bcss += (
98
+ len(clust) * (1 - jt_sim_packed(all_fps_central, central).item()) ** 2
99
+ )
100
+ d = 1 - jt_sim_packed(clust, central)
101
+ wcss += np.dot(d, d)
102
+ if verbose:
103
+ pbar.update(task, advance=1) # type: ignore
104
+ # TODO: When can the denom be 0?
105
+ return bcss * (all_fps_num - clusters_num) / (wcss * (clusters_num - 1))
106
+
107
+
108
+ def jt_dbi(
109
+ cluster_fps: list[NDArray[np.uint8]],
110
+ centrals: list[NDArray[np.uint8]] | str = "centroid",
111
+ input_is_packed: bool = True,
112
+ n_features: int | None = None,
113
+ verbose: bool = False,
114
+ ) -> float:
115
+ """Davies-Bouldin clustering index
116
+
117
+ DBI index using the Tanimoto distance. *Lower* is better.
118
+ """
119
+ if isinstance(centrals, str):
120
+ centrals = _calc_centrals(cluster_fps, centrals, input_is_packed, n_features)
121
+ else:
122
+ if not input_is_packed:
123
+ centrals = [pack_fingerprints(c) for c in centrals]
124
+
125
+ # Centrals can be 'medoids' or 'centroids'
126
+ if not input_is_packed:
127
+ cluster_fps = [pack_fingerprints(c) for c in cluster_fps]
128
+ # Packed cluster_fps required for DBI
129
+
130
+ fps_num = 0
131
+ S: list[float] = []
132
+ for central, clust_fps in zip(centrals, cluster_fps):
133
+ size = len(clust_fps)
134
+ S.append(np.sum(1 - jt_sim_packed(clust_fps, central)) / size)
135
+ fps_num += size
136
+
137
+ if fps_num == 0:
138
+ return 0
139
+
140
+ # Quadratic scaling on num. clusters
141
+ progress = Progress(transient=True) if verbose else nullcontext()
142
+ with progress as pbar:
143
+ if verbose:
144
+ task = pbar.add_task( # type: ignore
145
+ "[italic]Calculating DBI[/italic]...",
146
+ total=(len(centrals) ** 2 - len(centrals)),
147
+ )
148
+ numerator = 0.0
149
+ for i, central in enumerate(centrals):
150
+ max_d = 0.0
151
+ for j, other_central in enumerate(centrals):
152
+ if i == j:
153
+ continue
154
+ Mij = 1 - jt_sim_packed(central, other_central).item()
155
+ max_d = max(max_d, (S[i] + S[j]) / Mij)
156
+ if verbose:
157
+ pbar.update(task, advance=1) # type: ignore
158
+ numerator += max_d
159
+ return numerator / fps_num
160
+
161
+
162
+ # This is the Dunn varaint used in the original BitBirch article
163
+ def jt_isim_dunn(
164
+ cluster_fps: list[NDArray[np.uint8]],
165
+ input_is_packed: bool = True,
166
+ n_features: int | None = None,
167
+ verbose: bool = False,
168
+ ) -> float:
169
+ """Dunn clustering index
170
+
171
+ An approximation to the Dunn index using the Tanimoto iSIM. *Higher* is better.
172
+ """
173
+ # Unpacked cluster_fps required for Dunn
174
+ if input_is_packed:
175
+ D = [jt_isim_packed(clust) for clust in cluster_fps]
176
+ cluster_fps = [unpack_fingerprints(clust, n_features) for clust in cluster_fps]
177
+ else:
178
+ D = [jt_isim_unpacked(clust) for clust in cluster_fps]
179
+ max_d = max(D)
180
+ if max_d == 0:
181
+ # TODO: Unclear what to return in this case, probably 1.0 is safer?
182
+ return 1
183
+ min_d = 1.00
184
+ # Quadratic scaling on num. clusters
185
+ pairs_num = len(cluster_fps) * (len(cluster_fps) - 1) // 2
186
+ progress = Progress(transient=True) if verbose else nullcontext()
187
+ with progress as pbar:
188
+ if verbose:
189
+ task = pbar.add_task( # type: ignore
190
+ "[italic]Calculating Dunn (slow)[/italic]...", total=pairs_num
191
+ )
192
+ for i, clust1 in enumerate(cluster_fps[:-1]):
193
+ for j, clust2 in enumerate(cluster_fps[i + 1 :]):
194
+ combined = np.sum(clust1, axis=0) + np.sum(clust2, axis=0)
195
+ dij = 1 - jt_isim_from_sum(combined, len(clust1) + len(clust2))
196
+ min_d = min(dij, min_d)
197
+ if verbose:
198
+ pbar.update(task, advance=1) # type: ignore
199
+ return min_d / max(D)