bblean 0.6.0b1__cp312-cp312-macosx_10_13_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bblean/__init__.py +22 -0
- bblean/_config.py +61 -0
- bblean/_console.py +187 -0
- bblean/_cpp_similarity.cpython-312-darwin.so +0 -0
- bblean/_legacy/__init__.py +0 -0
- bblean/_legacy/bb_int64.py +1252 -0
- bblean/_legacy/bb_uint8.py +1144 -0
- bblean/_memory.py +198 -0
- bblean/_merges.py +212 -0
- bblean/_py_similarity.py +278 -0
- bblean/_timer.py +42 -0
- bblean/_version.py +34 -0
- bblean/analysis.py +258 -0
- bblean/bitbirch.py +1437 -0
- bblean/cli.py +1854 -0
- bblean/csrc/README.md +1 -0
- bblean/csrc/similarity.cpp +521 -0
- bblean/fingerprints.py +424 -0
- bblean/metrics.py +199 -0
- bblean/multiround.py +489 -0
- bblean/plotting.py +479 -0
- bblean/similarity.py +304 -0
- bblean/sklearn.py +203 -0
- bblean/smiles.py +61 -0
- bblean/utils.py +130 -0
- bblean-0.6.0b1.dist-info/METADATA +283 -0
- bblean-0.6.0b1.dist-info/RECORD +31 -0
- bblean-0.6.0b1.dist-info/WHEEL +6 -0
- bblean-0.6.0b1.dist-info/entry_points.txt +2 -0
- bblean-0.6.0b1.dist-info/licenses/LICENSE +48 -0
- bblean-0.6.0b1.dist-info/top_level.txt +1 -0
bblean/multiround.py
ADDED
|
@@ -0,0 +1,489 @@
|
|
|
1
|
+
# BitBIRCH-Lean Python Package: An open-source clustering module based on iSIM.
|
|
2
|
+
#
|
|
3
|
+
# If you find this software useful please cite the following articles:
|
|
4
|
+
# - BitBIRCH: efficient clustering of large molecular libraries:
|
|
5
|
+
# https://doi.org/10.1039/D5DD00030K
|
|
6
|
+
# - BitBIRCH Clustering Refinement Strategies:
|
|
7
|
+
# https://doi.org/10.1021/acs.jcim.5c00627
|
|
8
|
+
# - BitBIRCh-Lean:
|
|
9
|
+
# (preprint) https://www.biorxiv.org/content/10.1101/2025.10.22.684015v1
|
|
10
|
+
#
|
|
11
|
+
# Copyright (C) 2025 The Miranda-Quintana Lab and other BitBirch developers, comprised
|
|
12
|
+
# exclusively by:
|
|
13
|
+
# - Ramon Alain Miranda Quintana <ramirandaq@gmail.com>, <quintana@chem.ufl.edu>
|
|
14
|
+
# - Krisztina Zsigmond <kzsigmond@ufl.edu>
|
|
15
|
+
# - Ignacio Pickering <ipickering@chem.ufl.edu>
|
|
16
|
+
# - Kenneth Lopez Perez <klopezperez@chem.ufl.edu>
|
|
17
|
+
# - Miroslav Lzicar <miroslav.lzicar@deepmedchem.com>
|
|
18
|
+
#
|
|
19
|
+
# Authors of this file are:
|
|
20
|
+
# - Ramon Alain Miranda Quintana <ramirandaq@gmail.com>, <quintana@chem.ufl.edu>
|
|
21
|
+
# - Ignacio Pickering <ipickering@chem.ufl.edu>
|
|
22
|
+
#
|
|
23
|
+
# This program is free software: you can redistribute it and/or modify it under the
|
|
24
|
+
# terms of the GNU General Public License as published by the Free Software Foundation,
|
|
25
|
+
# version 3 (SPDX-License-Identifier: GPL-3.0-only).
|
|
26
|
+
#
|
|
27
|
+
# Portions of ./bblean/bitbirch.py are licensed under the BSD 3-Clause License
|
|
28
|
+
# Copyright (c) 2007-2024 The scikit-learn developers. All rights reserved.
|
|
29
|
+
# (SPDX-License-Identifier: BSD-3-Clause). Copies or reproductions of code in the
|
|
30
|
+
# ./bblean/bitbirch.py file must in addition adhere to the BSD-3-Clause license terms. A
|
|
31
|
+
# copy of the BSD-3-Clause license can be located at the root of this repository, under
|
|
32
|
+
# ./LICENSES/BSD-3-Clause.txt.
|
|
33
|
+
#
|
|
34
|
+
# Portions of ./bblean/bitbirch.py were previously licensed under the LGPL 3.0
|
|
35
|
+
# license (SPDX-License-Identifier: LGPL-3.0-only), they are relicensed in this program
|
|
36
|
+
# as GPL-3.0, with permission of all original copyright holders:
|
|
37
|
+
# - Ramon Alain Miranda Quintana <ramirandaq@gmail.com>, <quintana@chem.ufl.edu>
|
|
38
|
+
# - Vicky (Vic) Jung <jungvicky@ufl.edu>
|
|
39
|
+
# - Kenneth Lopez Perez <klopezperez@chem.ufl.edu>
|
|
40
|
+
# - Kate Huddleston <kdavis2@chem.ufl.edu>
|
|
41
|
+
#
|
|
42
|
+
# This program is distributed in the hope that it will be useful, but WITHOUT ANY
|
|
43
|
+
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
|
|
44
|
+
# PARTICULAR PURPOSE. See the GNU General Public License for more details.
|
|
45
|
+
#
|
|
46
|
+
# You should have received a copy of the GNU General Public License along with this
|
|
47
|
+
# program. This copy can be located at the root of this repository, under
|
|
48
|
+
# ./LICENSES/GPL-3.0-only.txt. If not, see <http://www.gnu.org/licenses/gpl-3.0.html>.
|
|
49
|
+
r"""Multi-round BitBirch workflow for clustering huge datasets in parallel"""
|
|
50
|
+
import sys
|
|
51
|
+
import math
|
|
52
|
+
import pickle
|
|
53
|
+
import typing as tp
|
|
54
|
+
import multiprocessing as mp
|
|
55
|
+
from pathlib import Path
|
|
56
|
+
|
|
57
|
+
from rich.console import Console
|
|
58
|
+
import numpy as np
|
|
59
|
+
from numpy.typing import NDArray
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
from bblean._console import get_console
|
|
63
|
+
from bblean._timer import Timer
|
|
64
|
+
from bblean._config import DEFAULTS
|
|
65
|
+
from bblean.utils import batched
|
|
66
|
+
from bblean.bitbirch import BitBirch
|
|
67
|
+
from bblean.fingerprints import _get_fps_file_num
|
|
68
|
+
|
|
69
|
+
__all__ = ["run_multiround_bitbirch"]
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
# Save a list of numpy arrays into a single array in a streaming fashion, avoiding
|
|
73
|
+
# stacking them in memory
|
|
74
|
+
def _numpy_streaming_save(
|
|
75
|
+
fp_list: list[NDArray[np.integer]] | NDArray[np.integer], path: Path | str
|
|
76
|
+
) -> None:
|
|
77
|
+
first_arr = np.ascontiguousarray(fp_list[0])
|
|
78
|
+
header = np.lib.format.header_data_from_array_1_0(first_arr)
|
|
79
|
+
header["shape"] = (len(fp_list), len(first_arr))
|
|
80
|
+
path = Path(path)
|
|
81
|
+
if not path.suffix:
|
|
82
|
+
path = path.with_suffix(".npy")
|
|
83
|
+
with open(path, "wb") as f:
|
|
84
|
+
np.lib.format.write_array_header_1_0(f, header)
|
|
85
|
+
for arr in fp_list:
|
|
86
|
+
np.ascontiguousarray(arr).tofile(f)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
# Glob and sort by uint bits and label, if a console is passed then the number of output
|
|
90
|
+
# files is printed
|
|
91
|
+
def _get_prev_round_buf_and_mol_idxs_files(
|
|
92
|
+
path: Path, round_idx: int, console: Console | None = None
|
|
93
|
+
) -> list[tuple[Path, Path]]:
|
|
94
|
+
path = Path(path)
|
|
95
|
+
# TODO: Important: What should be the logic for batching? currently there doesn't
|
|
96
|
+
# seem to be much logic for grouping the files
|
|
97
|
+
buf_files = sorted(path.glob(f"round-{round_idx - 1}-bufs*.npy"))
|
|
98
|
+
idx_files = sorted(path.glob(f"round-{round_idx - 1}-idxs*.pkl"))
|
|
99
|
+
if console is not None:
|
|
100
|
+
console.print(f" - Collected {len(buf_files)} buffer-index file pairs")
|
|
101
|
+
return list(zip(buf_files, idx_files))
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def _sort_batch(b: tp.Sequence[tuple[Path, Path]]) -> tuple[tuple[Path, Path], ...]:
|
|
105
|
+
return tuple(
|
|
106
|
+
sorted(
|
|
107
|
+
b,
|
|
108
|
+
key=lambda b: int(b[0].name.split("uint")[-1].split(".")[0]),
|
|
109
|
+
reverse=True,
|
|
110
|
+
)
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def _chunk_file_pairs_in_batches(
|
|
115
|
+
file_pairs: tp.Sequence[tuple[Path, Path]],
|
|
116
|
+
bin_size: int,
|
|
117
|
+
console: Console | None = None,
|
|
118
|
+
) -> list[tuple[str, tuple[tuple[Path, Path], ...]]]:
|
|
119
|
+
z = len(str(math.ceil(len(file_pairs) / bin_size)))
|
|
120
|
+
# Within each batch, sort the files by starting with the uint16 files, followed by
|
|
121
|
+
# uint8 files, this helps that (approximately) the largest clusters are fitted first
|
|
122
|
+
# which may improve final cluster quality
|
|
123
|
+
batches = [
|
|
124
|
+
(str(i).zfill(z), _sort_batch(b))
|
|
125
|
+
for i, b in enumerate(batched(file_pairs, bin_size))
|
|
126
|
+
]
|
|
127
|
+
if console is not None:
|
|
128
|
+
console.print(f" - Chunked files into {len(batches)} batches")
|
|
129
|
+
return batches
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def _save_bufs_and_mol_idxs(
|
|
133
|
+
out_dir: Path,
|
|
134
|
+
fps_bfs: dict[str, tp.Any],
|
|
135
|
+
mols_bfs: dict[str, tp.Any],
|
|
136
|
+
label: str,
|
|
137
|
+
round_idx: int,
|
|
138
|
+
) -> None:
|
|
139
|
+
for dtype, buf_list in fps_bfs.items():
|
|
140
|
+
suffix = f".label-{label}-{dtype.replace('8', '08')}"
|
|
141
|
+
_numpy_streaming_save(buf_list, out_dir / f"round-{round_idx}-bufs{suffix}.npy")
|
|
142
|
+
with open(out_dir / f"round-{round_idx}-idxs{suffix}.pkl", mode="wb") as f:
|
|
143
|
+
pickle.dump(mols_bfs[dtype], f)
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
class _InitialRound:
|
|
147
|
+
def __init__(
|
|
148
|
+
self,
|
|
149
|
+
branching_factor: int,
|
|
150
|
+
threshold: float,
|
|
151
|
+
tolerance: float,
|
|
152
|
+
out_dir: Path | str,
|
|
153
|
+
refinement_before_midsection: str,
|
|
154
|
+
refine_threshold_change: float,
|
|
155
|
+
refine_merge_criterion: str,
|
|
156
|
+
n_features: int | None = None,
|
|
157
|
+
max_fps: int | None = None,
|
|
158
|
+
merge_criterion: str = DEFAULTS.merge_criterion,
|
|
159
|
+
input_is_packed: bool = True,
|
|
160
|
+
) -> None:
|
|
161
|
+
self.n_features = n_features
|
|
162
|
+
self.refinement_before_midsection = refinement_before_midsection
|
|
163
|
+
if refinement_before_midsection not in ["full", "split", "none"]:
|
|
164
|
+
raise ValueError(f"Unknown refinement kind {refinement_before_midsection}")
|
|
165
|
+
self.branching_factor = branching_factor
|
|
166
|
+
self.threshold = threshold
|
|
167
|
+
self.tolerance = tolerance
|
|
168
|
+
self.out_dir = Path(out_dir)
|
|
169
|
+
self.max_fps = max_fps
|
|
170
|
+
self.merge_criterion = merge_criterion
|
|
171
|
+
self.refine_merge_criterion = refine_merge_criterion
|
|
172
|
+
self.input_is_packed = input_is_packed
|
|
173
|
+
self.refine_threshold_change = refine_threshold_change
|
|
174
|
+
|
|
175
|
+
def __call__(self, file_info: tuple[str, Path, int, int]) -> None:
|
|
176
|
+
file_label, fp_file, start_idx, end_idx = file_info
|
|
177
|
+
|
|
178
|
+
# First fit the fps in each process, in parallel.
|
|
179
|
+
# `reinsert_indices` required to keep track of mol idxs in different processes.
|
|
180
|
+
tree = BitBirch(
|
|
181
|
+
branching_factor=self.branching_factor,
|
|
182
|
+
threshold=self.threshold,
|
|
183
|
+
merge_criterion=self.merge_criterion,
|
|
184
|
+
)
|
|
185
|
+
range_ = range(start_idx, end_idx)
|
|
186
|
+
tree.fit(
|
|
187
|
+
fp_file,
|
|
188
|
+
reinsert_indices=range_,
|
|
189
|
+
n_features=self.n_features,
|
|
190
|
+
input_is_packed=self.input_is_packed,
|
|
191
|
+
max_fps=self.max_fps,
|
|
192
|
+
)
|
|
193
|
+
# Extract the BitFeatures of the leaves, breaking the largest cluster(s) apart,
|
|
194
|
+
# to prepare for refinement
|
|
195
|
+
tree.delete_internal_nodes()
|
|
196
|
+
if self.refinement_before_midsection == "none":
|
|
197
|
+
fps_bfs, mols_bfs = tree._bf_to_np()
|
|
198
|
+
elif self.refinement_before_midsection in ["split", "full"]:
|
|
199
|
+
fps_bfs, mols_bfs = tree._bf_to_np_refine(fp_file, initial_mol=start_idx)
|
|
200
|
+
if self.refinement_before_midsection == "full":
|
|
201
|
+
# Finish the first refinement step internally in this round
|
|
202
|
+
tree.reset()
|
|
203
|
+
tree.set_merge(
|
|
204
|
+
self.refine_merge_criterion,
|
|
205
|
+
tolerance=self.tolerance,
|
|
206
|
+
threshold=self.threshold + self.refine_threshold_change,
|
|
207
|
+
)
|
|
208
|
+
for bufs, mol_idxs in zip(fps_bfs.values(), mols_bfs.values()):
|
|
209
|
+
tree._fit_buffers(bufs, reinsert_index_seqs=mol_idxs)
|
|
210
|
+
del mol_idxs
|
|
211
|
+
del bufs
|
|
212
|
+
|
|
213
|
+
tree.delete_internal_nodes()
|
|
214
|
+
fps_bfs, mols_bfs = tree._bf_to_np()
|
|
215
|
+
|
|
216
|
+
_save_bufs_and_mol_idxs(self.out_dir, fps_bfs, mols_bfs, file_label, 1)
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
class _TreeMergingRound:
|
|
220
|
+
def __init__(
|
|
221
|
+
self,
|
|
222
|
+
branching_factor: int,
|
|
223
|
+
threshold: float,
|
|
224
|
+
tolerance: float,
|
|
225
|
+
round_idx: int,
|
|
226
|
+
out_dir: Path | str,
|
|
227
|
+
split_largest_cluster: bool,
|
|
228
|
+
criterion: str,
|
|
229
|
+
all_fp_paths: tp.Sequence[Path] = (),
|
|
230
|
+
) -> None:
|
|
231
|
+
self.all_fp_paths = list(all_fp_paths)
|
|
232
|
+
self.branching_factor = branching_factor
|
|
233
|
+
self.threshold = threshold
|
|
234
|
+
self.tolerance = tolerance
|
|
235
|
+
self.round_idx = round_idx
|
|
236
|
+
self.out_dir = Path(out_dir)
|
|
237
|
+
self.split_largest_cluster = split_largest_cluster
|
|
238
|
+
self.criterion = criterion
|
|
239
|
+
|
|
240
|
+
def __call__(self, batch_info: tuple[str, tp.Sequence[tuple[Path, Path]]]) -> None:
|
|
241
|
+
batch_label, batch_path_pairs = batch_info
|
|
242
|
+
tree = BitBirch(
|
|
243
|
+
branching_factor=self.branching_factor,
|
|
244
|
+
threshold=self.threshold,
|
|
245
|
+
merge_criterion=self.criterion,
|
|
246
|
+
tolerance=self.tolerance,
|
|
247
|
+
)
|
|
248
|
+
# Rebuild a tree, inserting all BitFeatures from the corresponding batch
|
|
249
|
+
for buf_path, idx_path in batch_path_pairs:
|
|
250
|
+
with open(idx_path, "rb") as f:
|
|
251
|
+
mol_idxs = pickle.load(f)
|
|
252
|
+
tree._fit_buffers(buf_path, reinsert_index_seqs=mol_idxs)
|
|
253
|
+
del mol_idxs
|
|
254
|
+
|
|
255
|
+
# Either do a refinement step, or fetch and save the bufs and idxs for the next
|
|
256
|
+
# round
|
|
257
|
+
tree.delete_internal_nodes()
|
|
258
|
+
if self.split_largest_cluster:
|
|
259
|
+
fps_bfs, mols_bfs = tree._bf_to_np_refine(self.all_fp_paths)
|
|
260
|
+
else:
|
|
261
|
+
fps_bfs, mols_bfs = tree._bf_to_np()
|
|
262
|
+
_save_bufs_and_mol_idxs(
|
|
263
|
+
self.out_dir, fps_bfs, mols_bfs, batch_label, self.round_idx
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
class _FinalTreeMergingRound(_TreeMergingRound):
|
|
268
|
+
def __init__(
|
|
269
|
+
self,
|
|
270
|
+
branching_factor: int,
|
|
271
|
+
threshold: float,
|
|
272
|
+
tolerance: float,
|
|
273
|
+
criterion: str,
|
|
274
|
+
out_dir: Path | str,
|
|
275
|
+
save_tree: bool,
|
|
276
|
+
save_centroids: bool,
|
|
277
|
+
) -> None:
|
|
278
|
+
super().__init__(
|
|
279
|
+
branching_factor, threshold, tolerance, -1, out_dir, False, criterion, ()
|
|
280
|
+
)
|
|
281
|
+
self.save_tree = save_tree
|
|
282
|
+
self.save_centroids = save_centroids
|
|
283
|
+
|
|
284
|
+
def __call__(self, batch_info: tuple[str, tp.Sequence[tuple[Path, Path]]]) -> None:
|
|
285
|
+
batch_path_pairs = batch_info[1]
|
|
286
|
+
tree = BitBirch(
|
|
287
|
+
branching_factor=self.branching_factor,
|
|
288
|
+
threshold=self.threshold,
|
|
289
|
+
merge_criterion=self.criterion,
|
|
290
|
+
tolerance=self.tolerance,
|
|
291
|
+
)
|
|
292
|
+
# Rebuild a tree, inserting all BitFeatures from the corresponding batch
|
|
293
|
+
for buf_path, idx_path in batch_path_pairs:
|
|
294
|
+
with open(idx_path, "rb") as f:
|
|
295
|
+
mol_idxs = pickle.load(f)
|
|
296
|
+
tree._fit_buffers(buf_path, reinsert_index_seqs=mol_idxs)
|
|
297
|
+
del mol_idxs
|
|
298
|
+
|
|
299
|
+
# Save clusters and exit
|
|
300
|
+
if self.save_tree:
|
|
301
|
+
# TODO: BitBIRCH is highly recursive. pickling may crash python,
|
|
302
|
+
# an alternative solution would be better
|
|
303
|
+
_old_limit = sys.getrecursionlimit()
|
|
304
|
+
sys.setrecursionlimit(100_000)
|
|
305
|
+
with open(self.out_dir / "bitbirch.pkl", mode="wb") as f:
|
|
306
|
+
pickle.dump(tree, f)
|
|
307
|
+
sys.setrecursionlimit(_old_limit)
|
|
308
|
+
tree.delete_internal_nodes()
|
|
309
|
+
if self.save_centroids:
|
|
310
|
+
output = tree.get_centroids_mol_ids()
|
|
311
|
+
with open(self.out_dir / "clusters.pkl", mode="wb") as f:
|
|
312
|
+
pickle.dump(output["mol_ids"], f)
|
|
313
|
+
with open(self.out_dir / "cluster-centroids-packed.pkl", mode="wb") as f:
|
|
314
|
+
pickle.dump(output["centroids"], f)
|
|
315
|
+
else:
|
|
316
|
+
with open(self.out_dir / "clusters.pkl", mode="wb") as f:
|
|
317
|
+
pickle.dump(tree.get_cluster_mol_ids(), f)
|
|
318
|
+
|
|
319
|
+
|
|
320
|
+
# Create a list of tuples of labels, file paths and start-end idxs
|
|
321
|
+
def _get_files_range_tuples(
|
|
322
|
+
files: tp.Sequence[Path],
|
|
323
|
+
) -> list[tuple[str, Path, int, int]]:
|
|
324
|
+
running_idx = 0
|
|
325
|
+
files_info = []
|
|
326
|
+
z = len(str(len(files)))
|
|
327
|
+
for i, file in enumerate(files):
|
|
328
|
+
start_idx = running_idx
|
|
329
|
+
end_idx = running_idx + _get_fps_file_num(file)
|
|
330
|
+
files_info.append((str(i).zfill(z), file, start_idx, end_idx))
|
|
331
|
+
running_idx = end_idx
|
|
332
|
+
return files_info
|
|
333
|
+
|
|
334
|
+
|
|
335
|
+
# NOTE: 'full_refinement_before_midsection' indicates if the refinement of the batches
|
|
336
|
+
# is fully done after the tree-merging rounds, or if the data is only split before the
|
|
337
|
+
# tree-merging rounds
|
|
338
|
+
def run_multiround_bitbirch(
|
|
339
|
+
input_files: tp.Sequence[Path],
|
|
340
|
+
out_dir: Path,
|
|
341
|
+
n_features: int | None = None,
|
|
342
|
+
input_is_packed: bool = True,
|
|
343
|
+
num_initial_processes: int = 10,
|
|
344
|
+
num_midsection_processes: int | None = None,
|
|
345
|
+
initial_merge_criterion: str = DEFAULTS.merge_criterion,
|
|
346
|
+
branching_factor: int = DEFAULTS.branching_factor,
|
|
347
|
+
threshold: float = DEFAULTS.threshold,
|
|
348
|
+
midsection_threshold_change: float = DEFAULTS.refine_threshold_change,
|
|
349
|
+
tolerance: float = DEFAULTS.tolerance,
|
|
350
|
+
# Advanced
|
|
351
|
+
num_midsection_rounds: int = 1,
|
|
352
|
+
bin_size: int = 10,
|
|
353
|
+
max_tasks_per_process: int = 1,
|
|
354
|
+
refinement_before_midsection: str = "full",
|
|
355
|
+
split_largest_after_each_midsection_round: bool = False,
|
|
356
|
+
midsection_merge_criterion: str = DEFAULTS.refine_merge_criterion,
|
|
357
|
+
final_merge_criterion: str | None = None,
|
|
358
|
+
mp_context: tp.Any = None,
|
|
359
|
+
save_tree: bool = False,
|
|
360
|
+
save_centroids: bool = True,
|
|
361
|
+
# Debug
|
|
362
|
+
max_fps: int | None = None,
|
|
363
|
+
verbose: bool = False,
|
|
364
|
+
cleanup: bool = True,
|
|
365
|
+
) -> Timer:
|
|
366
|
+
r"""Perform (possibly parallel) multi-round BitBirch clustering
|
|
367
|
+
|
|
368
|
+
.. warning::
|
|
369
|
+
|
|
370
|
+
The functionality provided by this function is stable, but its API
|
|
371
|
+
(the arguments it takes and its return values) may change in the future.
|
|
372
|
+
"""
|
|
373
|
+
if final_merge_criterion is None:
|
|
374
|
+
final_merge_criterion = midsection_merge_criterion
|
|
375
|
+
if mp_context is None:
|
|
376
|
+
mp_context = mp.get_context("forkserver" if sys.platform == "linux" else None)
|
|
377
|
+
# Returns timing and for the different rounds
|
|
378
|
+
# TODO: Also return peak-rss
|
|
379
|
+
console = get_console(silent=not verbose)
|
|
380
|
+
|
|
381
|
+
if num_midsection_processes is None:
|
|
382
|
+
num_midsection_processes = num_initial_processes
|
|
383
|
+
else:
|
|
384
|
+
# Sanity check
|
|
385
|
+
if num_midsection_processes > num_initial_processes:
|
|
386
|
+
raise ValueError("Num. midsection procs. must be <= num. initial processes")
|
|
387
|
+
|
|
388
|
+
# Common params to all rounds BitBIRCH
|
|
389
|
+
common_kwargs: dict[str, tp.Any] = dict(
|
|
390
|
+
branching_factor=branching_factor,
|
|
391
|
+
tolerance=tolerance,
|
|
392
|
+
out_dir=out_dir,
|
|
393
|
+
)
|
|
394
|
+
timer = Timer()
|
|
395
|
+
timer.init_timing("total")
|
|
396
|
+
|
|
397
|
+
# Get starting and ending idxs for each file, and collect them into tuples
|
|
398
|
+
files_range_tuples = _get_files_range_tuples(input_files) # correct
|
|
399
|
+
num_files = len(input_files)
|
|
400
|
+
|
|
401
|
+
# Initial round of clustering
|
|
402
|
+
round_idx = 1
|
|
403
|
+
timer.init_timing(f"round-{round_idx}")
|
|
404
|
+
console.print(f"(Initial) Round {round_idx}: Cluster initial batch of fingerprints")
|
|
405
|
+
|
|
406
|
+
initial_fn = _InitialRound(
|
|
407
|
+
n_features=n_features,
|
|
408
|
+
refinement_before_midsection=refinement_before_midsection,
|
|
409
|
+
max_fps=max_fps,
|
|
410
|
+
merge_criterion=initial_merge_criterion,
|
|
411
|
+
input_is_packed=input_is_packed,
|
|
412
|
+
threshold=threshold,
|
|
413
|
+
refine_merge_criterion=midsection_merge_criterion,
|
|
414
|
+
refine_threshold_change=midsection_threshold_change,
|
|
415
|
+
**common_kwargs,
|
|
416
|
+
)
|
|
417
|
+
num_ps = min(num_initial_processes, num_files)
|
|
418
|
+
console.print(f" - Processing {num_files} inputs with {num_ps} processes")
|
|
419
|
+
with console.status("[italic]BitBirching...[/italic]", spinner="dots"):
|
|
420
|
+
if num_ps == 1:
|
|
421
|
+
for tup in files_range_tuples:
|
|
422
|
+
initial_fn(tup)
|
|
423
|
+
else:
|
|
424
|
+
with mp_context.Pool(
|
|
425
|
+
processes=num_ps, maxtasksperchild=max_tasks_per_process
|
|
426
|
+
) as pool:
|
|
427
|
+
pool.map(initial_fn, files_range_tuples)
|
|
428
|
+
|
|
429
|
+
timer.end_timing(f"round-{round_idx}", console)
|
|
430
|
+
console.print_peak_mem(out_dir)
|
|
431
|
+
|
|
432
|
+
# Mid-section "Tree-Merging" rounds of clustering
|
|
433
|
+
for _ in range(num_midsection_rounds):
|
|
434
|
+
round_idx += 1
|
|
435
|
+
timer.init_timing(f"round-{round_idx}")
|
|
436
|
+
console.print(f"(Midsection) Round {round_idx}: Re-clustering in chunks")
|
|
437
|
+
|
|
438
|
+
file_pairs = _get_prev_round_buf_and_mol_idxs_files(out_dir, round_idx, console)
|
|
439
|
+
batches = _chunk_file_pairs_in_batches(file_pairs, bin_size, console)
|
|
440
|
+
merging_fn = _TreeMergingRound(
|
|
441
|
+
round_idx=round_idx,
|
|
442
|
+
all_fp_paths=input_files,
|
|
443
|
+
split_largest_cluster=split_largest_after_each_midsection_round,
|
|
444
|
+
criterion=midsection_merge_criterion,
|
|
445
|
+
threshold=threshold + midsection_threshold_change,
|
|
446
|
+
**common_kwargs,
|
|
447
|
+
)
|
|
448
|
+
num_ps = min(num_midsection_processes, len(batches))
|
|
449
|
+
console.print(f" - Processing {len(batches)} inputs with {num_ps} processes")
|
|
450
|
+
with console.status("[italic]BitBirching...[/italic]", spinner="dots"):
|
|
451
|
+
if num_ps == 1:
|
|
452
|
+
for batch_info in batches:
|
|
453
|
+
merging_fn(batch_info)
|
|
454
|
+
else:
|
|
455
|
+
with mp_context.Pool(
|
|
456
|
+
processes=num_ps, maxtasksperchild=max_tasks_per_process
|
|
457
|
+
) as pool:
|
|
458
|
+
pool.map(merging_fn, batches)
|
|
459
|
+
|
|
460
|
+
timer.end_timing(f"round-{round_idx}", console)
|
|
461
|
+
console.print_peak_mem(out_dir)
|
|
462
|
+
|
|
463
|
+
# Final "Tree-Merging" round of clustering
|
|
464
|
+
round_idx += 1
|
|
465
|
+
timer.init_timing(f"round-{round_idx}")
|
|
466
|
+
console.print(f"(Final) Round {round_idx}: Final round of clustering")
|
|
467
|
+
file_pairs = _get_prev_round_buf_and_mol_idxs_files(out_dir, round_idx, console)
|
|
468
|
+
|
|
469
|
+
final_fn = _FinalTreeMergingRound(
|
|
470
|
+
save_tree=save_tree,
|
|
471
|
+
save_centroids=save_centroids,
|
|
472
|
+
criterion=final_merge_criterion,
|
|
473
|
+
threshold=threshold + midsection_threshold_change,
|
|
474
|
+
**common_kwargs,
|
|
475
|
+
)
|
|
476
|
+
with console.status("[italic]BitBirching...[/italic]", spinner="dots"):
|
|
477
|
+
final_fn(("", file_pairs))
|
|
478
|
+
|
|
479
|
+
timer.end_timing(f"round-{round_idx}", console)
|
|
480
|
+
console.print_peak_mem(out_dir)
|
|
481
|
+
# Remove intermediate files
|
|
482
|
+
if cleanup:
|
|
483
|
+
for f in out_dir.glob("round-*.npy"):
|
|
484
|
+
f.unlink()
|
|
485
|
+
for f in out_dir.glob("round-*.pkl"):
|
|
486
|
+
f.unlink()
|
|
487
|
+
console.print()
|
|
488
|
+
timer.end_timing("total", console, indent=False)
|
|
489
|
+
return timer
|