bblean 0.6.0b1__cp311-cp311-macosx_10_9_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bblean/multiround.py ADDED
@@ -0,0 +1,489 @@
1
+ # BitBIRCH-Lean Python Package: An open-source clustering module based on iSIM.
2
+ #
3
+ # If you find this software useful please cite the following articles:
4
+ # - BitBIRCH: efficient clustering of large molecular libraries:
5
+ # https://doi.org/10.1039/D5DD00030K
6
+ # - BitBIRCH Clustering Refinement Strategies:
7
+ # https://doi.org/10.1021/acs.jcim.5c00627
8
+ # - BitBIRCh-Lean:
9
+ # (preprint) https://www.biorxiv.org/content/10.1101/2025.10.22.684015v1
10
+ #
11
+ # Copyright (C) 2025 The Miranda-Quintana Lab and other BitBirch developers, comprised
12
+ # exclusively by:
13
+ # - Ramon Alain Miranda Quintana <ramirandaq@gmail.com>, <quintana@chem.ufl.edu>
14
+ # - Krisztina Zsigmond <kzsigmond@ufl.edu>
15
+ # - Ignacio Pickering <ipickering@chem.ufl.edu>
16
+ # - Kenneth Lopez Perez <klopezperez@chem.ufl.edu>
17
+ # - Miroslav Lzicar <miroslav.lzicar@deepmedchem.com>
18
+ #
19
+ # Authors of this file are:
20
+ # - Ramon Alain Miranda Quintana <ramirandaq@gmail.com>, <quintana@chem.ufl.edu>
21
+ # - Ignacio Pickering <ipickering@chem.ufl.edu>
22
+ #
23
+ # This program is free software: you can redistribute it and/or modify it under the
24
+ # terms of the GNU General Public License as published by the Free Software Foundation,
25
+ # version 3 (SPDX-License-Identifier: GPL-3.0-only).
26
+ #
27
+ # Portions of ./bblean/bitbirch.py are licensed under the BSD 3-Clause License
28
+ # Copyright (c) 2007-2024 The scikit-learn developers. All rights reserved.
29
+ # (SPDX-License-Identifier: BSD-3-Clause). Copies or reproductions of code in the
30
+ # ./bblean/bitbirch.py file must in addition adhere to the BSD-3-Clause license terms. A
31
+ # copy of the BSD-3-Clause license can be located at the root of this repository, under
32
+ # ./LICENSES/BSD-3-Clause.txt.
33
+ #
34
+ # Portions of ./bblean/bitbirch.py were previously licensed under the LGPL 3.0
35
+ # license (SPDX-License-Identifier: LGPL-3.0-only), they are relicensed in this program
36
+ # as GPL-3.0, with permission of all original copyright holders:
37
+ # - Ramon Alain Miranda Quintana <ramirandaq@gmail.com>, <quintana@chem.ufl.edu>
38
+ # - Vicky (Vic) Jung <jungvicky@ufl.edu>
39
+ # - Kenneth Lopez Perez <klopezperez@chem.ufl.edu>
40
+ # - Kate Huddleston <kdavis2@chem.ufl.edu>
41
+ #
42
+ # This program is distributed in the hope that it will be useful, but WITHOUT ANY
43
+ # WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
44
+ # PARTICULAR PURPOSE. See the GNU General Public License for more details.
45
+ #
46
+ # You should have received a copy of the GNU General Public License along with this
47
+ # program. This copy can be located at the root of this repository, under
48
+ # ./LICENSES/GPL-3.0-only.txt. If not, see <http://www.gnu.org/licenses/gpl-3.0.html>.
49
+ r"""Multi-round BitBirch workflow for clustering huge datasets in parallel"""
50
+ import sys
51
+ import math
52
+ import pickle
53
+ import typing as tp
54
+ import multiprocessing as mp
55
+ from pathlib import Path
56
+
57
+ from rich.console import Console
58
+ import numpy as np
59
+ from numpy.typing import NDArray
60
+
61
+
62
+ from bblean._console import get_console
63
+ from bblean._timer import Timer
64
+ from bblean._config import DEFAULTS
65
+ from bblean.utils import batched
66
+ from bblean.bitbirch import BitBirch
67
+ from bblean.fingerprints import _get_fps_file_num
68
+
69
+ __all__ = ["run_multiround_bitbirch"]
70
+
71
+
72
+ # Save a list of numpy arrays into a single array in a streaming fashion, avoiding
73
+ # stacking them in memory
74
+ def _numpy_streaming_save(
75
+ fp_list: list[NDArray[np.integer]] | NDArray[np.integer], path: Path | str
76
+ ) -> None:
77
+ first_arr = np.ascontiguousarray(fp_list[0])
78
+ header = np.lib.format.header_data_from_array_1_0(first_arr)
79
+ header["shape"] = (len(fp_list), len(first_arr))
80
+ path = Path(path)
81
+ if not path.suffix:
82
+ path = path.with_suffix(".npy")
83
+ with open(path, "wb") as f:
84
+ np.lib.format.write_array_header_1_0(f, header)
85
+ for arr in fp_list:
86
+ np.ascontiguousarray(arr).tofile(f)
87
+
88
+
89
+ # Glob and sort by uint bits and label, if a console is passed then the number of output
90
+ # files is printed
91
+ def _get_prev_round_buf_and_mol_idxs_files(
92
+ path: Path, round_idx: int, console: Console | None = None
93
+ ) -> list[tuple[Path, Path]]:
94
+ path = Path(path)
95
+ # TODO: Important: What should be the logic for batching? currently there doesn't
96
+ # seem to be much logic for grouping the files
97
+ buf_files = sorted(path.glob(f"round-{round_idx - 1}-bufs*.npy"))
98
+ idx_files = sorted(path.glob(f"round-{round_idx - 1}-idxs*.pkl"))
99
+ if console is not None:
100
+ console.print(f" - Collected {len(buf_files)} buffer-index file pairs")
101
+ return list(zip(buf_files, idx_files))
102
+
103
+
104
+ def _sort_batch(b: tp.Sequence[tuple[Path, Path]]) -> tuple[tuple[Path, Path], ...]:
105
+ return tuple(
106
+ sorted(
107
+ b,
108
+ key=lambda b: int(b[0].name.split("uint")[-1].split(".")[0]),
109
+ reverse=True,
110
+ )
111
+ )
112
+
113
+
114
+ def _chunk_file_pairs_in_batches(
115
+ file_pairs: tp.Sequence[tuple[Path, Path]],
116
+ bin_size: int,
117
+ console: Console | None = None,
118
+ ) -> list[tuple[str, tuple[tuple[Path, Path], ...]]]:
119
+ z = len(str(math.ceil(len(file_pairs) / bin_size)))
120
+ # Within each batch, sort the files by starting with the uint16 files, followed by
121
+ # uint8 files, this helps that (approximately) the largest clusters are fitted first
122
+ # which may improve final cluster quality
123
+ batches = [
124
+ (str(i).zfill(z), _sort_batch(b))
125
+ for i, b in enumerate(batched(file_pairs, bin_size))
126
+ ]
127
+ if console is not None:
128
+ console.print(f" - Chunked files into {len(batches)} batches")
129
+ return batches
130
+
131
+
132
+ def _save_bufs_and_mol_idxs(
133
+ out_dir: Path,
134
+ fps_bfs: dict[str, tp.Any],
135
+ mols_bfs: dict[str, tp.Any],
136
+ label: str,
137
+ round_idx: int,
138
+ ) -> None:
139
+ for dtype, buf_list in fps_bfs.items():
140
+ suffix = f".label-{label}-{dtype.replace('8', '08')}"
141
+ _numpy_streaming_save(buf_list, out_dir / f"round-{round_idx}-bufs{suffix}.npy")
142
+ with open(out_dir / f"round-{round_idx}-idxs{suffix}.pkl", mode="wb") as f:
143
+ pickle.dump(mols_bfs[dtype], f)
144
+
145
+
146
+ class _InitialRound:
147
+ def __init__(
148
+ self,
149
+ branching_factor: int,
150
+ threshold: float,
151
+ tolerance: float,
152
+ out_dir: Path | str,
153
+ refinement_before_midsection: str,
154
+ refine_threshold_change: float,
155
+ refine_merge_criterion: str,
156
+ n_features: int | None = None,
157
+ max_fps: int | None = None,
158
+ merge_criterion: str = DEFAULTS.merge_criterion,
159
+ input_is_packed: bool = True,
160
+ ) -> None:
161
+ self.n_features = n_features
162
+ self.refinement_before_midsection = refinement_before_midsection
163
+ if refinement_before_midsection not in ["full", "split", "none"]:
164
+ raise ValueError(f"Unknown refinement kind {refinement_before_midsection}")
165
+ self.branching_factor = branching_factor
166
+ self.threshold = threshold
167
+ self.tolerance = tolerance
168
+ self.out_dir = Path(out_dir)
169
+ self.max_fps = max_fps
170
+ self.merge_criterion = merge_criterion
171
+ self.refine_merge_criterion = refine_merge_criterion
172
+ self.input_is_packed = input_is_packed
173
+ self.refine_threshold_change = refine_threshold_change
174
+
175
+ def __call__(self, file_info: tuple[str, Path, int, int]) -> None:
176
+ file_label, fp_file, start_idx, end_idx = file_info
177
+
178
+ # First fit the fps in each process, in parallel.
179
+ # `reinsert_indices` required to keep track of mol idxs in different processes.
180
+ tree = BitBirch(
181
+ branching_factor=self.branching_factor,
182
+ threshold=self.threshold,
183
+ merge_criterion=self.merge_criterion,
184
+ )
185
+ range_ = range(start_idx, end_idx)
186
+ tree.fit(
187
+ fp_file,
188
+ reinsert_indices=range_,
189
+ n_features=self.n_features,
190
+ input_is_packed=self.input_is_packed,
191
+ max_fps=self.max_fps,
192
+ )
193
+ # Extract the BitFeatures of the leaves, breaking the largest cluster(s) apart,
194
+ # to prepare for refinement
195
+ tree.delete_internal_nodes()
196
+ if self.refinement_before_midsection == "none":
197
+ fps_bfs, mols_bfs = tree._bf_to_np()
198
+ elif self.refinement_before_midsection in ["split", "full"]:
199
+ fps_bfs, mols_bfs = tree._bf_to_np_refine(fp_file, initial_mol=start_idx)
200
+ if self.refinement_before_midsection == "full":
201
+ # Finish the first refinement step internally in this round
202
+ tree.reset()
203
+ tree.set_merge(
204
+ self.refine_merge_criterion,
205
+ tolerance=self.tolerance,
206
+ threshold=self.threshold + self.refine_threshold_change,
207
+ )
208
+ for bufs, mol_idxs in zip(fps_bfs.values(), mols_bfs.values()):
209
+ tree._fit_buffers(bufs, reinsert_index_seqs=mol_idxs)
210
+ del mol_idxs
211
+ del bufs
212
+
213
+ tree.delete_internal_nodes()
214
+ fps_bfs, mols_bfs = tree._bf_to_np()
215
+
216
+ _save_bufs_and_mol_idxs(self.out_dir, fps_bfs, mols_bfs, file_label, 1)
217
+
218
+
219
+ class _TreeMergingRound:
220
+ def __init__(
221
+ self,
222
+ branching_factor: int,
223
+ threshold: float,
224
+ tolerance: float,
225
+ round_idx: int,
226
+ out_dir: Path | str,
227
+ split_largest_cluster: bool,
228
+ criterion: str,
229
+ all_fp_paths: tp.Sequence[Path] = (),
230
+ ) -> None:
231
+ self.all_fp_paths = list(all_fp_paths)
232
+ self.branching_factor = branching_factor
233
+ self.threshold = threshold
234
+ self.tolerance = tolerance
235
+ self.round_idx = round_idx
236
+ self.out_dir = Path(out_dir)
237
+ self.split_largest_cluster = split_largest_cluster
238
+ self.criterion = criterion
239
+
240
+ def __call__(self, batch_info: tuple[str, tp.Sequence[tuple[Path, Path]]]) -> None:
241
+ batch_label, batch_path_pairs = batch_info
242
+ tree = BitBirch(
243
+ branching_factor=self.branching_factor,
244
+ threshold=self.threshold,
245
+ merge_criterion=self.criterion,
246
+ tolerance=self.tolerance,
247
+ )
248
+ # Rebuild a tree, inserting all BitFeatures from the corresponding batch
249
+ for buf_path, idx_path in batch_path_pairs:
250
+ with open(idx_path, "rb") as f:
251
+ mol_idxs = pickle.load(f)
252
+ tree._fit_buffers(buf_path, reinsert_index_seqs=mol_idxs)
253
+ del mol_idxs
254
+
255
+ # Either do a refinement step, or fetch and save the bufs and idxs for the next
256
+ # round
257
+ tree.delete_internal_nodes()
258
+ if self.split_largest_cluster:
259
+ fps_bfs, mols_bfs = tree._bf_to_np_refine(self.all_fp_paths)
260
+ else:
261
+ fps_bfs, mols_bfs = tree._bf_to_np()
262
+ _save_bufs_and_mol_idxs(
263
+ self.out_dir, fps_bfs, mols_bfs, batch_label, self.round_idx
264
+ )
265
+
266
+
267
+ class _FinalTreeMergingRound(_TreeMergingRound):
268
+ def __init__(
269
+ self,
270
+ branching_factor: int,
271
+ threshold: float,
272
+ tolerance: float,
273
+ criterion: str,
274
+ out_dir: Path | str,
275
+ save_tree: bool,
276
+ save_centroids: bool,
277
+ ) -> None:
278
+ super().__init__(
279
+ branching_factor, threshold, tolerance, -1, out_dir, False, criterion, ()
280
+ )
281
+ self.save_tree = save_tree
282
+ self.save_centroids = save_centroids
283
+
284
+ def __call__(self, batch_info: tuple[str, tp.Sequence[tuple[Path, Path]]]) -> None:
285
+ batch_path_pairs = batch_info[1]
286
+ tree = BitBirch(
287
+ branching_factor=self.branching_factor,
288
+ threshold=self.threshold,
289
+ merge_criterion=self.criterion,
290
+ tolerance=self.tolerance,
291
+ )
292
+ # Rebuild a tree, inserting all BitFeatures from the corresponding batch
293
+ for buf_path, idx_path in batch_path_pairs:
294
+ with open(idx_path, "rb") as f:
295
+ mol_idxs = pickle.load(f)
296
+ tree._fit_buffers(buf_path, reinsert_index_seqs=mol_idxs)
297
+ del mol_idxs
298
+
299
+ # Save clusters and exit
300
+ if self.save_tree:
301
+ # TODO: BitBIRCH is highly recursive. pickling may crash python,
302
+ # an alternative solution would be better
303
+ _old_limit = sys.getrecursionlimit()
304
+ sys.setrecursionlimit(100_000)
305
+ with open(self.out_dir / "bitbirch.pkl", mode="wb") as f:
306
+ pickle.dump(tree, f)
307
+ sys.setrecursionlimit(_old_limit)
308
+ tree.delete_internal_nodes()
309
+ if self.save_centroids:
310
+ output = tree.get_centroids_mol_ids()
311
+ with open(self.out_dir / "clusters.pkl", mode="wb") as f:
312
+ pickle.dump(output["mol_ids"], f)
313
+ with open(self.out_dir / "cluster-centroids-packed.pkl", mode="wb") as f:
314
+ pickle.dump(output["centroids"], f)
315
+ else:
316
+ with open(self.out_dir / "clusters.pkl", mode="wb") as f:
317
+ pickle.dump(tree.get_cluster_mol_ids(), f)
318
+
319
+
320
+ # Create a list of tuples of labels, file paths and start-end idxs
321
+ def _get_files_range_tuples(
322
+ files: tp.Sequence[Path],
323
+ ) -> list[tuple[str, Path, int, int]]:
324
+ running_idx = 0
325
+ files_info = []
326
+ z = len(str(len(files)))
327
+ for i, file in enumerate(files):
328
+ start_idx = running_idx
329
+ end_idx = running_idx + _get_fps_file_num(file)
330
+ files_info.append((str(i).zfill(z), file, start_idx, end_idx))
331
+ running_idx = end_idx
332
+ return files_info
333
+
334
+
335
+ # NOTE: 'full_refinement_before_midsection' indicates if the refinement of the batches
336
+ # is fully done after the tree-merging rounds, or if the data is only split before the
337
+ # tree-merging rounds
338
+ def run_multiround_bitbirch(
339
+ input_files: tp.Sequence[Path],
340
+ out_dir: Path,
341
+ n_features: int | None = None,
342
+ input_is_packed: bool = True,
343
+ num_initial_processes: int = 10,
344
+ num_midsection_processes: int | None = None,
345
+ initial_merge_criterion: str = DEFAULTS.merge_criterion,
346
+ branching_factor: int = DEFAULTS.branching_factor,
347
+ threshold: float = DEFAULTS.threshold,
348
+ midsection_threshold_change: float = DEFAULTS.refine_threshold_change,
349
+ tolerance: float = DEFAULTS.tolerance,
350
+ # Advanced
351
+ num_midsection_rounds: int = 1,
352
+ bin_size: int = 10,
353
+ max_tasks_per_process: int = 1,
354
+ refinement_before_midsection: str = "full",
355
+ split_largest_after_each_midsection_round: bool = False,
356
+ midsection_merge_criterion: str = DEFAULTS.refine_merge_criterion,
357
+ final_merge_criterion: str | None = None,
358
+ mp_context: tp.Any = None,
359
+ save_tree: bool = False,
360
+ save_centroids: bool = True,
361
+ # Debug
362
+ max_fps: int | None = None,
363
+ verbose: bool = False,
364
+ cleanup: bool = True,
365
+ ) -> Timer:
366
+ r"""Perform (possibly parallel) multi-round BitBirch clustering
367
+
368
+ .. warning::
369
+
370
+ The functionality provided by this function is stable, but its API
371
+ (the arguments it takes and its return values) may change in the future.
372
+ """
373
+ if final_merge_criterion is None:
374
+ final_merge_criterion = midsection_merge_criterion
375
+ if mp_context is None:
376
+ mp_context = mp.get_context("forkserver" if sys.platform == "linux" else None)
377
+ # Returns timing and for the different rounds
378
+ # TODO: Also return peak-rss
379
+ console = get_console(silent=not verbose)
380
+
381
+ if num_midsection_processes is None:
382
+ num_midsection_processes = num_initial_processes
383
+ else:
384
+ # Sanity check
385
+ if num_midsection_processes > num_initial_processes:
386
+ raise ValueError("Num. midsection procs. must be <= num. initial processes")
387
+
388
+ # Common params to all rounds BitBIRCH
389
+ common_kwargs: dict[str, tp.Any] = dict(
390
+ branching_factor=branching_factor,
391
+ tolerance=tolerance,
392
+ out_dir=out_dir,
393
+ )
394
+ timer = Timer()
395
+ timer.init_timing("total")
396
+
397
+ # Get starting and ending idxs for each file, and collect them into tuples
398
+ files_range_tuples = _get_files_range_tuples(input_files) # correct
399
+ num_files = len(input_files)
400
+
401
+ # Initial round of clustering
402
+ round_idx = 1
403
+ timer.init_timing(f"round-{round_idx}")
404
+ console.print(f"(Initial) Round {round_idx}: Cluster initial batch of fingerprints")
405
+
406
+ initial_fn = _InitialRound(
407
+ n_features=n_features,
408
+ refinement_before_midsection=refinement_before_midsection,
409
+ max_fps=max_fps,
410
+ merge_criterion=initial_merge_criterion,
411
+ input_is_packed=input_is_packed,
412
+ threshold=threshold,
413
+ refine_merge_criterion=midsection_merge_criterion,
414
+ refine_threshold_change=midsection_threshold_change,
415
+ **common_kwargs,
416
+ )
417
+ num_ps = min(num_initial_processes, num_files)
418
+ console.print(f" - Processing {num_files} inputs with {num_ps} processes")
419
+ with console.status("[italic]BitBirching...[/italic]", spinner="dots"):
420
+ if num_ps == 1:
421
+ for tup in files_range_tuples:
422
+ initial_fn(tup)
423
+ else:
424
+ with mp_context.Pool(
425
+ processes=num_ps, maxtasksperchild=max_tasks_per_process
426
+ ) as pool:
427
+ pool.map(initial_fn, files_range_tuples)
428
+
429
+ timer.end_timing(f"round-{round_idx}", console)
430
+ console.print_peak_mem(out_dir)
431
+
432
+ # Mid-section "Tree-Merging" rounds of clustering
433
+ for _ in range(num_midsection_rounds):
434
+ round_idx += 1
435
+ timer.init_timing(f"round-{round_idx}")
436
+ console.print(f"(Midsection) Round {round_idx}: Re-clustering in chunks")
437
+
438
+ file_pairs = _get_prev_round_buf_and_mol_idxs_files(out_dir, round_idx, console)
439
+ batches = _chunk_file_pairs_in_batches(file_pairs, bin_size, console)
440
+ merging_fn = _TreeMergingRound(
441
+ round_idx=round_idx,
442
+ all_fp_paths=input_files,
443
+ split_largest_cluster=split_largest_after_each_midsection_round,
444
+ criterion=midsection_merge_criterion,
445
+ threshold=threshold + midsection_threshold_change,
446
+ **common_kwargs,
447
+ )
448
+ num_ps = min(num_midsection_processes, len(batches))
449
+ console.print(f" - Processing {len(batches)} inputs with {num_ps} processes")
450
+ with console.status("[italic]BitBirching...[/italic]", spinner="dots"):
451
+ if num_ps == 1:
452
+ for batch_info in batches:
453
+ merging_fn(batch_info)
454
+ else:
455
+ with mp_context.Pool(
456
+ processes=num_ps, maxtasksperchild=max_tasks_per_process
457
+ ) as pool:
458
+ pool.map(merging_fn, batches)
459
+
460
+ timer.end_timing(f"round-{round_idx}", console)
461
+ console.print_peak_mem(out_dir)
462
+
463
+ # Final "Tree-Merging" round of clustering
464
+ round_idx += 1
465
+ timer.init_timing(f"round-{round_idx}")
466
+ console.print(f"(Final) Round {round_idx}: Final round of clustering")
467
+ file_pairs = _get_prev_round_buf_and_mol_idxs_files(out_dir, round_idx, console)
468
+
469
+ final_fn = _FinalTreeMergingRound(
470
+ save_tree=save_tree,
471
+ save_centroids=save_centroids,
472
+ criterion=final_merge_criterion,
473
+ threshold=threshold + midsection_threshold_change,
474
+ **common_kwargs,
475
+ )
476
+ with console.status("[italic]BitBirching...[/italic]", spinner="dots"):
477
+ final_fn(("", file_pairs))
478
+
479
+ timer.end_timing(f"round-{round_idx}", console)
480
+ console.print_peak_mem(out_dir)
481
+ # Remove intermediate files
482
+ if cleanup:
483
+ for f in out_dir.glob("round-*.npy"):
484
+ f.unlink()
485
+ for f in out_dir.glob("round-*.pkl"):
486
+ f.unlink()
487
+ console.print()
488
+ timer.end_timing("total", console, indent=False)
489
+ return timer