bblean 0.6.0b1__cp311-cp311-macosx_10_9_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bblean/bitbirch.py ADDED
@@ -0,0 +1,1437 @@
1
+ # BitBIRCH-Lean Python Package: An open-source clustering module based on iSIM.
2
+ #
3
+ # If you find this software useful please cite the following articles:
4
+ # - BitBIRCH: efficient clustering of large molecular libraries:
5
+ # https://doi.org/10.1039/D5DD00030K
6
+ # - BitBIRCH Clustering Refinement Strategies:
7
+ # https://doi.org/10.1021/acs.jcim.5c00627
8
+ # - BitBIRCH-Lean: TO-BE-ADDED
9
+ #
10
+ # Copyright (C) 2025 The Miranda-Quintana Lab and other BitBirch developers, comprised
11
+ # exclusively by:
12
+ # - Ramon Alain Miranda Quintana <ramirandaq@gmail.com>, <quintana@chem.ufl.edu>
13
+ # - Krisztina Zsigmond <kzsigmond@ufl.edu>
14
+ # - Ignacio Pickering <ipickering@chem.ufl.edu>
15
+ # - Kenneth Lopez Perez <klopezperez@chem.ufl.edu>
16
+ # - Miroslav Lzicar <miroslav.lzicar@deepmedchem.com>
17
+ #
18
+ # Authors of ./bblean/multiround.py are:
19
+ # - Ramon Alain Miranda Quintana <ramirandaq@gmail.com>, <quintana@chem.ufl.edu>
20
+ # - Ignacio Pickering <ipickering@chem.ufl.edu>
21
+ #
22
+ # This program is free software: you can redistribute it and/or modify it under the
23
+ # terms of the GNU General Public License as published by the Free Software Foundation,
24
+ # version 3 (SPDX-License-Identifier: GPL-3.0-only).
25
+ #
26
+ # Portions of this file are licensed under the BSD 3-Clause License
27
+ # Copyright (c) 2007-2024 The scikit-learn developers. All rights reserved.
28
+ # (SPDX-License-Identifier: BSD-3-Clause). Copies or reproductions of code in this
29
+ # file must in addition adhere to the BSD-3-Clause license terms. A
30
+ # copy of the BSD-3-Clause license can be located at the root of this repository, under
31
+ # ./LICENSES/BSD-3-Clause.txt.
32
+ #
33
+ # Portions of this file were previously licensed under the LGPL 3.0
34
+ # license (SPDX-License-Identifier: LGPL-3.0-only), they are relicensed in this program
35
+ # as GPL-3.0, with permission of all original copyright holders:
36
+ # - Ramon Alain Miranda Quintana <ramirandaq@gmail.com>, <quintana@chem.ufl.edu>
37
+ # - Vicky (Vic) Jung <jungvicky@ufl.edu>
38
+ # - Kenneth Lopez Perez <klopezperez@chem.ufl.edu>
39
+ # - Kate Huddleston <kdavis2@chem.ufl.edu>
40
+ #
41
+ # This program is distributed in the hope that it will be useful, but WITHOUT ANY
42
+ # WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
43
+ # PARTICULAR PURPOSE. See the GNU General Public License for more details.
44
+ #
45
+ # You should have received a copy of the GNU General Public License along with this
46
+ # program. This copy can be located at the root of this repository, under
47
+ # ./LICENSES/GPL-3.0-only.txt. If not, see <http://www.gnu.org/licenses/gpl-3.0.html>.
48
+ r"""BitBirch 'Lean' class for fast, memory-efficient O(N) clustering"""
49
+ from __future__ import annotations # Stringize type annotations for no runtime overhead
50
+ import typing_extensions as tpx
51
+ import os
52
+ import random
53
+ from pathlib import Path
54
+ import warnings
55
+ import typing as tp
56
+ from typing import cast
57
+ from collections import defaultdict
58
+ from weakref import WeakSet
59
+
60
+ import numpy as np
61
+ from numpy.typing import NDArray, DTypeLike
62
+
63
+ from bblean._memory import _mmap_file_and_madvise_sequential, _ArrayMemPagesManager
64
+ from bblean._merges import get_merge_accept_fn, MergeAcceptFunction, BUILTIN_MERGES
65
+ from bblean.utils import min_safe_uint
66
+ from bblean.fingerprints import (
67
+ pack_fingerprints,
68
+ _get_fingerprints_from_file_seq,
69
+ )
70
+ from bblean.similarity import (
71
+ _jt_sim_arr_vec_packed,
72
+ jt_most_dissimilar_packed,
73
+ jt_isim_medoid,
74
+ centroid_from_sum,
75
+ )
76
+
77
+ if os.getenv("BITBIRCH_NO_EXTENSIONS"):
78
+ from bblean.fingerprints import unpack_fingerprints as _unpack_fingerprints
79
+ else:
80
+ try:
81
+ # NOTE: There are small gains from using this fn but only ~3%, so don't warn for
82
+ # now if this fails, and don't expose it
83
+ from bblean._cpp_similarity import unpack_fingerprints as _unpack_fingerprints # type: ignore # noqa
84
+ except ImportError:
85
+ from bblean.fingerprints import unpack_fingerprints as _unpack_fingerprints
86
+
87
+ __all__ = ["BitBirch"]
88
+
89
+
90
+ # For backwards compatibility with the global "set_merge", keep weak references to all
91
+ # the BitBirch instances and update them when set_merge is called
92
+ _BITBIRCH_INSTANCES: WeakSet["BitBirch"] = WeakSet()
93
+
94
+
95
+ # For backwards compatibility: global function used to accept merges
96
+ _global_merge_accept: MergeAcceptFunction | None = None
97
+
98
+ _Input = tp.Union[NDArray[np.integer], list[NDArray[np.integer]]]
99
+
100
+
101
+ # For backwards compatibility: set the global merge_accept function
102
+ def set_merge(merge_criterion: str, tolerance: float = 0.05) -> None:
103
+ r"""Sets the global criteria for merging subclusters in any BitBirch tree
104
+
105
+ For usage see `BitBirch.set_merge`
106
+
107
+ .. warning::
108
+
109
+ Use of this function is highly discouraged, instead use either:
110
+ ``bb_tree = BitBirch(...)``
111
+ ``bb_tree.set_merge(merge_criterion=..., tolerance=...)``
112
+ or directly: ``bb_tree = BitBirch(..., merge_criterion=..., tolerance=...)``"
113
+
114
+ """
115
+ msg = (
116
+ "Use of the global `set_merge` function is highly discouraged,\n"
117
+ " instead use either: "
118
+ " bb_tree = BitBirch(...)\n"
119
+ " bb_tree.set_merge(merge_criterion=..., tolerance=...)\n"
120
+ " or directly: `bb_tree = BitBirch(..., merge_criterion=..., tolerance=...)`."
121
+ )
122
+ warnings.warn(msg, UserWarning)
123
+ # Set the global merge_accept function
124
+ global _global_merge_accept
125
+ _global_merge_accept = get_merge_accept_fn(merge_criterion, tolerance)
126
+ for bbirch in _BITBIRCH_INSTANCES:
127
+ bbirch._merge_accept_fn = _global_merge_accept
128
+
129
+
130
+ # Utility function to validate the n_features argument for packed inputs
131
+ def _validate_n_features(
132
+ X: _Input, input_is_packed: bool, n_features: int | None = None
133
+ ) -> int:
134
+ if len(X) == 0:
135
+ raise ValueError("Input must have at least 1 fingerprint")
136
+ if input_is_packed:
137
+ _padded_n_features = len(X[0]) * 8 if isinstance(X, list) else X.shape[1] * 8
138
+ if n_features is None:
139
+ # Assume multiple of 8
140
+ return _padded_n_features
141
+ if _padded_n_features < n_features:
142
+ raise ValueError(
143
+ "n_features is larger than the padded length, which is inconsistent"
144
+ )
145
+ return n_features
146
+
147
+ x_n_features = len(X[0]) if isinstance(X, list) else X.shape[1]
148
+ if n_features is not None:
149
+ if n_features != x_n_features:
150
+ raise ValueError(
151
+ "n_features is redundant for non-packed inputs"
152
+ " if passed, it must be equal to X.shape[1] (or len(X[0]))."
153
+ f" For passed X the inferred n_features was {x_n_features}."
154
+ " If this value is not what you expected,"
155
+ " make sure the passed X is actually unpacked."
156
+ )
157
+ return x_n_features
158
+
159
+
160
+ def _split_node(node: "_BFNode") -> tuple["_BFSubcluster", "_BFSubcluster"]:
161
+ """The node has to be split if there is no place for a new subcluster
162
+ in the node.
163
+ 1. An extra empty node and two empty subclusters are initialized.
164
+ 2. The pair of distant subclusters are found.
165
+ 3. The properties of the empty subclusters and nodes are updated
166
+ according to the nearest distance between the subclusters to the
167
+ pair of distant subclusters.
168
+ 4. The two nodes are set as children to the two subclusters.
169
+ """
170
+ n_features = node.n_features
171
+ branching_factor = node.branching_factor
172
+ new_subcluster1 = _BFSubcluster(n_features=n_features)
173
+ new_subcluster2 = _BFSubcluster(n_features=n_features)
174
+
175
+ node1 = _BFNode(branching_factor, n_features)
176
+ node2 = node # Rename for clarity
177
+ new_subcluster1.child = node1
178
+ new_subcluster2.child = node2
179
+
180
+ if node2.is_leaf:
181
+ # If is_leaf, _prev_leaf is guaranteed to be not None
182
+ # NOTE: cast seems to have a small overhead here for some reason
183
+ node1._prev_leaf = node2._prev_leaf
184
+ node2._prev_leaf._next_leaf = node1 # type: ignore
185
+ node1._next_leaf = node2
186
+ node2._prev_leaf = node1
187
+
188
+ # O(N) approximation to obtain "most dissimilar fingerprints" within an array
189
+ node1_idx, _, node1_sim, node2_sim = jt_most_dissimilar_packed(
190
+ node2.packed_centroids, n_features
191
+ )
192
+ node1_closer = node1_sim > node2_sim
193
+ # Make sure node1 and node2 are closest to themselves, even if all sims are equal.
194
+ # This can only happen when all node.packed_centroids are duplicates leading to all
195
+ # distances between centroids being zero.
196
+
197
+ # TODO: Currently this behavior is buggy (?), seems like in some cases one of the
198
+ # subclusters may *never* get updated, double check this logic
199
+ node1_closer[node1_idx] = True
200
+ subclusters = node2._subclusters.copy() # Shallow copy
201
+ node2._subclusters = [] # Reset the node
202
+ for idx, subcluster in enumerate(subclusters):
203
+ if node1_closer[idx]:
204
+ node1.append_subcluster(subcluster)
205
+ new_subcluster1.update(subcluster)
206
+ else:
207
+ node2.append_subcluster(subcluster)
208
+ new_subcluster2.update(subcluster)
209
+ return new_subcluster1, new_subcluster2
210
+
211
+
212
+ class _BFNode:
213
+ """Each node in a BitBirch tree is a _BFNode.
214
+
215
+ The _BFNode holds a maximum of branching_factor _BFSubclusters.
216
+
217
+ Parameters
218
+ ----------
219
+ branching_factor : int
220
+ Maximum number of _BFSubcluster in the node.
221
+
222
+ n_features : int
223
+ The number of features.
224
+
225
+ Attributes
226
+ ----------
227
+ _subclusters : list
228
+ List of _BFSubcluster for thre _BFNode.
229
+
230
+ _prev_leaf : _BFNode
231
+ Only useful for leaf nodes, otherwise None
232
+
233
+ _next_leaf : _BFNode
234
+ Only useful for leaf nodes, otherwise None
235
+
236
+ _packed_centroids_buf : NDArray[np.uint8]
237
+ Packed array of shape (branching_factor + 1, (n_features + 7) // 8) The code
238
+ internally manipulates this buf rather than packed_centroids, which is just a
239
+ view of this buf.
240
+
241
+ packed_centroids : ndarray of shape (branching_factor, n_features)
242
+ Packed array of shape (len(_subclusters), (n_features + 7) // 8)
243
+ View of the valid section of ``_packed_centroids_buf``.
244
+ """
245
+
246
+ # NOTE: Slots deactivates __dict__, and thus reduces memory usage of python objects
247
+ __slots__ = (
248
+ "n_features",
249
+ "_subclusters",
250
+ "_packed_centroids_buf",
251
+ "_prev_leaf",
252
+ "_next_leaf",
253
+ )
254
+
255
+ def __init__(self, branching_factor: int, n_features: int):
256
+ self.n_features = n_features
257
+ # The list of subclusters, centroids and squared norms
258
+ # to manipulate throughout.
259
+ self._subclusters: list["_BFSubcluster"] = []
260
+ # Centroids are stored packed. All centroids up to branching_factor are
261
+ # allocated in a contiguous array
262
+ self._packed_centroids_buf = np.empty(
263
+ (branching_factor + 1, (n_features + 7) // 8), dtype=np.uint8
264
+ )
265
+ # Nodes that are leaves have a non-null _prev_leaf
266
+ self._prev_leaf: tp.Optional["_BFNode"] = None
267
+ self._next_leaf: tp.Optional["_BFNode"] = None
268
+
269
+ @property
270
+ def is_leaf(self) -> bool:
271
+ return self._prev_leaf is not None
272
+
273
+ @property
274
+ def branching_factor(self) -> int:
275
+ return self._packed_centroids_buf.shape[0] - 1
276
+
277
+ @property
278
+ def packed_centroids(self) -> NDArray[np.uint8]:
279
+ # packed_centroids returns a view of the valid part of _packed_centroids_buf.
280
+ return self._packed_centroids_buf[: len(self._subclusters), :]
281
+
282
+ def append_subcluster(self, subcluster: "_BFSubcluster") -> None:
283
+ n_samples = len(self._subclusters)
284
+ self._subclusters.append(subcluster)
285
+ self._packed_centroids_buf[n_samples] = subcluster.packed_centroid
286
+
287
+ def update_split_subclusters(
288
+ self,
289
+ subcluster: "_BFSubcluster",
290
+ new_subcluster1: "_BFSubcluster",
291
+ new_subcluster2: "_BFSubcluster",
292
+ ) -> None:
293
+ """Remove a subcluster from a node and update it with the
294
+ split subclusters.
295
+ """
296
+ # Replace subcluster with new_subcluster1
297
+ idx = self._subclusters.index(subcluster)
298
+ self._subclusters[idx] = new_subcluster1
299
+ self._packed_centroids_buf[idx] = new_subcluster1.packed_centroid
300
+ # Append new_subcluster2
301
+ self.append_subcluster(new_subcluster2)
302
+
303
+ def insert_bf_subcluster(
304
+ self,
305
+ subcluster: "_BFSubcluster",
306
+ merge_accept_fn: MergeAcceptFunction,
307
+ threshold: float,
308
+ ) -> bool:
309
+ """Insert a new subcluster into the node."""
310
+ if not self._subclusters:
311
+ self.append_subcluster(subcluster)
312
+ return False
313
+
314
+ # Within this node, find the closest subcluster to the one to-be-inserted
315
+ sim_matrix = _jt_sim_arr_vec_packed(
316
+ self.packed_centroids, subcluster.packed_centroid
317
+ )
318
+ closest_idx = np.argmax(sim_matrix)
319
+ closest_subcluster = self._subclusters[closest_idx]
320
+ closest_node = closest_subcluster.child
321
+ if closest_node is None:
322
+ # The subcluster doesn't have a child node (this is a leaf node)
323
+ # attempt direct merge
324
+ merge_was_successful = closest_subcluster.merge_subcluster(
325
+ subcluster, threshold, merge_accept_fn
326
+ )
327
+ if not merge_was_successful:
328
+ # Could not merge due to criteria
329
+ # Append subcluster, and check if splitting *this node* is needed
330
+ self.append_subcluster(subcluster)
331
+ return len(self._subclusters) > self.branching_factor
332
+ # Merge success, update the centroid
333
+ self._packed_centroids_buf[closest_idx] = closest_subcluster.packed_centroid
334
+ return False
335
+
336
+ # Hard case: the closest subcluster has a child (is 'tracking'), use recursion
337
+ child_must_be_split = closest_node.insert_bf_subcluster(
338
+ subcluster, merge_accept_fn, threshold
339
+ )
340
+ if child_must_be_split:
341
+ # Split the child node and redistribute subclusters. Update
342
+ # this node with the 'tracking' subclusters of the two new children.
343
+ # Then, check if *this node* needs splitting too
344
+ new_subcluster1, new_subcluster2 = _split_node(closest_node)
345
+ self.update_split_subclusters(
346
+ closest_subcluster, new_subcluster1, new_subcluster2
347
+ )
348
+ return len(self._subclusters) > self.branching_factor
349
+
350
+ # Child need not be split, update the *tracking* closest subcluster
351
+ closest_subcluster.update(subcluster)
352
+ self._packed_centroids_buf[closest_idx] = self._subclusters[
353
+ closest_idx
354
+ ].packed_centroid
355
+ return False
356
+
357
+
358
+ class _BFSubcluster:
359
+ r"""Each subcluster in a BFNode is called a BFSubcluster.
360
+
361
+ A BFSubcluster can have a BFNode as its child.
362
+
363
+ Parameters
364
+ ----------
365
+ linear_sum : ndarray of shape (n_features,), default=None
366
+ Sample. This is kept optional to allow initialization of empty
367
+ subclusters.
368
+
369
+ Attributes
370
+ ----------
371
+ n_samples : int
372
+ Number of samples that belong to each subcluster.
373
+
374
+ linear_sum : ndarray
375
+ Linear sum of all the samples in a subcluster. Prevents holding
376
+ all sample data in memory.
377
+
378
+ packed_centroid : ndarray of shape (branching_factor + 1, n_features)
379
+ Centroid of the subcluster. Prevent recomputing of centroids when
380
+ ``BFNode.packed_centroids`` is called.
381
+
382
+ mol_indices : list, default=[]
383
+ List of indices of molecules included in the given cluster.
384
+
385
+ child : _BFNode
386
+ Child Node of the subcluster. Once a given _BFNode is set as the child
387
+ of the _BFNode, it is set to ``self.child``.
388
+ """
389
+
390
+ # NOTE: Slots deactivates __dict__, and thus reduces memory usage of python objects
391
+ __slots__ = ("_buffer", "packed_centroid", "child", "mol_indices")
392
+
393
+ def __init__(
394
+ self,
395
+ *,
396
+ linear_sum: NDArray[np.integer] | None = None,
397
+ mol_indices: tp.Sequence[int] = (),
398
+ n_features: int = 2048,
399
+ buffer: NDArray[np.integer] | None = None,
400
+ check_indices: bool = True,
401
+ ):
402
+ # NOTE: Internally, _buffer holds both "linear_sum" and "n_samples" It is
403
+ # guaranteed to always have the minimum required uint dtype It should not be
404
+ # accessed by external classes, only used internally. The individual parts can
405
+ # be accessed in a read-only way using the linear_sum and n_samples
406
+ # properties.
407
+ #
408
+ # IMPORTANT: To mutate instances of this class, *always* use the public API
409
+ # given by replace|add_to_n_samples_and_linear_sum(...)
410
+ if buffer is not None:
411
+ if linear_sum is not None:
412
+ raise ValueError("'linear_sum' and 'buffer' are mutually exclusive")
413
+ if check_indices and len(mol_indices) != buffer[-1]:
414
+ raise ValueError(
415
+ "Expected len(mol_indices) == buffer[-1],"
416
+ f" but found {len(mol_indices)} != {buffer[-1]}"
417
+ )
418
+ self._buffer = buffer
419
+ self.packed_centroid = centroid_from_sum(buffer[:-1], buffer[-1], pack=True)
420
+ else:
421
+ if linear_sum is not None:
422
+ if check_indices and len(mol_indices) != 1:
423
+ raise ValueError(
424
+ "Expected len(mol_indices) == 1,"
425
+ f" but found {len(mol_indices)} != 1"
426
+ )
427
+ buffer = np.empty((len(linear_sum) + 1,), dtype=np.uint8)
428
+ buffer[:-1] = linear_sum
429
+ buffer[-1] = 1
430
+ self._buffer = buffer
431
+ self.packed_centroid = pack_fingerprints(
432
+ linear_sum.astype(np.uint8, copy=False)
433
+ )
434
+ else:
435
+ # Empty subcluster
436
+ if check_indices and len(mol_indices) != 0:
437
+ raise ValueError(
438
+ "Expected len(mol_indices) == 0 for empty subcluster,"
439
+ f" but found {len(mol_indices)} != 0"
440
+ )
441
+ self._buffer = np.zeros((n_features + 1,), dtype=np.uint8)
442
+ self.packed_centroid = np.empty(
443
+ 0, dtype=np.uint8
444
+ ) # Will be overwritten
445
+ self.mol_indices = list(mol_indices)
446
+ self.child: tp.Optional["_BFNode"] = None
447
+
448
+ @property
449
+ def unpacked_centroid(self) -> NDArray[np.uint8]:
450
+ return _unpack_fingerprints(self.packed_centroid, self.n_features)
451
+
452
+ @property
453
+ def n_features(self) -> int:
454
+ return len(self._buffer) - 1
455
+
456
+ @property
457
+ def dtype_name(self) -> str:
458
+ return self._buffer.dtype.name
459
+
460
+ @property
461
+ def linear_sum(self) -> NDArray[np.integer]:
462
+ read_only_view = self._buffer[:-1]
463
+ read_only_view.flags.writeable = False
464
+ return read_only_view
465
+
466
+ @property
467
+ def n_samples(self) -> int:
468
+ # Returns a python int, which is guaranteed to never overflow in sums, so
469
+ # n_samples can always be safely added when accessed through this property
470
+ return self._buffer.item(-1)
471
+
472
+ # NOTE: Part of the contract is that all elements of linear sum must always be
473
+ # less or equal to n_samples. This function does not check this
474
+ def replace_n_samples_and_linear_sum(
475
+ self, n_samples: int, linear_sum: NDArray[np.integer]
476
+ ) -> None:
477
+ # Cast to the minimum uint that can hold the inputs
478
+ self._buffer = self._buffer.astype(min_safe_uint(n_samples), copy=False)
479
+ # NOTE: Assignments are safe and do not recast the buffer
480
+ self._buffer[:-1] = linear_sum
481
+ self._buffer[-1] = n_samples
482
+ self.packed_centroid = centroid_from_sum(linear_sum, n_samples, pack=True)
483
+
484
+ # NOTE: Part of the contract is that all elements of linear sum must always be
485
+ # less or equal to n_samples. This function does not check this
486
+ def add_to_n_samples_and_linear_sum(
487
+ self, n_samples: int, linear_sum: NDArray[np.integer]
488
+ ) -> None:
489
+ # Cast to the minimum uint that can hold the inputs
490
+ new_n_samples = self.n_samples + n_samples
491
+ self._buffer = self._buffer.astype(min_safe_uint(new_n_samples), copy=False)
492
+ # NOTE: Assignment and inplace add are safe and do not recast the buffer
493
+ self._buffer[:-1] += linear_sum
494
+ self._buffer[-1] = new_n_samples
495
+ self.packed_centroid = centroid_from_sum(
496
+ self._buffer[:-1], new_n_samples, pack=True
497
+ )
498
+
499
+ def update(self, subcluster: "_BFSubcluster") -> None:
500
+ self.add_to_n_samples_and_linear_sum(
501
+ subcluster.n_samples, subcluster.linear_sum
502
+ )
503
+ self.mol_indices.extend(subcluster.mol_indices)
504
+
505
+ def merge_subcluster(
506
+ self,
507
+ nominee_cluster: "_BFSubcluster",
508
+ threshold: float,
509
+ merge_accept_fn: MergeAcceptFunction,
510
+ ) -> bool:
511
+ """Check if a cluster is worthy enough to be merged. If yes, merge."""
512
+ old_n = self.n_samples
513
+ nom_n = nominee_cluster.n_samples
514
+ new_n = old_n + nom_n
515
+ old_ls = self.linear_sum
516
+ nom_ls = nominee_cluster.linear_sum
517
+ # np.add with explicit dtype is safe from overflows, e.g. :
518
+ # np.add(np.uint8(255), np.uint8(255), dtype=np.uint16) = np.uint16(510)
519
+ new_ls = np.add(old_ls, nom_ls, dtype=min_safe_uint(new_n))
520
+ if merge_accept_fn(threshold, new_ls, new_n, old_ls, nom_ls, old_n, nom_n):
521
+ self.replace_n_samples_and_linear_sum(new_n, new_ls)
522
+ self.mol_indices.extend(nominee_cluster.mol_indices)
523
+ return True
524
+ return False
525
+
526
+
527
+ class _CentroidsMolIds(tp.TypedDict):
528
+ centroids: list[NDArray[np.uint8]]
529
+ mol_ids: list[list[int]]
530
+
531
+
532
+ class _MedoidsMolIds(tp.TypedDict):
533
+ medoids: NDArray[np.uint8]
534
+ mol_ids: list[list[int]]
535
+
536
+
537
+ class BitBirch:
538
+ r"""Implements the BitBIRCH clustering algorithm, 'Lean' version
539
+
540
+ Memory and time efficient, online-learning algorithm. It constructs a tree data
541
+ structure with the cluster centroids being read off the leaf.
542
+
543
+ If you find this software useful please cite the following articles:
544
+
545
+ - BitBIRCH: efficient clustering of large molecular libraries:
546
+ https://doi.org/10.1039/D5DD00030K
547
+ - BitBIRCH Clustering Refinement Strategies:
548
+ https://doi.org/10.1021/acs.jcim.5c00627
549
+ - BitBIRCH-Lean: TO-BE-ADDED
550
+
551
+ Parameters
552
+ ----------
553
+
554
+ threshold : float = 0.65
555
+ The similarity radius of the subcluster obtained by merging a new sample and the
556
+ closest subcluster should be greater than the threshold. Otherwise a new
557
+ subcluster is started. Setting this value to be very low promotes splitting and
558
+ vice-versa.
559
+
560
+ branching_factor : int = 50
561
+ Maximum number of 'BitFeatures' subclusters in each node. If a new sample enters
562
+ such that the number of subclusters exceed the branching_factor then that node
563
+ is split into two nodes with the subclusters redistributed in each. The parent
564
+ subcluster of that node is removed and two new subclusters are added as parents
565
+ of the 2 split nodes.
566
+
567
+ merge_criterion: str
568
+ radius, diameter or tolerance. *radius*: merge subcluster based on comparison to
569
+ centroid of the cluster. *diameter*: merge subcluster based on instant Tanimoto
570
+ similarity of cluster. *tolerance*: applies tolerance threshold to diameter
571
+ merge criteria, which will merge subcluster with stricter threshold for newly
572
+ added molecules.
573
+
574
+ tolerance: float
575
+ Penalty value for similarity threshold of the 'tolerance' merge criteria.
576
+
577
+ Notes
578
+ -----
579
+
580
+ The tree data structure consists of nodes with each node holdint a number of
581
+ subclusters (``BitFeatures``). The maximum number of subclusters in a node is
582
+ determined by the branching factor. Each subcluster maintains a linear sum,
583
+ mol_indices and the number of samples in that subcluster. In addition, each
584
+ subcluster can also have a node as its child, if the subcluster is not a member of a
585
+ leaf node.
586
+
587
+ Each time a new fingerprint is fitted, it is merged with the subcluster closest to
588
+ it and the linear sum, mol_indices and the number of samples int the corresponding
589
+ subcluster are updated. This is done recursively untils the properties of a leaf
590
+ node are updated.
591
+
592
+ """
593
+
594
+ def __init__(
595
+ self,
596
+ *,
597
+ threshold: float = 0.65,
598
+ branching_factor: int = 50,
599
+ merge_criterion: str | MergeAcceptFunction | None = None,
600
+ tolerance: float | None = None,
601
+ ):
602
+ # Criterion for merges
603
+ self.threshold = threshold
604
+ self.branching_factor = branching_factor
605
+ if _global_merge_accept is not None:
606
+ # Backwards compat
607
+ if tolerance is not None:
608
+ raise ValueError(
609
+ "tolerance can only be passed if "
610
+ "the *global* set_merge function has *not* been used"
611
+ )
612
+ if merge_criterion is not None:
613
+ raise ValueError(
614
+ "merge_criterion can only be passed if "
615
+ "the *global* set_merge function has *not* been used"
616
+ )
617
+ self._merge_accept_fn = _global_merge_accept
618
+ else:
619
+ merge_criterion = "diameter" if merge_criterion is None else merge_criterion
620
+ tolerance = 0.05 if tolerance is None else tolerance
621
+ if isinstance(merge_criterion, MergeAcceptFunction):
622
+ if tolerance is not None:
623
+ raise ValueError(
624
+ "'tolerance' arg is disregarded for custom merge functions"
625
+ )
626
+ self._merge_accept_fn = merge_criterion
627
+ else:
628
+ self._merge_accept_fn = get_merge_accept_fn(merge_criterion, tolerance)
629
+
630
+ # Tree state
631
+ self._num_fitted_fps = 0
632
+ self._root: _BFNode | None = None
633
+ self._dummy_leaf = _BFNode(branching_factor=2, n_features=0)
634
+ # TODO: Type correctly
635
+ self._global_clustering_centroid_labels: NDArray[np.int64] | None = None
636
+ self._n_global_clusters = 0
637
+
638
+ # For backwards compatibility, weak-register in global state This is used to
639
+ # update the merge_accept function if the global set_merge() is called
640
+ # (discouraged)
641
+ _BITBIRCH_INSTANCES.add(self)
642
+
643
+ @property
644
+ def merge_criterion(self) -> str:
645
+ return self._merge_accept_fn.name
646
+
647
+ @merge_criterion.setter
648
+ def merge_criterion(self, value: str) -> None:
649
+ self.set_merge(criterion=value)
650
+
651
+ @property
652
+ def tolerance(self) -> float | None:
653
+ fn = self._merge_accept_fn
654
+ if hasattr(fn, "tolerance"):
655
+ return fn.tolerance
656
+ return None
657
+
658
+ @tolerance.setter
659
+ def tolerance(self, value: float) -> None:
660
+ self.set_merge(tolerance=value)
661
+
662
+ @property
663
+ def is_init(self) -> bool:
664
+ r"""Whether the tree has been initialized (True after first call to `fit()`)"""
665
+ return self._dummy_leaf._next_leaf is not None
666
+
667
+ @property
668
+ def num_fitted_fps(self) -> int:
669
+ r"""Total number of fitted fingerprints"""
670
+ return self._num_fitted_fps
671
+
672
+ def set_merge(
673
+ self,
674
+ criterion: str | MergeAcceptFunction | None = None,
675
+ *,
676
+ tolerance: float | None = None,
677
+ threshold: float | None = None,
678
+ branching_factor: int | None = None,
679
+ ) -> None:
680
+ r"""Changes the criteria for merging subclusters in this BitBirch tree
681
+
682
+ For an explanation of the parameters see the `BitBirch` class docstring.
683
+ """
684
+ if _global_merge_accept is not None:
685
+ raise ValueError(
686
+ "BitBirch.set_merge() can only called if "
687
+ "the global set_merge() function has *not* been used"
688
+ )
689
+ _tolerance = 0.05 if tolerance is None else tolerance
690
+ if isinstance(criterion, MergeAcceptFunction):
691
+ self._merge_accept_fn = criterion
692
+ elif isinstance(criterion, str):
693
+ self._merge_accept_fn = get_merge_accept_fn(criterion, _tolerance)
694
+ if hasattr(self._merge_accept_fn, "tolerance"):
695
+ self._merge_accept_fn.tolerance = _tolerance
696
+ elif tolerance is not None:
697
+ raise ValueError(f"Can't set tolerance for {self._merge_accept_fn}")
698
+ if threshold is not None:
699
+ self.threshold = threshold
700
+ if branching_factor is not None:
701
+ self.branching_factor = branching_factor
702
+
703
+ def fit(
704
+ self,
705
+ X: _Input | Path | str,
706
+ /,
707
+ reinsert_indices: tp.Iterable[int] | None = None,
708
+ input_is_packed: bool = True,
709
+ n_features: int | None = None,
710
+ max_fps: int | None = None,
711
+ ) -> tpx.Self:
712
+ r"""Build a BF Tree for the input data.
713
+
714
+ Parameters
715
+ ----------
716
+
717
+ X : {array-like, sparse matrix} of shape (n_samples, n_features)
718
+ Input data.
719
+
720
+ reinsert_indices: Iterable[int]
721
+ if ``reinsert_indices`` is passed, ``X`` corresponds only to the molecules
722
+ that will be reinserted into the tree, and ``reinsert_indices`` are the
723
+ indices associated with these molecules.
724
+
725
+ input_is_packed: bool
726
+ Whether the input fingerprints are packed
727
+
728
+ n_features: int
729
+ Number of featurs of input fingerprints. Only required for packed inputs if
730
+ it is not a multiple of 8, otherwise it is redundant.
731
+
732
+ Returns
733
+ -------
734
+
735
+ self
736
+ Fitted estimator.
737
+ """
738
+ if isinstance(X, (Path, str)):
739
+ X = _mmap_file_and_madvise_sequential(Path(X), max_fps=max_fps)
740
+ mmanager = _ArrayMemPagesManager.from_bb_input(X)
741
+ else:
742
+ X = X[:max_fps]
743
+ mmanager = _ArrayMemPagesManager.from_bb_input(X, can_release=False)
744
+
745
+ n_features = _validate_n_features(X, input_is_packed, n_features)
746
+ # Start a new tree the first time this function is called
747
+ if self._only_has_leaves:
748
+ raise ValueError("Internal nodes were released, call reset() before fit()")
749
+ if not self.is_init:
750
+ self._initialize_tree(n_features)
751
+ self._root = cast("_BFNode", self._root) # After init, this is not None
752
+
753
+ # The array iterator either copies, un-sparsifies, or does nothing
754
+ # with the array rows, depending on the kind of X passed
755
+ arr_iterable = _get_array_iterable(X, input_is_packed, n_features)
756
+ arr_iterable = cast(tp.Iterable[NDArray[np.uint8]], arr_iterable)
757
+ iterable: tp.Iterable[tuple[int, NDArray[np.uint8]]]
758
+ if reinsert_indices is None:
759
+ iterable = enumerate(arr_iterable, self.num_fitted_fps)
760
+ else:
761
+ iterable = zip(reinsert_indices, arr_iterable)
762
+
763
+ threshold = self.threshold
764
+ branching_factor = self.branching_factor
765
+ merge_accept_fn = self._merge_accept_fn
766
+
767
+ arr_idx = 0
768
+ for idx, fp in iterable:
769
+ subcluster = _BFSubcluster(
770
+ linear_sum=fp, mol_indices=[idx], n_features=n_features
771
+ )
772
+ split = self._root.insert_bf_subcluster(
773
+ subcluster, merge_accept_fn, threshold
774
+ )
775
+
776
+ if split:
777
+ new_subcluster1, new_subcluster2 = _split_node(self._root)
778
+ self._root = _BFNode(branching_factor, n_features)
779
+ self._root.append_subcluster(new_subcluster1)
780
+ self._root.append_subcluster(new_subcluster2)
781
+
782
+ self._num_fitted_fps += 1
783
+ arr_idx += 1
784
+ if mmanager.can_release and mmanager.should_release_curr_page(arr_idx):
785
+ mmanager.release_curr_page_and_update_addr()
786
+ return self
787
+
788
+ def _fit_buffers(
789
+ self,
790
+ X: _Input | Path | str,
791
+ reinsert_index_seqs: (
792
+ tp.Iterable[tp.Sequence[int]] | tp.Literal["omit"]
793
+ ) = "omit",
794
+ ) -> tpx.Self:
795
+ r"""Build a BF Tree starting from buffers
796
+
797
+ Buffers are arrays of the form:
798
+ - buffer[0:-1] = linear_sum
799
+ - buffer[-1] = n_samples
800
+ And X is either an array or a list of such buffers
801
+
802
+ If `reinsert_index_seqs` is passed, X corresponds only to the buffers to be
803
+ reinserted into the tree, and `reinsert_index_seqs` are the sequences
804
+ of indices associated with such buffers.
805
+
806
+ If `reinsert_index_seqs` is "omit", then no indices are collected in the tree.
807
+
808
+ Parameters
809
+ ----------
810
+ X : {array-like, sparse matrix} of shape (n_samples + 1, n_features)
811
+ Input data.
812
+
813
+ Returns
814
+ -------
815
+ self
816
+ Fitted estimator.
817
+ """
818
+ if isinstance(X, (Path, str)):
819
+ X = _mmap_file_and_madvise_sequential(Path(X))
820
+ mmanager = _ArrayMemPagesManager.from_bb_input(X)
821
+ else:
822
+ mmanager = _ArrayMemPagesManager.from_bb_input(X, can_release=False)
823
+
824
+ n_features = _validate_n_features(X, input_is_packed=False) - 1
825
+ # Start a new tree the first time this function is called
826
+ if self._only_has_leaves:
827
+ raise ValueError("Internal nodes were released, call reset() before fit()")
828
+ if not self.is_init:
829
+ self._initialize_tree(n_features)
830
+ self._root = cast("_BFNode", self._root) # After init, this is not None
831
+
832
+ # The array iterator either copies, un-sparsifies, or does nothing with the
833
+ # array rows, depending on the kind of X passed
834
+ arr_iterable = _get_array_iterable(X, input_is_packed=False, dtype=X[0].dtype)
835
+ merge_accept_fn = self._merge_accept_fn
836
+ threshold = self.threshold
837
+ branching_factor = self.branching_factor
838
+ idx_provider: tp.Iterable[tp.Sequence[int]]
839
+ arr_idx = 0
840
+ if reinsert_index_seqs == "omit":
841
+ idx_provider = (() for idx in range(self.num_fitted_fps))
842
+ check = False
843
+ else:
844
+ idx_provider = reinsert_index_seqs
845
+ check = True
846
+ for idxs, buf in zip(idx_provider, arr_iterable):
847
+ subcluster = _BFSubcluster(
848
+ buffer=buf, mol_indices=idxs, n_features=n_features, check_indices=check
849
+ )
850
+ split = self._root.insert_bf_subcluster(
851
+ subcluster, merge_accept_fn, threshold
852
+ )
853
+
854
+ if split:
855
+ new_subcluster1, new_subcluster2 = _split_node(self._root)
856
+ self._root = _BFNode(branching_factor, n_features)
857
+ self._root.append_subcluster(new_subcluster1)
858
+ self._root.append_subcluster(new_subcluster2)
859
+
860
+ self._num_fitted_fps += len(idxs)
861
+ arr_idx += 1
862
+ if mmanager.can_release and mmanager.should_release_curr_page(arr_idx):
863
+ mmanager.release_curr_page_and_update_addr()
864
+ return self
865
+
866
+ # Provided for backwards compatibility
867
+ def fit_reinsert(
868
+ self,
869
+ X: _Input | Path | str,
870
+ reinsert_indices: tp.Iterable[int],
871
+ input_is_packed: bool = True,
872
+ n_features: int | None = None,
873
+ max_fps: int | None = None,
874
+ ) -> tpx.Self:
875
+ r""":meta private:"""
876
+ return self.fit(X, reinsert_indices, input_is_packed, n_features, max_fps)
877
+
878
+ def _initialize_tree(self, n_features: int) -> None:
879
+ # Initialize the root (and a dummy node to get back the subclusters
880
+ self._root = _BFNode(self.branching_factor, n_features)
881
+ self._dummy_leaf._next_leaf = self._root
882
+ self._root._prev_leaf = self._dummy_leaf
883
+
884
+ def _get_leaves(self) -> tp.Iterator[_BFNode]:
885
+ r"""Yields all leaf nodes"""
886
+ if not self.is_init:
887
+ raise ValueError("The model has not been fitted yet.")
888
+ leaf = self._dummy_leaf._next_leaf
889
+ while leaf is not None:
890
+ yield leaf
891
+ leaf = leaf._next_leaf
892
+
893
+ def get_centroids_mol_ids(
894
+ self, sort: bool = True, packed: bool = True
895
+ ) -> _CentroidsMolIds:
896
+ """Get a dict with centroids and mol indices of the leaves"""
897
+ # NOTE: This is different from the original bitbirch, here outputs are sorted by
898
+ # default
899
+ centroids = []
900
+ mol_ids = []
901
+ attr = "packed_centroid" if packed else "unpacked_centroid"
902
+ for subcluster in self._get_leaf_bfs(sort=sort):
903
+ centroids.append(getattr(subcluster, attr))
904
+ mol_ids.append(subcluster.mol_indices)
905
+ return {"centroids": centroids, "mol_ids": mol_ids}
906
+
907
+ def get_centroids(
908
+ self,
909
+ sort: bool = True,
910
+ packed: bool = True,
911
+ ) -> list[NDArray[np.uint8]]:
912
+ r"""Get a list of arrays with the centroids' fingerprints"""
913
+ # NOTE: This is different from the original bitbirch, here outputs are sorted by
914
+ # default
915
+ attr = "packed_centroid" if packed else "unpacked_centroid"
916
+ return [getattr(s, attr) for s in self._get_leaf_bfs(sort=sort)]
917
+
918
+ def get_medoids_mol_ids(
919
+ self,
920
+ fps: NDArray[np.uint8],
921
+ sort: bool = True,
922
+ pack: bool = True,
923
+ global_clusters: bool = False,
924
+ input_is_packed: bool = True,
925
+ n_features: int | None = None,
926
+ ) -> _MedoidsMolIds:
927
+ """Get a dict with medoids and mol indices of the leaves"""
928
+ cluster_members = self.get_cluster_mol_ids(
929
+ sort=sort, global_clusters=global_clusters
930
+ )
931
+
932
+ if input_is_packed:
933
+ fps = _unpack_fingerprints(fps, n_features=n_features)
934
+ cluster_medoids = self._unpacked_medoids_from_members(fps, cluster_members)
935
+ if pack:
936
+ cluster_medoids = pack_fingerprints(cluster_medoids)
937
+ return {"medoids": cluster_medoids, "mol_ids": cluster_members}
938
+
939
+ @staticmethod
940
+ def _unpacked_medoids_from_members(
941
+ unpacked_fps: NDArray[np.uint8], cluster_members: tp.Sequence[list[int]]
942
+ ) -> NDArray[np.uint8]:
943
+ cluster_medoids = np.zeros(
944
+ (len(cluster_members), unpacked_fps.shape[1]), dtype=np.uint8
945
+ )
946
+ for idx, members in enumerate(cluster_members):
947
+ cluster_medoids[idx, :] = jt_isim_medoid(
948
+ unpacked_fps[members],
949
+ input_is_packed=False,
950
+ pack=False,
951
+ )[1]
952
+ return cluster_medoids
953
+
954
+ def get_medoids(
955
+ self,
956
+ fps: NDArray[np.uint8],
957
+ sort: bool = True,
958
+ pack: bool = True,
959
+ global_clusters: bool = False,
960
+ input_is_packed: bool = True,
961
+ n_features: int | None = None,
962
+ ) -> NDArray[np.uint8]:
963
+ return self.get_medoids_mol_ids(
964
+ fps, sort, pack, global_clusters, input_is_packed, n_features
965
+ )["medoids"]
966
+
967
+ def get_cluster_mol_ids(
968
+ self, sort: bool = True, global_clusters: bool = False
969
+ ) -> list[list[int]]:
970
+ r"""Get the indices of the molecules in each cluster"""
971
+ if global_clusters:
972
+ if self._global_clustering_centroid_labels is None:
973
+ raise ValueError(
974
+ "Must perform global clustering before fetching global labels"
975
+ )
976
+ bf_labels = (
977
+ self._global_clustering_centroid_labels - 1
978
+ ) # sub 1 to use as idxs
979
+
980
+ # Collect the members of all clusters
981
+ it = (bf.mol_indices for bf in self._get_leaf_bfs(sort=sort))
982
+ return self._new_ids_from_labels(it, bf_labels, self._n_global_clusters)
983
+
984
+ return [s.mol_indices for s in self._get_leaf_bfs(sort=sort)]
985
+
986
+ @staticmethod
987
+ def _new_ids_from_labels(
988
+ members: tp.Iterable[list[int]],
989
+ labels: NDArray[np.int64],
990
+ n_labels: int | None = None,
991
+ ) -> list[list[int]]:
992
+ r"""Get the indices of the molecules in each cluster"""
993
+ if n_labels is None:
994
+ n_labels = len(np.unique(labels))
995
+ new_members: list[list[int]] = [[] for _ in range(n_labels)]
996
+ for i, idxs in enumerate(members):
997
+ new_members[labels[i]].extend(idxs)
998
+ return new_members
999
+
1000
+ def get_assignments(
1001
+ self,
1002
+ n_mols: int | None = None,
1003
+ sort: bool = True,
1004
+ check_valid: bool = True,
1005
+ global_clusters: bool = False,
1006
+ ) -> NDArray[np.uint64]:
1007
+ r"""Get an array with the cluster labels associated with each fingerprint idx"""
1008
+ if n_mols is not None:
1009
+ warnings.warn("The n_mols argument is redundant", DeprecationWarning)
1010
+ if n_mols is not None and n_mols != self.num_fitted_fps:
1011
+ raise ValueError(
1012
+ f"Provided n_mols {n_mols} is different"
1013
+ f" from the number of fitted fingerprints {self.num_fitted_fps}"
1014
+ )
1015
+ if check_valid:
1016
+ assignments = np.full(self.num_fitted_fps, 0, dtype=np.uint64)
1017
+ else:
1018
+ assignments = np.empty(self.num_fitted_fps, dtype=np.uint64)
1019
+
1020
+ iterator: tp.Iterable[list[int]]
1021
+ if sort:
1022
+ iterator = self.get_cluster_mol_ids(sort=True)
1023
+ else:
1024
+ iterator = (
1025
+ s.mol_indices for leaf in self._get_leaves() for s in leaf._subclusters
1026
+ )
1027
+
1028
+ if global_clusters:
1029
+ if self._global_clustering_centroid_labels is None:
1030
+ raise ValueError(
1031
+ "Must perform global clustering before fetching global labels"
1032
+ )
1033
+ # Assign according to global clustering labels
1034
+ final_labels = self._global_clustering_centroid_labels
1035
+ for mol_ids, label in zip(iterator, final_labels):
1036
+ assignments[mol_ids] = label
1037
+ else:
1038
+ # Assign according to mol_ids from the subclusters
1039
+ for i, mol_ids in enumerate(iterator, 1):
1040
+ assignments[mol_ids] = i
1041
+
1042
+ # Check that there are no unassigned molecules
1043
+ if check_valid and (assignments == 0).any():
1044
+ raise ValueError("There are unasigned molecules")
1045
+ return assignments
1046
+
1047
+ def dump_assignments(
1048
+ self,
1049
+ path: Path | str,
1050
+ smiles: tp.Iterable[str] = (),
1051
+ sort: bool = True,
1052
+ global_clusters: bool = False,
1053
+ check_valid: bool = True,
1054
+ ) -> None:
1055
+ r"""Dump the cluster assignments to a ``*.csv`` file"""
1056
+ import pandas as pd # Hide pandas import since it is heavy
1057
+
1058
+ path = Path(path)
1059
+ if isinstance(smiles, str):
1060
+ smiles = [smiles]
1061
+ smiles = np.asarray(smiles, dtype=np.str_)
1062
+ # Dump cluster assignments to *.csv
1063
+ assignments = self.get_assignments(
1064
+ sort=sort, check_valid=check_valid, global_clusters=global_clusters
1065
+ )
1066
+ if smiles.size and (len(assignments) != len(smiles)):
1067
+ raise ValueError(
1068
+ f"Len of the provided smiles {len(smiles)}"
1069
+ f" must match the number of fitted fingerprints {self.num_fitted_fps}"
1070
+ )
1071
+ df = pd.DataFrame({"assignments": assignments})
1072
+ if smiles.size:
1073
+ df["smiles"] = smiles
1074
+ df.to_csv(path, index=False)
1075
+
1076
+ def reset(self) -> None:
1077
+ r"""Reset the tree state
1078
+
1079
+ Delete *all internal nodes and leafs*, does not reset the merge criterion or
1080
+ other merge parameters.
1081
+ """
1082
+ # Reset the whole tree
1083
+ if self._root is not None:
1084
+ self._root._prev_leaf = None
1085
+ self._root._next_leaf = None
1086
+ self._dummy_leaf._next_leaf = None
1087
+ self._root = None
1088
+ self._num_fitted_fps = 0
1089
+
1090
+ def delete_internal_nodes(self) -> None:
1091
+ r"""Delete all nodes in the tree that are not leaves
1092
+
1093
+ This function is for advanced usage only. It should be called if there is need
1094
+ to use the BitBirch leaf clusters, but you need to release the memory held by
1095
+ the internal nodes. After calling this function, no more fingerprints can be fit
1096
+ into the tree, unless a call to `BitBirch.reset` afterwards releases the
1097
+ *whole tree*, including the leaf clusters.
1098
+ """
1099
+ if not tp.cast(_BFNode, self._root).is_leaf:
1100
+ # release all nodes that are not leaves,
1101
+ # they are kept alive by references from dummy_leaf
1102
+ self._root = None
1103
+
1104
+ @property
1105
+ def _only_has_leaves(self) -> bool:
1106
+ return (self._root is None) and (self._dummy_leaf._next_leaf is not None)
1107
+
1108
+ def recluster_inplace(
1109
+ self,
1110
+ iterations: int = 1,
1111
+ extra_threshold: float = 0.0,
1112
+ shuffle: bool = False,
1113
+ seed: int | None = None,
1114
+ verbose: bool = False,
1115
+ stop_early: bool = False,
1116
+ ) -> tpx.Self:
1117
+ r"""Refine singleton clusters by re-inserting them into the tree
1118
+
1119
+ Parameters
1120
+ ----------
1121
+ extra_threshold : float, default=0.0
1122
+ The amount to increase the current threshold in each iteration.
1123
+
1124
+ iterations : int, default=1
1125
+ The maximum number of refinement iterations to perform.
1126
+
1127
+ Returns
1128
+ -------
1129
+ self : BitBirch
1130
+ The fitted estimator with refined clusters.
1131
+
1132
+ Raises
1133
+ ------
1134
+ ValueError
1135
+ If the model has not been fitted.
1136
+ """
1137
+ if not self.is_init:
1138
+ raise ValueError("The model has not been fitted yet.")
1139
+
1140
+ singletons_before = 0
1141
+ for _ in range(iterations):
1142
+ # Get the BFs
1143
+ bfs = self._get_leaf_bfs(sort=True)
1144
+
1145
+ # Count the number of clusters and singletons
1146
+ singleton_bfs = sum(1 for bf in bfs if bf.n_samples == 1)
1147
+
1148
+ # Check stopping criteria
1149
+ if stop_early:
1150
+ if singleton_bfs == 0 or singleton_bfs == singletons_before:
1151
+ # No more singletons to refine
1152
+ break
1153
+ singletons_before = singleton_bfs
1154
+
1155
+ # Print progress
1156
+ if verbose:
1157
+ print(f"Current number of clusters: {len(bfs)}")
1158
+ print(f"Current number of singletons: {singleton_bfs}")
1159
+
1160
+ if shuffle:
1161
+ random.seed(seed)
1162
+ random.shuffle(bfs)
1163
+
1164
+ # Prepare the buffers for refitting
1165
+ fps_bfs, mols_bfs = self._prepare_bf_to_buffer_dicts(bfs)
1166
+
1167
+ # Reset the tree
1168
+ self.reset()
1169
+
1170
+ # Change the threshold
1171
+ self.threshold += extra_threshold
1172
+
1173
+ # Refit the subsclusters
1174
+ for bufs, mol_idxs in zip(fps_bfs.values(), mols_bfs.values()):
1175
+ self._fit_buffers(bufs, reinsert_index_seqs=mol_idxs)
1176
+
1177
+ # Print final stats
1178
+ if verbose:
1179
+ bfs = self._get_leaf_bfs(sort=True)
1180
+ singleton_bfs = sum(1 for bf in bfs if bf.n_samples == 1)
1181
+ print(f"Final number of clusters: {len(bfs)}")
1182
+ print(f"Final number of singletons: {singleton_bfs}")
1183
+ return self
1184
+
1185
+ def refine_inplace(
1186
+ self,
1187
+ X: _Input | Path | str | tp.Sequence[Path],
1188
+ initial_mol: int = 0,
1189
+ input_is_packed: bool = True,
1190
+ n_largest: int = 1,
1191
+ ) -> tpx.Self:
1192
+ r"""Refine the tree: break the largest clusters in singletons and re-fit"""
1193
+ if not self.is_init:
1194
+ raise ValueError("The model has not been fitted yet.")
1195
+ # Release the memory held by non-leaf nodes
1196
+ self.delete_internal_nodes()
1197
+
1198
+ # Extract the BitFeatures of the leaves, breaking the largest cluster apart into
1199
+ # singleton subclusters
1200
+ fps_bfs, mols_bfs = self._bf_to_np_refine( # This function takes a bunch of mem
1201
+ X,
1202
+ initial_mol=initial_mol,
1203
+ input_is_packed=input_is_packed,
1204
+ n_largest=n_largest,
1205
+ )
1206
+ # Reset the tree
1207
+ self.reset()
1208
+
1209
+ # Rebuild the tree again from scratch, reinserting all the subclusters
1210
+ for bufs, mol_idxs in zip(fps_bfs.values(), mols_bfs.values()):
1211
+ self._fit_buffers(bufs, reinsert_index_seqs=mol_idxs)
1212
+ return self
1213
+
1214
+ def _get_leaf_bfs(self, sort: bool = True) -> list[_BFSubcluster]:
1215
+ r"""Get the BitFeatures of the leaves"""
1216
+ bfs = [s for leaf in self._get_leaves() for s in leaf._subclusters]
1217
+ if sort:
1218
+ # Sort the BitFeatures by the number of samples in the cluster
1219
+ bfs.sort(key=lambda x: x.n_samples, reverse=True)
1220
+ return bfs
1221
+
1222
+ def _bf_to_np_refine(
1223
+ self,
1224
+ X: _Input | Path | str | tp.Sequence[Path],
1225
+ initial_mol: int = 0,
1226
+ input_is_packed: bool = True,
1227
+ n_largest: int = 1,
1228
+ ) -> tuple[dict[str, list[NDArray[np.integer]]], dict[str, list[list[int]]]]:
1229
+ """Prepare numpy bufs ('np') for BitFeatures, splitting the biggest n clusters
1230
+
1231
+ The largest clusters are split into singletons. In order to perform this split,
1232
+ the *original* fingerprint array used to fit the tree (X) has to be provided,
1233
+ together with the index associated with the first fingerprint.
1234
+
1235
+ The split is only performed for the returned 'np' buffers, the clusters in the
1236
+ tree itself are not modified
1237
+ """
1238
+ if n_largest == 0:
1239
+ return self._bf_to_np()
1240
+
1241
+ if n_largest < 1:
1242
+ raise ValueError("n_largest must be >= 1")
1243
+
1244
+ bfs = self._get_leaf_bfs()
1245
+ largest = bfs[:n_largest]
1246
+ rest = bfs[n_largest:]
1247
+ n_features = largest[0].n_features
1248
+
1249
+ dtypes_to_fp, dtypes_to_mols = self._prepare_bf_to_buffer_dicts(rest)
1250
+ # Add X and mol indices of the "big" cluster
1251
+ if input_is_packed:
1252
+ unpack_or_copy = lambda x: _unpack_fingerprints(
1253
+ cast(NDArray[np.uint8], x), n_features
1254
+ )
1255
+ else:
1256
+ unpack_or_copy = lambda x: x.copy()
1257
+
1258
+ for big_bf in largest:
1259
+ full_arr_idxs = [(idx - initial_mol) for idx in big_bf.mol_indices]
1260
+ _X: _Input
1261
+ if isinstance(X, (Path, str)):
1262
+ # Only load the specific required mol idxs
1263
+ _X = cast(NDArray[np.integer], np.load(X, mmap_mode="r"))[full_arr_idxs]
1264
+ arr_idxs = list(range(len(_X)))
1265
+ mol_idxs = big_bf.mol_indices
1266
+ elif isinstance(X[0], Path):
1267
+ # Only load the specific required mol idxs
1268
+ sort_idxs = np.argsort(full_arr_idxs)
1269
+ _X = _get_fingerprints_from_file_seq(
1270
+ cast(tp.Sequence[Path], X),
1271
+ [full_arr_idxs[i] for i in sort_idxs],
1272
+ )
1273
+ arr_idxs = list(range(len(_X)))
1274
+ mol_idxs = big_bf.mol_indices
1275
+ mol_idxs = [mol_idxs[i] for i in sort_idxs]
1276
+ else:
1277
+ # Index the full array / list
1278
+ _X = cast(_Input, X)
1279
+ arr_idxs = full_arr_idxs
1280
+ mol_idxs = big_bf.mol_indices
1281
+
1282
+ for mol_idx, arr_idx in zip(mol_idxs, arr_idxs):
1283
+ buffer = np.empty(n_features + 1, dtype=np.uint8)
1284
+ buffer[:-1] = unpack_or_copy(_X[arr_idx])
1285
+ buffer[-1] = 1
1286
+ dtypes_to_fp["uint8"].append(buffer)
1287
+ dtypes_to_mols["uint8"].append([mol_idx])
1288
+ return dtypes_to_fp, dtypes_to_mols
1289
+
1290
+ def _bf_to_np(
1291
+ self,
1292
+ ) -> tuple[dict[str, list[NDArray[np.integer]]], dict[str, list[list[int]]]]:
1293
+ """Prepare numpy buffers ('np') for BitFeatures of all clusters"""
1294
+ return self._prepare_bf_to_buffer_dicts(self._get_leaf_bfs())
1295
+
1296
+ @staticmethod
1297
+ def _prepare_bf_to_buffer_dicts(
1298
+ bfs: list["_BFSubcluster"],
1299
+ ) -> tuple[dict[str, list[NDArray[np.integer]]], dict[str, list[list[int]]]]:
1300
+ # Helper function used when returning lists of subclusters
1301
+ dtypes_to_fp = defaultdict(list)
1302
+ dtypes_to_mols = defaultdict(list)
1303
+ for bf in bfs:
1304
+ dtypes_to_fp[bf.dtype_name].append(bf._buffer)
1305
+ dtypes_to_mols[bf.dtype_name].append(bf.mol_indices)
1306
+ return dtypes_to_fp, dtypes_to_mols
1307
+
1308
+ def __repr__(self) -> str:
1309
+ fn = self._merge_accept_fn
1310
+ parts = [
1311
+ f"threshold={self.threshold}",
1312
+ f"branching_factor={self.branching_factor}",
1313
+ f"merge_criterion='{fn.name if fn.name in BUILTIN_MERGES else fn}'",
1314
+ ]
1315
+ if self.tolerance is not None:
1316
+ parts.append(f"tolerance={self.tolerance}")
1317
+ return f"{self.__class__.__name__}({', '.join(parts)})"
1318
+
1319
+ def global_clustering(
1320
+ self,
1321
+ n_clusters: int,
1322
+ *,
1323
+ method: str = "kmeans",
1324
+ # TODO: Type correctly
1325
+ **method_kwargs: tp.Any,
1326
+ ) -> tpx.Self:
1327
+ r""":meta private:"""
1328
+ warnings.warn(
1329
+ "Global clustering is an experimental features"
1330
+ " it will be modified without warning, please do not use"
1331
+ )
1332
+ if not self.is_init:
1333
+ raise ValueError("The model has not been fitted yet.")
1334
+ centroids = np.vstack(self.get_centroids(packed=False))
1335
+ labels = self._centrals_global_clustering(
1336
+ centroids, n_clusters, method=method, input_is_packed=False, **method_kwargs
1337
+ )
1338
+ num_centroids = len(centroids)
1339
+ self._n_global_clusters = (
1340
+ n_clusters if num_centroids > n_clusters else num_centroids
1341
+ )
1342
+ self._global_clustering_centroid_labels = labels
1343
+ return self
1344
+
1345
+ @staticmethod
1346
+ def _centrals_global_clustering(
1347
+ centrals: NDArray[np.uint8],
1348
+ n_clusters: int,
1349
+ *,
1350
+ method: str = "kmeans",
1351
+ input_is_packed: bool = True,
1352
+ n_features: int | None = None,
1353
+ # TODO: Type correctly
1354
+ **method_kwargs: tp.Any,
1355
+ ) -> NDArray[np.int64]:
1356
+ r""":meta private:"""
1357
+ if method not in {"agglomerative", "kmeans", "kmeans-normalized"}:
1358
+ raise ValueError(f"Unknown method {method}")
1359
+ # Returns the labels associated with global clustering
1360
+ # Lazy import because sklearn is very heavy
1361
+ from sklearn.cluster import KMeans, AgglomerativeClustering
1362
+ from sklearn.exceptions import ConvergenceWarning
1363
+
1364
+ if input_is_packed:
1365
+ centrals = _unpack_fingerprints(centrals, n_features)
1366
+
1367
+ num_centrals = len(centrals)
1368
+ if num_centrals < n_clusters:
1369
+ msg = (
1370
+ f"Number of subclusters found ({num_centrals}) by BitBIRCH is less "
1371
+ "than ({n_clusters}). Decrease k or the threshold."
1372
+ )
1373
+ warnings.warn(msg, ConvergenceWarning, stacklevel=2)
1374
+ n_clusters = num_centrals
1375
+
1376
+ if method == "kmeans-normalized":
1377
+ centrals = centrals / np.linalg.norm(centrals, axis=1, keepdims=True)
1378
+ if method in ["kmeans", "kmeans-normalized"]:
1379
+ predictor = KMeans(n_clusters=n_clusters, **method_kwargs)
1380
+ elif method == "agglomerative":
1381
+ predictor = AgglomerativeClustering(n_clusters=n_clusters, **method_kwargs)
1382
+ else:
1383
+ raise ValueError("method must be one of 'kmeans' or 'agglomerative'")
1384
+
1385
+ # Add 1 to start labels from 1 instead of 0, so 0 can be used as sentinel
1386
+ # value
1387
+ # This is the bottleneck for building this index
1388
+ # K-means is feasible, agglomerative is extremely expensive
1389
+ return predictor.fit_predict(centrals) + 1
1390
+
1391
+
1392
+ # There are 4 cases here:
1393
+ # (1) The input is a scipy.sparse array
1394
+ # (2) The input is a list of dense arrays (nothing required)
1395
+ # (3) The input is a packed array or list of packed arrays (unpack required)
1396
+ # (4) The input is a dense array (copy required)
1397
+ # NOTE: Sparse iteration hack is taken from sklearn
1398
+ # It returns a densified row when iterating over a sparse matrix, instead
1399
+ # of constructing a sparse matrix for every row that is expensive.
1400
+ #
1401
+ # Output is *always* of dtype uint8, but input (if unpacked) can be of arbitrary dtype
1402
+ # It is most efficient for input to be uint8 to prevent copies
1403
+ def _get_array_iterable(
1404
+ X: _Input,
1405
+ input_is_packed: bool = True,
1406
+ n_features: int | None = None,
1407
+ dtype: DTypeLike = np.uint8,
1408
+ ) -> tp.Iterable[NDArray[np.integer]]:
1409
+ if input_is_packed:
1410
+ # Unpacking copies the fingerprints, so no extra copy required
1411
+ # NOTE: cast seems to have a very small overhead in this loop for some reason
1412
+ return (_unpack_fingerprints(a, n_features) for a in X) # type: ignore
1413
+ if isinstance(X, list):
1414
+ # No copy is required here unless the dtype is not uint8
1415
+ return (a.astype(dtype, copy=False) for a in X)
1416
+ if isinstance(X, np.ndarray):
1417
+ # A copy is required here to avoid keeping a ref to the full array alive
1418
+ return (a.astype(dtype, copy=True) for a in X)
1419
+ return _iter_sparse(X)
1420
+
1421
+
1422
+ # NOTE: In practice this branch is never used, it could probably safely be deleted
1423
+ def _iter_sparse(X: tp.Any) -> tp.Iterator[NDArray[np.uint8]]:
1424
+ import scipy.sparse # Hide this import since scipy is heavy
1425
+
1426
+ if not scipy.sparse.issparse(X):
1427
+ raise ValueError(f"Input of type {type(X)} is not supported")
1428
+ n_samples, n_features = X.shape
1429
+ X_indices = X.indices
1430
+ X_data = X.data
1431
+ X_indptr = X.indptr
1432
+ for i in range(n_samples):
1433
+ a = np.zeros(n_features, dtype=np.uint8)
1434
+ startptr, endptr = X_indptr[i], X_indptr[i + 1]
1435
+ nonzero_indices = X_indices[startptr:endptr]
1436
+ a[nonzero_indices] = X_data[startptr:endptr].astype(np.uint8, copy=False)
1437
+ yield a