bblean 0.6.0b2__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bblean/__init__.py +22 -0
- bblean/_config.py +61 -0
- bblean/_console.py +187 -0
- bblean/_cpp_similarity.cp312-win_amd64.pyd +0 -0
- bblean/_legacy/__init__.py +0 -0
- bblean/_legacy/bb_int64.py +1252 -0
- bblean/_legacy/bb_uint8.py +1144 -0
- bblean/_memory.py +198 -0
- bblean/_merges.py +212 -0
- bblean/_py_similarity.py +278 -0
- bblean/_timer.py +42 -0
- bblean/_version.py +34 -0
- bblean/analysis.py +258 -0
- bblean/bitbirch.py +1437 -0
- bblean/cli.py +1850 -0
- bblean/csrc/README.md +1 -0
- bblean/csrc/similarity.cpp +521 -0
- bblean/fingerprints.py +424 -0
- bblean/metrics.py +199 -0
- bblean/multiround.py +489 -0
- bblean/plotting.py +479 -0
- bblean/similarity.py +304 -0
- bblean/sklearn.py +203 -0
- bblean/smiles.py +61 -0
- bblean/utils.py +130 -0
- bblean-0.6.0b2.dist-info/METADATA +288 -0
- bblean-0.6.0b2.dist-info/RECORD +31 -0
- bblean-0.6.0b2.dist-info/WHEEL +5 -0
- bblean-0.6.0b2.dist-info/entry_points.txt +2 -0
- bblean-0.6.0b2.dist-info/licenses/LICENSE +48 -0
- bblean-0.6.0b2.dist-info/top_level.txt +1 -0
bblean/bitbirch.py
ADDED
|
@@ -0,0 +1,1437 @@
|
|
|
1
|
+
# BitBIRCH-Lean Python Package: An open-source clustering module based on iSIM.
|
|
2
|
+
#
|
|
3
|
+
# If you find this software useful please cite the following articles:
|
|
4
|
+
# - BitBIRCH: efficient clustering of large molecular libraries:
|
|
5
|
+
# https://doi.org/10.1039/D5DD00030K
|
|
6
|
+
# - BitBIRCH Clustering Refinement Strategies:
|
|
7
|
+
# https://doi.org/10.1021/acs.jcim.5c00627
|
|
8
|
+
# - BitBIRCH-Lean: TO-BE-ADDED
|
|
9
|
+
#
|
|
10
|
+
# Copyright (C) 2025 The Miranda-Quintana Lab and other BitBirch developers, comprised
|
|
11
|
+
# exclusively by:
|
|
12
|
+
# - Ramon Alain Miranda Quintana <ramirandaq@gmail.com>, <quintana@chem.ufl.edu>
|
|
13
|
+
# - Krisztina Zsigmond <kzsigmond@ufl.edu>
|
|
14
|
+
# - Ignacio Pickering <ipickering@chem.ufl.edu>
|
|
15
|
+
# - Kenneth Lopez Perez <klopezperez@chem.ufl.edu>
|
|
16
|
+
# - Miroslav Lzicar <miroslav.lzicar@deepmedchem.com>
|
|
17
|
+
#
|
|
18
|
+
# Authors of ./bblean/multiround.py are:
|
|
19
|
+
# - Ramon Alain Miranda Quintana <ramirandaq@gmail.com>, <quintana@chem.ufl.edu>
|
|
20
|
+
# - Ignacio Pickering <ipickering@chem.ufl.edu>
|
|
21
|
+
#
|
|
22
|
+
# This program is free software: you can redistribute it and/or modify it under the
|
|
23
|
+
# terms of the GNU General Public License as published by the Free Software Foundation,
|
|
24
|
+
# version 3 (SPDX-License-Identifier: GPL-3.0-only).
|
|
25
|
+
#
|
|
26
|
+
# Portions of this file are licensed under the BSD 3-Clause License
|
|
27
|
+
# Copyright (c) 2007-2024 The scikit-learn developers. All rights reserved.
|
|
28
|
+
# (SPDX-License-Identifier: BSD-3-Clause). Copies or reproductions of code in this
|
|
29
|
+
# file must in addition adhere to the BSD-3-Clause license terms. A
|
|
30
|
+
# copy of the BSD-3-Clause license can be located at the root of this repository, under
|
|
31
|
+
# ./LICENSES/BSD-3-Clause.txt.
|
|
32
|
+
#
|
|
33
|
+
# Portions of this file were previously licensed under the LGPL 3.0
|
|
34
|
+
# license (SPDX-License-Identifier: LGPL-3.0-only), they are relicensed in this program
|
|
35
|
+
# as GPL-3.0, with permission of all original copyright holders:
|
|
36
|
+
# - Ramon Alain Miranda Quintana <ramirandaq@gmail.com>, <quintana@chem.ufl.edu>
|
|
37
|
+
# - Vicky (Vic) Jung <jungvicky@ufl.edu>
|
|
38
|
+
# - Kenneth Lopez Perez <klopezperez@chem.ufl.edu>
|
|
39
|
+
# - Kate Huddleston <kdavis2@chem.ufl.edu>
|
|
40
|
+
#
|
|
41
|
+
# This program is distributed in the hope that it will be useful, but WITHOUT ANY
|
|
42
|
+
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
|
|
43
|
+
# PARTICULAR PURPOSE. See the GNU General Public License for more details.
|
|
44
|
+
#
|
|
45
|
+
# You should have received a copy of the GNU General Public License along with this
|
|
46
|
+
# program. This copy can be located at the root of this repository, under
|
|
47
|
+
# ./LICENSES/GPL-3.0-only.txt. If not, see <http://www.gnu.org/licenses/gpl-3.0.html>.
|
|
48
|
+
r"""BitBirch 'Lean' class for fast, memory-efficient O(N) clustering"""
|
|
49
|
+
from __future__ import annotations # Stringize type annotations for no runtime overhead
|
|
50
|
+
import typing_extensions as tpx
|
|
51
|
+
import os
|
|
52
|
+
import random
|
|
53
|
+
from pathlib import Path
|
|
54
|
+
import warnings
|
|
55
|
+
import typing as tp
|
|
56
|
+
from typing import cast
|
|
57
|
+
from collections import defaultdict
|
|
58
|
+
from weakref import WeakSet
|
|
59
|
+
|
|
60
|
+
import numpy as np
|
|
61
|
+
from numpy.typing import NDArray, DTypeLike
|
|
62
|
+
|
|
63
|
+
from bblean._memory import _mmap_file_and_madvise_sequential, _ArrayMemPagesManager
|
|
64
|
+
from bblean._merges import get_merge_accept_fn, MergeAcceptFunction, BUILTIN_MERGES
|
|
65
|
+
from bblean.utils import min_safe_uint
|
|
66
|
+
from bblean.fingerprints import (
|
|
67
|
+
pack_fingerprints,
|
|
68
|
+
_get_fingerprints_from_file_seq,
|
|
69
|
+
)
|
|
70
|
+
from bblean.similarity import (
|
|
71
|
+
_jt_sim_arr_vec_packed,
|
|
72
|
+
jt_most_dissimilar_packed,
|
|
73
|
+
jt_isim_medoid,
|
|
74
|
+
centroid_from_sum,
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
if os.getenv("BITBIRCH_NO_EXTENSIONS"):
|
|
78
|
+
from bblean.fingerprints import unpack_fingerprints as _unpack_fingerprints
|
|
79
|
+
else:
|
|
80
|
+
try:
|
|
81
|
+
# NOTE: There are small gains from using this fn but only ~3%, so don't warn for
|
|
82
|
+
# now if this fails, and don't expose it
|
|
83
|
+
from bblean._cpp_similarity import unpack_fingerprints as _unpack_fingerprints # type: ignore # noqa
|
|
84
|
+
except ImportError:
|
|
85
|
+
from bblean.fingerprints import unpack_fingerprints as _unpack_fingerprints
|
|
86
|
+
|
|
87
|
+
__all__ = ["BitBirch"]
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
# For backwards compatibility with the global "set_merge", keep weak references to all
|
|
91
|
+
# the BitBirch instances and update them when set_merge is called
|
|
92
|
+
_BITBIRCH_INSTANCES: WeakSet["BitBirch"] = WeakSet()
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
# For backwards compatibility: global function used to accept merges
|
|
96
|
+
_global_merge_accept: MergeAcceptFunction | None = None
|
|
97
|
+
|
|
98
|
+
_Input = tp.Union[NDArray[np.integer], list[NDArray[np.integer]]]
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
# For backwards compatibility: set the global merge_accept function
|
|
102
|
+
def set_merge(merge_criterion: str, tolerance: float = 0.05) -> None:
|
|
103
|
+
r"""Sets the global criteria for merging subclusters in any BitBirch tree
|
|
104
|
+
|
|
105
|
+
For usage see `BitBirch.set_merge`
|
|
106
|
+
|
|
107
|
+
.. warning::
|
|
108
|
+
|
|
109
|
+
Use of this function is highly discouraged, instead use either:
|
|
110
|
+
``bb_tree = BitBirch(...)``
|
|
111
|
+
``bb_tree.set_merge(merge_criterion=..., tolerance=...)``
|
|
112
|
+
or directly: ``bb_tree = BitBirch(..., merge_criterion=..., tolerance=...)``"
|
|
113
|
+
|
|
114
|
+
"""
|
|
115
|
+
msg = (
|
|
116
|
+
"Use of the global `set_merge` function is highly discouraged,\n"
|
|
117
|
+
" instead use either: "
|
|
118
|
+
" bb_tree = BitBirch(...)\n"
|
|
119
|
+
" bb_tree.set_merge(merge_criterion=..., tolerance=...)\n"
|
|
120
|
+
" or directly: `bb_tree = BitBirch(..., merge_criterion=..., tolerance=...)`."
|
|
121
|
+
)
|
|
122
|
+
warnings.warn(msg, UserWarning)
|
|
123
|
+
# Set the global merge_accept function
|
|
124
|
+
global _global_merge_accept
|
|
125
|
+
_global_merge_accept = get_merge_accept_fn(merge_criterion, tolerance)
|
|
126
|
+
for bbirch in _BITBIRCH_INSTANCES:
|
|
127
|
+
bbirch._merge_accept_fn = _global_merge_accept
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
# Utility function to validate the n_features argument for packed inputs
|
|
131
|
+
def _validate_n_features(
|
|
132
|
+
X: _Input, input_is_packed: bool, n_features: int | None = None
|
|
133
|
+
) -> int:
|
|
134
|
+
if len(X) == 0:
|
|
135
|
+
raise ValueError("Input must have at least 1 fingerprint")
|
|
136
|
+
if input_is_packed:
|
|
137
|
+
_padded_n_features = len(X[0]) * 8 if isinstance(X, list) else X.shape[1] * 8
|
|
138
|
+
if n_features is None:
|
|
139
|
+
# Assume multiple of 8
|
|
140
|
+
return _padded_n_features
|
|
141
|
+
if _padded_n_features < n_features:
|
|
142
|
+
raise ValueError(
|
|
143
|
+
"n_features is larger than the padded length, which is inconsistent"
|
|
144
|
+
)
|
|
145
|
+
return n_features
|
|
146
|
+
|
|
147
|
+
x_n_features = len(X[0]) if isinstance(X, list) else X.shape[1]
|
|
148
|
+
if n_features is not None:
|
|
149
|
+
if n_features != x_n_features:
|
|
150
|
+
raise ValueError(
|
|
151
|
+
"n_features is redundant for non-packed inputs"
|
|
152
|
+
" if passed, it must be equal to X.shape[1] (or len(X[0]))."
|
|
153
|
+
f" For passed X the inferred n_features was {x_n_features}."
|
|
154
|
+
" If this value is not what you expected,"
|
|
155
|
+
" make sure the passed X is actually unpacked."
|
|
156
|
+
)
|
|
157
|
+
return x_n_features
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def _split_node(node: "_BFNode") -> tuple["_BFSubcluster", "_BFSubcluster"]:
|
|
161
|
+
"""The node has to be split if there is no place for a new subcluster
|
|
162
|
+
in the node.
|
|
163
|
+
1. An extra empty node and two empty subclusters are initialized.
|
|
164
|
+
2. The pair of distant subclusters are found.
|
|
165
|
+
3. The properties of the empty subclusters and nodes are updated
|
|
166
|
+
according to the nearest distance between the subclusters to the
|
|
167
|
+
pair of distant subclusters.
|
|
168
|
+
4. The two nodes are set as children to the two subclusters.
|
|
169
|
+
"""
|
|
170
|
+
n_features = node.n_features
|
|
171
|
+
branching_factor = node.branching_factor
|
|
172
|
+
new_subcluster1 = _BFSubcluster(n_features=n_features)
|
|
173
|
+
new_subcluster2 = _BFSubcluster(n_features=n_features)
|
|
174
|
+
|
|
175
|
+
node1 = _BFNode(branching_factor, n_features)
|
|
176
|
+
node2 = node # Rename for clarity
|
|
177
|
+
new_subcluster1.child = node1
|
|
178
|
+
new_subcluster2.child = node2
|
|
179
|
+
|
|
180
|
+
if node2.is_leaf:
|
|
181
|
+
# If is_leaf, _prev_leaf is guaranteed to be not None
|
|
182
|
+
# NOTE: cast seems to have a small overhead here for some reason
|
|
183
|
+
node1._prev_leaf = node2._prev_leaf
|
|
184
|
+
node2._prev_leaf._next_leaf = node1 # type: ignore
|
|
185
|
+
node1._next_leaf = node2
|
|
186
|
+
node2._prev_leaf = node1
|
|
187
|
+
|
|
188
|
+
# O(N) approximation to obtain "most dissimilar fingerprints" within an array
|
|
189
|
+
node1_idx, _, node1_sim, node2_sim = jt_most_dissimilar_packed(
|
|
190
|
+
node2.packed_centroids, n_features
|
|
191
|
+
)
|
|
192
|
+
node1_closer = node1_sim > node2_sim
|
|
193
|
+
# Make sure node1 and node2 are closest to themselves, even if all sims are equal.
|
|
194
|
+
# This can only happen when all node.packed_centroids are duplicates leading to all
|
|
195
|
+
# distances between centroids being zero.
|
|
196
|
+
|
|
197
|
+
# TODO: Currently this behavior is buggy (?), seems like in some cases one of the
|
|
198
|
+
# subclusters may *never* get updated, double check this logic
|
|
199
|
+
node1_closer[node1_idx] = True
|
|
200
|
+
subclusters = node2._subclusters.copy() # Shallow copy
|
|
201
|
+
node2._subclusters = [] # Reset the node
|
|
202
|
+
for idx, subcluster in enumerate(subclusters):
|
|
203
|
+
if node1_closer[idx]:
|
|
204
|
+
node1.append_subcluster(subcluster)
|
|
205
|
+
new_subcluster1.update(subcluster)
|
|
206
|
+
else:
|
|
207
|
+
node2.append_subcluster(subcluster)
|
|
208
|
+
new_subcluster2.update(subcluster)
|
|
209
|
+
return new_subcluster1, new_subcluster2
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
class _BFNode:
|
|
213
|
+
"""Each node in a BitBirch tree is a _BFNode.
|
|
214
|
+
|
|
215
|
+
The _BFNode holds a maximum of branching_factor _BFSubclusters.
|
|
216
|
+
|
|
217
|
+
Parameters
|
|
218
|
+
----------
|
|
219
|
+
branching_factor : int
|
|
220
|
+
Maximum number of _BFSubcluster in the node.
|
|
221
|
+
|
|
222
|
+
n_features : int
|
|
223
|
+
The number of features.
|
|
224
|
+
|
|
225
|
+
Attributes
|
|
226
|
+
----------
|
|
227
|
+
_subclusters : list
|
|
228
|
+
List of _BFSubcluster for thre _BFNode.
|
|
229
|
+
|
|
230
|
+
_prev_leaf : _BFNode
|
|
231
|
+
Only useful for leaf nodes, otherwise None
|
|
232
|
+
|
|
233
|
+
_next_leaf : _BFNode
|
|
234
|
+
Only useful for leaf nodes, otherwise None
|
|
235
|
+
|
|
236
|
+
_packed_centroids_buf : NDArray[np.uint8]
|
|
237
|
+
Packed array of shape (branching_factor + 1, (n_features + 7) // 8) The code
|
|
238
|
+
internally manipulates this buf rather than packed_centroids, which is just a
|
|
239
|
+
view of this buf.
|
|
240
|
+
|
|
241
|
+
packed_centroids : ndarray of shape (branching_factor, n_features)
|
|
242
|
+
Packed array of shape (len(_subclusters), (n_features + 7) // 8)
|
|
243
|
+
View of the valid section of ``_packed_centroids_buf``.
|
|
244
|
+
"""
|
|
245
|
+
|
|
246
|
+
# NOTE: Slots deactivates __dict__, and thus reduces memory usage of python objects
|
|
247
|
+
__slots__ = (
|
|
248
|
+
"n_features",
|
|
249
|
+
"_subclusters",
|
|
250
|
+
"_packed_centroids_buf",
|
|
251
|
+
"_prev_leaf",
|
|
252
|
+
"_next_leaf",
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
def __init__(self, branching_factor: int, n_features: int):
|
|
256
|
+
self.n_features = n_features
|
|
257
|
+
# The list of subclusters, centroids and squared norms
|
|
258
|
+
# to manipulate throughout.
|
|
259
|
+
self._subclusters: list["_BFSubcluster"] = []
|
|
260
|
+
# Centroids are stored packed. All centroids up to branching_factor are
|
|
261
|
+
# allocated in a contiguous array
|
|
262
|
+
self._packed_centroids_buf = np.empty(
|
|
263
|
+
(branching_factor + 1, (n_features + 7) // 8), dtype=np.uint8
|
|
264
|
+
)
|
|
265
|
+
# Nodes that are leaves have a non-null _prev_leaf
|
|
266
|
+
self._prev_leaf: tp.Optional["_BFNode"] = None
|
|
267
|
+
self._next_leaf: tp.Optional["_BFNode"] = None
|
|
268
|
+
|
|
269
|
+
@property
|
|
270
|
+
def is_leaf(self) -> bool:
|
|
271
|
+
return self._prev_leaf is not None
|
|
272
|
+
|
|
273
|
+
@property
|
|
274
|
+
def branching_factor(self) -> int:
|
|
275
|
+
return self._packed_centroids_buf.shape[0] - 1
|
|
276
|
+
|
|
277
|
+
@property
|
|
278
|
+
def packed_centroids(self) -> NDArray[np.uint8]:
|
|
279
|
+
# packed_centroids returns a view of the valid part of _packed_centroids_buf.
|
|
280
|
+
return self._packed_centroids_buf[: len(self._subclusters), :]
|
|
281
|
+
|
|
282
|
+
def append_subcluster(self, subcluster: "_BFSubcluster") -> None:
|
|
283
|
+
n_samples = len(self._subclusters)
|
|
284
|
+
self._subclusters.append(subcluster)
|
|
285
|
+
self._packed_centroids_buf[n_samples] = subcluster.packed_centroid
|
|
286
|
+
|
|
287
|
+
def update_split_subclusters(
|
|
288
|
+
self,
|
|
289
|
+
subcluster: "_BFSubcluster",
|
|
290
|
+
new_subcluster1: "_BFSubcluster",
|
|
291
|
+
new_subcluster2: "_BFSubcluster",
|
|
292
|
+
) -> None:
|
|
293
|
+
"""Remove a subcluster from a node and update it with the
|
|
294
|
+
split subclusters.
|
|
295
|
+
"""
|
|
296
|
+
# Replace subcluster with new_subcluster1
|
|
297
|
+
idx = self._subclusters.index(subcluster)
|
|
298
|
+
self._subclusters[idx] = new_subcluster1
|
|
299
|
+
self._packed_centroids_buf[idx] = new_subcluster1.packed_centroid
|
|
300
|
+
# Append new_subcluster2
|
|
301
|
+
self.append_subcluster(new_subcluster2)
|
|
302
|
+
|
|
303
|
+
def insert_bf_subcluster(
|
|
304
|
+
self,
|
|
305
|
+
subcluster: "_BFSubcluster",
|
|
306
|
+
merge_accept_fn: MergeAcceptFunction,
|
|
307
|
+
threshold: float,
|
|
308
|
+
) -> bool:
|
|
309
|
+
"""Insert a new subcluster into the node."""
|
|
310
|
+
if not self._subclusters:
|
|
311
|
+
self.append_subcluster(subcluster)
|
|
312
|
+
return False
|
|
313
|
+
|
|
314
|
+
# Within this node, find the closest subcluster to the one to-be-inserted
|
|
315
|
+
sim_matrix = _jt_sim_arr_vec_packed(
|
|
316
|
+
self.packed_centroids, subcluster.packed_centroid
|
|
317
|
+
)
|
|
318
|
+
closest_idx = np.argmax(sim_matrix)
|
|
319
|
+
closest_subcluster = self._subclusters[closest_idx]
|
|
320
|
+
closest_node = closest_subcluster.child
|
|
321
|
+
if closest_node is None:
|
|
322
|
+
# The subcluster doesn't have a child node (this is a leaf node)
|
|
323
|
+
# attempt direct merge
|
|
324
|
+
merge_was_successful = closest_subcluster.merge_subcluster(
|
|
325
|
+
subcluster, threshold, merge_accept_fn
|
|
326
|
+
)
|
|
327
|
+
if not merge_was_successful:
|
|
328
|
+
# Could not merge due to criteria
|
|
329
|
+
# Append subcluster, and check if splitting *this node* is needed
|
|
330
|
+
self.append_subcluster(subcluster)
|
|
331
|
+
return len(self._subclusters) > self.branching_factor
|
|
332
|
+
# Merge success, update the centroid
|
|
333
|
+
self._packed_centroids_buf[closest_idx] = closest_subcluster.packed_centroid
|
|
334
|
+
return False
|
|
335
|
+
|
|
336
|
+
# Hard case: the closest subcluster has a child (is 'tracking'), use recursion
|
|
337
|
+
child_must_be_split = closest_node.insert_bf_subcluster(
|
|
338
|
+
subcluster, merge_accept_fn, threshold
|
|
339
|
+
)
|
|
340
|
+
if child_must_be_split:
|
|
341
|
+
# Split the child node and redistribute subclusters. Update
|
|
342
|
+
# this node with the 'tracking' subclusters of the two new children.
|
|
343
|
+
# Then, check if *this node* needs splitting too
|
|
344
|
+
new_subcluster1, new_subcluster2 = _split_node(closest_node)
|
|
345
|
+
self.update_split_subclusters(
|
|
346
|
+
closest_subcluster, new_subcluster1, new_subcluster2
|
|
347
|
+
)
|
|
348
|
+
return len(self._subclusters) > self.branching_factor
|
|
349
|
+
|
|
350
|
+
# Child need not be split, update the *tracking* closest subcluster
|
|
351
|
+
closest_subcluster.update(subcluster)
|
|
352
|
+
self._packed_centroids_buf[closest_idx] = self._subclusters[
|
|
353
|
+
closest_idx
|
|
354
|
+
].packed_centroid
|
|
355
|
+
return False
|
|
356
|
+
|
|
357
|
+
|
|
358
|
+
class _BFSubcluster:
|
|
359
|
+
r"""Each subcluster in a BFNode is called a BFSubcluster.
|
|
360
|
+
|
|
361
|
+
A BFSubcluster can have a BFNode as its child.
|
|
362
|
+
|
|
363
|
+
Parameters
|
|
364
|
+
----------
|
|
365
|
+
linear_sum : ndarray of shape (n_features,), default=None
|
|
366
|
+
Sample. This is kept optional to allow initialization of empty
|
|
367
|
+
subclusters.
|
|
368
|
+
|
|
369
|
+
Attributes
|
|
370
|
+
----------
|
|
371
|
+
n_samples : int
|
|
372
|
+
Number of samples that belong to each subcluster.
|
|
373
|
+
|
|
374
|
+
linear_sum : ndarray
|
|
375
|
+
Linear sum of all the samples in a subcluster. Prevents holding
|
|
376
|
+
all sample data in memory.
|
|
377
|
+
|
|
378
|
+
packed_centroid : ndarray of shape (branching_factor + 1, n_features)
|
|
379
|
+
Centroid of the subcluster. Prevent recomputing of centroids when
|
|
380
|
+
``BFNode.packed_centroids`` is called.
|
|
381
|
+
|
|
382
|
+
mol_indices : list, default=[]
|
|
383
|
+
List of indices of molecules included in the given cluster.
|
|
384
|
+
|
|
385
|
+
child : _BFNode
|
|
386
|
+
Child Node of the subcluster. Once a given _BFNode is set as the child
|
|
387
|
+
of the _BFNode, it is set to ``self.child``.
|
|
388
|
+
"""
|
|
389
|
+
|
|
390
|
+
# NOTE: Slots deactivates __dict__, and thus reduces memory usage of python objects
|
|
391
|
+
__slots__ = ("_buffer", "packed_centroid", "child", "mol_indices")
|
|
392
|
+
|
|
393
|
+
def __init__(
|
|
394
|
+
self,
|
|
395
|
+
*,
|
|
396
|
+
linear_sum: NDArray[np.integer] | None = None,
|
|
397
|
+
mol_indices: tp.Sequence[int] = (),
|
|
398
|
+
n_features: int = 2048,
|
|
399
|
+
buffer: NDArray[np.integer] | None = None,
|
|
400
|
+
check_indices: bool = True,
|
|
401
|
+
):
|
|
402
|
+
# NOTE: Internally, _buffer holds both "linear_sum" and "n_samples" It is
|
|
403
|
+
# guaranteed to always have the minimum required uint dtype It should not be
|
|
404
|
+
# accessed by external classes, only used internally. The individual parts can
|
|
405
|
+
# be accessed in a read-only way using the linear_sum and n_samples
|
|
406
|
+
# properties.
|
|
407
|
+
#
|
|
408
|
+
# IMPORTANT: To mutate instances of this class, *always* use the public API
|
|
409
|
+
# given by replace|add_to_n_samples_and_linear_sum(...)
|
|
410
|
+
if buffer is not None:
|
|
411
|
+
if linear_sum is not None:
|
|
412
|
+
raise ValueError("'linear_sum' and 'buffer' are mutually exclusive")
|
|
413
|
+
if check_indices and len(mol_indices) != buffer[-1]:
|
|
414
|
+
raise ValueError(
|
|
415
|
+
"Expected len(mol_indices) == buffer[-1],"
|
|
416
|
+
f" but found {len(mol_indices)} != {buffer[-1]}"
|
|
417
|
+
)
|
|
418
|
+
self._buffer = buffer
|
|
419
|
+
self.packed_centroid = centroid_from_sum(buffer[:-1], buffer[-1], pack=True)
|
|
420
|
+
else:
|
|
421
|
+
if linear_sum is not None:
|
|
422
|
+
if check_indices and len(mol_indices) != 1:
|
|
423
|
+
raise ValueError(
|
|
424
|
+
"Expected len(mol_indices) == 1,"
|
|
425
|
+
f" but found {len(mol_indices)} != 1"
|
|
426
|
+
)
|
|
427
|
+
buffer = np.empty((len(linear_sum) + 1,), dtype=np.uint8)
|
|
428
|
+
buffer[:-1] = linear_sum
|
|
429
|
+
buffer[-1] = 1
|
|
430
|
+
self._buffer = buffer
|
|
431
|
+
self.packed_centroid = pack_fingerprints(
|
|
432
|
+
linear_sum.astype(np.uint8, copy=False)
|
|
433
|
+
)
|
|
434
|
+
else:
|
|
435
|
+
# Empty subcluster
|
|
436
|
+
if check_indices and len(mol_indices) != 0:
|
|
437
|
+
raise ValueError(
|
|
438
|
+
"Expected len(mol_indices) == 0 for empty subcluster,"
|
|
439
|
+
f" but found {len(mol_indices)} != 0"
|
|
440
|
+
)
|
|
441
|
+
self._buffer = np.zeros((n_features + 1,), dtype=np.uint8)
|
|
442
|
+
self.packed_centroid = np.empty(
|
|
443
|
+
0, dtype=np.uint8
|
|
444
|
+
) # Will be overwritten
|
|
445
|
+
self.mol_indices = list(mol_indices)
|
|
446
|
+
self.child: tp.Optional["_BFNode"] = None
|
|
447
|
+
|
|
448
|
+
@property
|
|
449
|
+
def unpacked_centroid(self) -> NDArray[np.uint8]:
|
|
450
|
+
return _unpack_fingerprints(self.packed_centroid, self.n_features)
|
|
451
|
+
|
|
452
|
+
@property
|
|
453
|
+
def n_features(self) -> int:
|
|
454
|
+
return len(self._buffer) - 1
|
|
455
|
+
|
|
456
|
+
@property
|
|
457
|
+
def dtype_name(self) -> str:
|
|
458
|
+
return self._buffer.dtype.name
|
|
459
|
+
|
|
460
|
+
@property
|
|
461
|
+
def linear_sum(self) -> NDArray[np.integer]:
|
|
462
|
+
read_only_view = self._buffer[:-1]
|
|
463
|
+
read_only_view.flags.writeable = False
|
|
464
|
+
return read_only_view
|
|
465
|
+
|
|
466
|
+
@property
|
|
467
|
+
def n_samples(self) -> int:
|
|
468
|
+
# Returns a python int, which is guaranteed to never overflow in sums, so
|
|
469
|
+
# n_samples can always be safely added when accessed through this property
|
|
470
|
+
return self._buffer.item(-1)
|
|
471
|
+
|
|
472
|
+
# NOTE: Part of the contract is that all elements of linear sum must always be
|
|
473
|
+
# less or equal to n_samples. This function does not check this
|
|
474
|
+
def replace_n_samples_and_linear_sum(
|
|
475
|
+
self, n_samples: int, linear_sum: NDArray[np.integer]
|
|
476
|
+
) -> None:
|
|
477
|
+
# Cast to the minimum uint that can hold the inputs
|
|
478
|
+
self._buffer = self._buffer.astype(min_safe_uint(n_samples), copy=False)
|
|
479
|
+
# NOTE: Assignments are safe and do not recast the buffer
|
|
480
|
+
self._buffer[:-1] = linear_sum
|
|
481
|
+
self._buffer[-1] = n_samples
|
|
482
|
+
self.packed_centroid = centroid_from_sum(linear_sum, n_samples, pack=True)
|
|
483
|
+
|
|
484
|
+
# NOTE: Part of the contract is that all elements of linear sum must always be
|
|
485
|
+
# less or equal to n_samples. This function does not check this
|
|
486
|
+
def add_to_n_samples_and_linear_sum(
|
|
487
|
+
self, n_samples: int, linear_sum: NDArray[np.integer]
|
|
488
|
+
) -> None:
|
|
489
|
+
# Cast to the minimum uint that can hold the inputs
|
|
490
|
+
new_n_samples = self.n_samples + n_samples
|
|
491
|
+
self._buffer = self._buffer.astype(min_safe_uint(new_n_samples), copy=False)
|
|
492
|
+
# NOTE: Assignment and inplace add are safe and do not recast the buffer
|
|
493
|
+
self._buffer[:-1] += linear_sum
|
|
494
|
+
self._buffer[-1] = new_n_samples
|
|
495
|
+
self.packed_centroid = centroid_from_sum(
|
|
496
|
+
self._buffer[:-1], new_n_samples, pack=True
|
|
497
|
+
)
|
|
498
|
+
|
|
499
|
+
def update(self, subcluster: "_BFSubcluster") -> None:
|
|
500
|
+
self.add_to_n_samples_and_linear_sum(
|
|
501
|
+
subcluster.n_samples, subcluster.linear_sum
|
|
502
|
+
)
|
|
503
|
+
self.mol_indices.extend(subcluster.mol_indices)
|
|
504
|
+
|
|
505
|
+
def merge_subcluster(
|
|
506
|
+
self,
|
|
507
|
+
nominee_cluster: "_BFSubcluster",
|
|
508
|
+
threshold: float,
|
|
509
|
+
merge_accept_fn: MergeAcceptFunction,
|
|
510
|
+
) -> bool:
|
|
511
|
+
"""Check if a cluster is worthy enough to be merged. If yes, merge."""
|
|
512
|
+
old_n = self.n_samples
|
|
513
|
+
nom_n = nominee_cluster.n_samples
|
|
514
|
+
new_n = old_n + nom_n
|
|
515
|
+
old_ls = self.linear_sum
|
|
516
|
+
nom_ls = nominee_cluster.linear_sum
|
|
517
|
+
# np.add with explicit dtype is safe from overflows, e.g. :
|
|
518
|
+
# np.add(np.uint8(255), np.uint8(255), dtype=np.uint16) = np.uint16(510)
|
|
519
|
+
new_ls = np.add(old_ls, nom_ls, dtype=min_safe_uint(new_n))
|
|
520
|
+
if merge_accept_fn(threshold, new_ls, new_n, old_ls, nom_ls, old_n, nom_n):
|
|
521
|
+
self.replace_n_samples_and_linear_sum(new_n, new_ls)
|
|
522
|
+
self.mol_indices.extend(nominee_cluster.mol_indices)
|
|
523
|
+
return True
|
|
524
|
+
return False
|
|
525
|
+
|
|
526
|
+
|
|
527
|
+
class _CentroidsMolIds(tp.TypedDict):
|
|
528
|
+
centroids: list[NDArray[np.uint8]]
|
|
529
|
+
mol_ids: list[list[int]]
|
|
530
|
+
|
|
531
|
+
|
|
532
|
+
class _MedoidsMolIds(tp.TypedDict):
|
|
533
|
+
medoids: NDArray[np.uint8]
|
|
534
|
+
mol_ids: list[list[int]]
|
|
535
|
+
|
|
536
|
+
|
|
537
|
+
class BitBirch:
|
|
538
|
+
r"""Implements the BitBIRCH clustering algorithm, 'Lean' version
|
|
539
|
+
|
|
540
|
+
Memory and time efficient, online-learning algorithm. It constructs a tree data
|
|
541
|
+
structure with the cluster centroids being read off the leaf.
|
|
542
|
+
|
|
543
|
+
If you find this software useful please cite the following articles:
|
|
544
|
+
|
|
545
|
+
- BitBIRCH: efficient clustering of large molecular libraries:
|
|
546
|
+
https://doi.org/10.1039/D5DD00030K
|
|
547
|
+
- BitBIRCH Clustering Refinement Strategies:
|
|
548
|
+
https://doi.org/10.1021/acs.jcim.5c00627
|
|
549
|
+
- BitBIRCH-Lean: TO-BE-ADDED
|
|
550
|
+
|
|
551
|
+
Parameters
|
|
552
|
+
----------
|
|
553
|
+
|
|
554
|
+
threshold : float = 0.65
|
|
555
|
+
The similarity radius of the subcluster obtained by merging a new sample and the
|
|
556
|
+
closest subcluster should be greater than the threshold. Otherwise a new
|
|
557
|
+
subcluster is started. Setting this value to be very low promotes splitting and
|
|
558
|
+
vice-versa.
|
|
559
|
+
|
|
560
|
+
branching_factor : int = 50
|
|
561
|
+
Maximum number of 'BitFeatures' subclusters in each node. If a new sample enters
|
|
562
|
+
such that the number of subclusters exceed the branching_factor then that node
|
|
563
|
+
is split into two nodes with the subclusters redistributed in each. The parent
|
|
564
|
+
subcluster of that node is removed and two new subclusters are added as parents
|
|
565
|
+
of the 2 split nodes.
|
|
566
|
+
|
|
567
|
+
merge_criterion: str
|
|
568
|
+
radius, diameter or tolerance. *radius*: merge subcluster based on comparison to
|
|
569
|
+
centroid of the cluster. *diameter*: merge subcluster based on instant Tanimoto
|
|
570
|
+
similarity of cluster. *tolerance*: applies tolerance threshold to diameter
|
|
571
|
+
merge criteria, which will merge subcluster with stricter threshold for newly
|
|
572
|
+
added molecules.
|
|
573
|
+
|
|
574
|
+
tolerance: float
|
|
575
|
+
Penalty value for similarity threshold of the 'tolerance' merge criteria.
|
|
576
|
+
|
|
577
|
+
Notes
|
|
578
|
+
-----
|
|
579
|
+
|
|
580
|
+
The tree data structure consists of nodes with each node holdint a number of
|
|
581
|
+
subclusters (``BitFeatures``). The maximum number of subclusters in a node is
|
|
582
|
+
determined by the branching factor. Each subcluster maintains a linear sum,
|
|
583
|
+
mol_indices and the number of samples in that subcluster. In addition, each
|
|
584
|
+
subcluster can also have a node as its child, if the subcluster is not a member of a
|
|
585
|
+
leaf node.
|
|
586
|
+
|
|
587
|
+
Each time a new fingerprint is fitted, it is merged with the subcluster closest to
|
|
588
|
+
it and the linear sum, mol_indices and the number of samples int the corresponding
|
|
589
|
+
subcluster are updated. This is done recursively untils the properties of a leaf
|
|
590
|
+
node are updated.
|
|
591
|
+
|
|
592
|
+
"""
|
|
593
|
+
|
|
594
|
+
def __init__(
|
|
595
|
+
self,
|
|
596
|
+
*,
|
|
597
|
+
threshold: float = 0.65,
|
|
598
|
+
branching_factor: int = 50,
|
|
599
|
+
merge_criterion: str | MergeAcceptFunction | None = None,
|
|
600
|
+
tolerance: float | None = None,
|
|
601
|
+
):
|
|
602
|
+
# Criterion for merges
|
|
603
|
+
self.threshold = threshold
|
|
604
|
+
self.branching_factor = branching_factor
|
|
605
|
+
if _global_merge_accept is not None:
|
|
606
|
+
# Backwards compat
|
|
607
|
+
if tolerance is not None:
|
|
608
|
+
raise ValueError(
|
|
609
|
+
"tolerance can only be passed if "
|
|
610
|
+
"the *global* set_merge function has *not* been used"
|
|
611
|
+
)
|
|
612
|
+
if merge_criterion is not None:
|
|
613
|
+
raise ValueError(
|
|
614
|
+
"merge_criterion can only be passed if "
|
|
615
|
+
"the *global* set_merge function has *not* been used"
|
|
616
|
+
)
|
|
617
|
+
self._merge_accept_fn = _global_merge_accept
|
|
618
|
+
else:
|
|
619
|
+
merge_criterion = "diameter" if merge_criterion is None else merge_criterion
|
|
620
|
+
tolerance = 0.05 if tolerance is None else tolerance
|
|
621
|
+
if isinstance(merge_criterion, MergeAcceptFunction):
|
|
622
|
+
if tolerance is not None:
|
|
623
|
+
raise ValueError(
|
|
624
|
+
"'tolerance' arg is disregarded for custom merge functions"
|
|
625
|
+
)
|
|
626
|
+
self._merge_accept_fn = merge_criterion
|
|
627
|
+
else:
|
|
628
|
+
self._merge_accept_fn = get_merge_accept_fn(merge_criterion, tolerance)
|
|
629
|
+
|
|
630
|
+
# Tree state
|
|
631
|
+
self._num_fitted_fps = 0
|
|
632
|
+
self._root: _BFNode | None = None
|
|
633
|
+
self._dummy_leaf = _BFNode(branching_factor=2, n_features=0)
|
|
634
|
+
# TODO: Type correctly
|
|
635
|
+
self._global_clustering_centroid_labels: NDArray[np.int64] | None = None
|
|
636
|
+
self._n_global_clusters = 0
|
|
637
|
+
|
|
638
|
+
# For backwards compatibility, weak-register in global state This is used to
|
|
639
|
+
# update the merge_accept function if the global set_merge() is called
|
|
640
|
+
# (discouraged)
|
|
641
|
+
_BITBIRCH_INSTANCES.add(self)
|
|
642
|
+
|
|
643
|
+
@property
|
|
644
|
+
def merge_criterion(self) -> str:
|
|
645
|
+
return self._merge_accept_fn.name
|
|
646
|
+
|
|
647
|
+
@merge_criterion.setter
|
|
648
|
+
def merge_criterion(self, value: str) -> None:
|
|
649
|
+
self.set_merge(criterion=value)
|
|
650
|
+
|
|
651
|
+
@property
|
|
652
|
+
def tolerance(self) -> float | None:
|
|
653
|
+
fn = self._merge_accept_fn
|
|
654
|
+
if hasattr(fn, "tolerance"):
|
|
655
|
+
return fn.tolerance
|
|
656
|
+
return None
|
|
657
|
+
|
|
658
|
+
@tolerance.setter
|
|
659
|
+
def tolerance(self, value: float) -> None:
|
|
660
|
+
self.set_merge(tolerance=value)
|
|
661
|
+
|
|
662
|
+
@property
|
|
663
|
+
def is_init(self) -> bool:
|
|
664
|
+
r"""Whether the tree has been initialized (True after first call to `fit()`)"""
|
|
665
|
+
return self._dummy_leaf._next_leaf is not None
|
|
666
|
+
|
|
667
|
+
@property
|
|
668
|
+
def num_fitted_fps(self) -> int:
|
|
669
|
+
r"""Total number of fitted fingerprints"""
|
|
670
|
+
return self._num_fitted_fps
|
|
671
|
+
|
|
672
|
+
def set_merge(
|
|
673
|
+
self,
|
|
674
|
+
criterion: str | MergeAcceptFunction | None = None,
|
|
675
|
+
*,
|
|
676
|
+
tolerance: float | None = None,
|
|
677
|
+
threshold: float | None = None,
|
|
678
|
+
branching_factor: int | None = None,
|
|
679
|
+
) -> None:
|
|
680
|
+
r"""Changes the criteria for merging subclusters in this BitBirch tree
|
|
681
|
+
|
|
682
|
+
For an explanation of the parameters see the `BitBirch` class docstring.
|
|
683
|
+
"""
|
|
684
|
+
if _global_merge_accept is not None:
|
|
685
|
+
raise ValueError(
|
|
686
|
+
"BitBirch.set_merge() can only called if "
|
|
687
|
+
"the global set_merge() function has *not* been used"
|
|
688
|
+
)
|
|
689
|
+
_tolerance = 0.05 if tolerance is None else tolerance
|
|
690
|
+
if isinstance(criterion, MergeAcceptFunction):
|
|
691
|
+
self._merge_accept_fn = criterion
|
|
692
|
+
elif isinstance(criterion, str):
|
|
693
|
+
self._merge_accept_fn = get_merge_accept_fn(criterion, _tolerance)
|
|
694
|
+
if hasattr(self._merge_accept_fn, "tolerance"):
|
|
695
|
+
self._merge_accept_fn.tolerance = _tolerance
|
|
696
|
+
elif tolerance is not None:
|
|
697
|
+
raise ValueError(f"Can't set tolerance for {self._merge_accept_fn}")
|
|
698
|
+
if threshold is not None:
|
|
699
|
+
self.threshold = threshold
|
|
700
|
+
if branching_factor is not None:
|
|
701
|
+
self.branching_factor = branching_factor
|
|
702
|
+
|
|
703
|
+
def fit(
|
|
704
|
+
self,
|
|
705
|
+
X: _Input | Path | str,
|
|
706
|
+
/,
|
|
707
|
+
reinsert_indices: tp.Iterable[int] | None = None,
|
|
708
|
+
input_is_packed: bool = True,
|
|
709
|
+
n_features: int | None = None,
|
|
710
|
+
max_fps: int | None = None,
|
|
711
|
+
) -> tpx.Self:
|
|
712
|
+
r"""Build a BF Tree for the input data.
|
|
713
|
+
|
|
714
|
+
Parameters
|
|
715
|
+
----------
|
|
716
|
+
|
|
717
|
+
X : {array-like, sparse matrix} of shape (n_samples, n_features)
|
|
718
|
+
Input data.
|
|
719
|
+
|
|
720
|
+
reinsert_indices: Iterable[int]
|
|
721
|
+
if ``reinsert_indices`` is passed, ``X`` corresponds only to the molecules
|
|
722
|
+
that will be reinserted into the tree, and ``reinsert_indices`` are the
|
|
723
|
+
indices associated with these molecules.
|
|
724
|
+
|
|
725
|
+
input_is_packed: bool
|
|
726
|
+
Whether the input fingerprints are packed
|
|
727
|
+
|
|
728
|
+
n_features: int
|
|
729
|
+
Number of featurs of input fingerprints. Only required for packed inputs if
|
|
730
|
+
it is not a multiple of 8, otherwise it is redundant.
|
|
731
|
+
|
|
732
|
+
Returns
|
|
733
|
+
-------
|
|
734
|
+
|
|
735
|
+
self
|
|
736
|
+
Fitted estimator.
|
|
737
|
+
"""
|
|
738
|
+
if isinstance(X, (Path, str)):
|
|
739
|
+
X = _mmap_file_and_madvise_sequential(Path(X), max_fps=max_fps)
|
|
740
|
+
mmanager = _ArrayMemPagesManager.from_bb_input(X)
|
|
741
|
+
else:
|
|
742
|
+
X = X[:max_fps]
|
|
743
|
+
mmanager = _ArrayMemPagesManager.from_bb_input(X, can_release=False)
|
|
744
|
+
|
|
745
|
+
n_features = _validate_n_features(X, input_is_packed, n_features)
|
|
746
|
+
# Start a new tree the first time this function is called
|
|
747
|
+
if self._only_has_leaves:
|
|
748
|
+
raise ValueError("Internal nodes were released, call reset() before fit()")
|
|
749
|
+
if not self.is_init:
|
|
750
|
+
self._initialize_tree(n_features)
|
|
751
|
+
self._root = cast("_BFNode", self._root) # After init, this is not None
|
|
752
|
+
|
|
753
|
+
# The array iterator either copies, un-sparsifies, or does nothing
|
|
754
|
+
# with the array rows, depending on the kind of X passed
|
|
755
|
+
arr_iterable = _get_array_iterable(X, input_is_packed, n_features)
|
|
756
|
+
arr_iterable = cast(tp.Iterable[NDArray[np.uint8]], arr_iterable)
|
|
757
|
+
iterable: tp.Iterable[tuple[int, NDArray[np.uint8]]]
|
|
758
|
+
if reinsert_indices is None:
|
|
759
|
+
iterable = enumerate(arr_iterable, self.num_fitted_fps)
|
|
760
|
+
else:
|
|
761
|
+
iterable = zip(reinsert_indices, arr_iterable)
|
|
762
|
+
|
|
763
|
+
threshold = self.threshold
|
|
764
|
+
branching_factor = self.branching_factor
|
|
765
|
+
merge_accept_fn = self._merge_accept_fn
|
|
766
|
+
|
|
767
|
+
arr_idx = 0
|
|
768
|
+
for idx, fp in iterable:
|
|
769
|
+
subcluster = _BFSubcluster(
|
|
770
|
+
linear_sum=fp, mol_indices=[idx], n_features=n_features
|
|
771
|
+
)
|
|
772
|
+
split = self._root.insert_bf_subcluster(
|
|
773
|
+
subcluster, merge_accept_fn, threshold
|
|
774
|
+
)
|
|
775
|
+
|
|
776
|
+
if split:
|
|
777
|
+
new_subcluster1, new_subcluster2 = _split_node(self._root)
|
|
778
|
+
self._root = _BFNode(branching_factor, n_features)
|
|
779
|
+
self._root.append_subcluster(new_subcluster1)
|
|
780
|
+
self._root.append_subcluster(new_subcluster2)
|
|
781
|
+
|
|
782
|
+
self._num_fitted_fps += 1
|
|
783
|
+
arr_idx += 1
|
|
784
|
+
if mmanager.can_release and mmanager.should_release_curr_page(arr_idx):
|
|
785
|
+
mmanager.release_curr_page_and_update_addr()
|
|
786
|
+
return self
|
|
787
|
+
|
|
788
|
+
def _fit_buffers(
|
|
789
|
+
self,
|
|
790
|
+
X: _Input | Path | str,
|
|
791
|
+
reinsert_index_seqs: (
|
|
792
|
+
tp.Iterable[tp.Sequence[int]] | tp.Literal["omit"]
|
|
793
|
+
) = "omit",
|
|
794
|
+
) -> tpx.Self:
|
|
795
|
+
r"""Build a BF Tree starting from buffers
|
|
796
|
+
|
|
797
|
+
Buffers are arrays of the form:
|
|
798
|
+
- buffer[0:-1] = linear_sum
|
|
799
|
+
- buffer[-1] = n_samples
|
|
800
|
+
And X is either an array or a list of such buffers
|
|
801
|
+
|
|
802
|
+
If `reinsert_index_seqs` is passed, X corresponds only to the buffers to be
|
|
803
|
+
reinserted into the tree, and `reinsert_index_seqs` are the sequences
|
|
804
|
+
of indices associated with such buffers.
|
|
805
|
+
|
|
806
|
+
If `reinsert_index_seqs` is "omit", then no indices are collected in the tree.
|
|
807
|
+
|
|
808
|
+
Parameters
|
|
809
|
+
----------
|
|
810
|
+
X : {array-like, sparse matrix} of shape (n_samples + 1, n_features)
|
|
811
|
+
Input data.
|
|
812
|
+
|
|
813
|
+
Returns
|
|
814
|
+
-------
|
|
815
|
+
self
|
|
816
|
+
Fitted estimator.
|
|
817
|
+
"""
|
|
818
|
+
if isinstance(X, (Path, str)):
|
|
819
|
+
X = _mmap_file_and_madvise_sequential(Path(X))
|
|
820
|
+
mmanager = _ArrayMemPagesManager.from_bb_input(X)
|
|
821
|
+
else:
|
|
822
|
+
mmanager = _ArrayMemPagesManager.from_bb_input(X, can_release=False)
|
|
823
|
+
|
|
824
|
+
n_features = _validate_n_features(X, input_is_packed=False) - 1
|
|
825
|
+
# Start a new tree the first time this function is called
|
|
826
|
+
if self._only_has_leaves:
|
|
827
|
+
raise ValueError("Internal nodes were released, call reset() before fit()")
|
|
828
|
+
if not self.is_init:
|
|
829
|
+
self._initialize_tree(n_features)
|
|
830
|
+
self._root = cast("_BFNode", self._root) # After init, this is not None
|
|
831
|
+
|
|
832
|
+
# The array iterator either copies, un-sparsifies, or does nothing with the
|
|
833
|
+
# array rows, depending on the kind of X passed
|
|
834
|
+
arr_iterable = _get_array_iterable(X, input_is_packed=False, dtype=X[0].dtype)
|
|
835
|
+
merge_accept_fn = self._merge_accept_fn
|
|
836
|
+
threshold = self.threshold
|
|
837
|
+
branching_factor = self.branching_factor
|
|
838
|
+
idx_provider: tp.Iterable[tp.Sequence[int]]
|
|
839
|
+
arr_idx = 0
|
|
840
|
+
if reinsert_index_seqs == "omit":
|
|
841
|
+
idx_provider = (() for idx in range(self.num_fitted_fps))
|
|
842
|
+
check = False
|
|
843
|
+
else:
|
|
844
|
+
idx_provider = reinsert_index_seqs
|
|
845
|
+
check = True
|
|
846
|
+
for idxs, buf in zip(idx_provider, arr_iterable):
|
|
847
|
+
subcluster = _BFSubcluster(
|
|
848
|
+
buffer=buf, mol_indices=idxs, n_features=n_features, check_indices=check
|
|
849
|
+
)
|
|
850
|
+
split = self._root.insert_bf_subcluster(
|
|
851
|
+
subcluster, merge_accept_fn, threshold
|
|
852
|
+
)
|
|
853
|
+
|
|
854
|
+
if split:
|
|
855
|
+
new_subcluster1, new_subcluster2 = _split_node(self._root)
|
|
856
|
+
self._root = _BFNode(branching_factor, n_features)
|
|
857
|
+
self._root.append_subcluster(new_subcluster1)
|
|
858
|
+
self._root.append_subcluster(new_subcluster2)
|
|
859
|
+
|
|
860
|
+
self._num_fitted_fps += len(idxs)
|
|
861
|
+
arr_idx += 1
|
|
862
|
+
if mmanager.can_release and mmanager.should_release_curr_page(arr_idx):
|
|
863
|
+
mmanager.release_curr_page_and_update_addr()
|
|
864
|
+
return self
|
|
865
|
+
|
|
866
|
+
# Provided for backwards compatibility
|
|
867
|
+
def fit_reinsert(
|
|
868
|
+
self,
|
|
869
|
+
X: _Input | Path | str,
|
|
870
|
+
reinsert_indices: tp.Iterable[int],
|
|
871
|
+
input_is_packed: bool = True,
|
|
872
|
+
n_features: int | None = None,
|
|
873
|
+
max_fps: int | None = None,
|
|
874
|
+
) -> tpx.Self:
|
|
875
|
+
r""":meta private:"""
|
|
876
|
+
return self.fit(X, reinsert_indices, input_is_packed, n_features, max_fps)
|
|
877
|
+
|
|
878
|
+
def _initialize_tree(self, n_features: int) -> None:
|
|
879
|
+
# Initialize the root (and a dummy node to get back the subclusters
|
|
880
|
+
self._root = _BFNode(self.branching_factor, n_features)
|
|
881
|
+
self._dummy_leaf._next_leaf = self._root
|
|
882
|
+
self._root._prev_leaf = self._dummy_leaf
|
|
883
|
+
|
|
884
|
+
def _get_leaves(self) -> tp.Iterator[_BFNode]:
|
|
885
|
+
r"""Yields all leaf nodes"""
|
|
886
|
+
if not self.is_init:
|
|
887
|
+
raise ValueError("The model has not been fitted yet.")
|
|
888
|
+
leaf = self._dummy_leaf._next_leaf
|
|
889
|
+
while leaf is not None:
|
|
890
|
+
yield leaf
|
|
891
|
+
leaf = leaf._next_leaf
|
|
892
|
+
|
|
893
|
+
def get_centroids_mol_ids(
|
|
894
|
+
self, sort: bool = True, packed: bool = True
|
|
895
|
+
) -> _CentroidsMolIds:
|
|
896
|
+
"""Get a dict with centroids and mol indices of the leaves"""
|
|
897
|
+
# NOTE: This is different from the original bitbirch, here outputs are sorted by
|
|
898
|
+
# default
|
|
899
|
+
centroids = []
|
|
900
|
+
mol_ids = []
|
|
901
|
+
attr = "packed_centroid" if packed else "unpacked_centroid"
|
|
902
|
+
for subcluster in self._get_leaf_bfs(sort=sort):
|
|
903
|
+
centroids.append(getattr(subcluster, attr))
|
|
904
|
+
mol_ids.append(subcluster.mol_indices)
|
|
905
|
+
return {"centroids": centroids, "mol_ids": mol_ids}
|
|
906
|
+
|
|
907
|
+
def get_centroids(
|
|
908
|
+
self,
|
|
909
|
+
sort: bool = True,
|
|
910
|
+
packed: bool = True,
|
|
911
|
+
) -> list[NDArray[np.uint8]]:
|
|
912
|
+
r"""Get a list of arrays with the centroids' fingerprints"""
|
|
913
|
+
# NOTE: This is different from the original bitbirch, here outputs are sorted by
|
|
914
|
+
# default
|
|
915
|
+
attr = "packed_centroid" if packed else "unpacked_centroid"
|
|
916
|
+
return [getattr(s, attr) for s in self._get_leaf_bfs(sort=sort)]
|
|
917
|
+
|
|
918
|
+
def get_medoids_mol_ids(
|
|
919
|
+
self,
|
|
920
|
+
fps: NDArray[np.uint8],
|
|
921
|
+
sort: bool = True,
|
|
922
|
+
pack: bool = True,
|
|
923
|
+
global_clusters: bool = False,
|
|
924
|
+
input_is_packed: bool = True,
|
|
925
|
+
n_features: int | None = None,
|
|
926
|
+
) -> _MedoidsMolIds:
|
|
927
|
+
"""Get a dict with medoids and mol indices of the leaves"""
|
|
928
|
+
cluster_members = self.get_cluster_mol_ids(
|
|
929
|
+
sort=sort, global_clusters=global_clusters
|
|
930
|
+
)
|
|
931
|
+
|
|
932
|
+
if input_is_packed:
|
|
933
|
+
fps = _unpack_fingerprints(fps, n_features=n_features)
|
|
934
|
+
cluster_medoids = self._unpacked_medoids_from_members(fps, cluster_members)
|
|
935
|
+
if pack:
|
|
936
|
+
cluster_medoids = pack_fingerprints(cluster_medoids)
|
|
937
|
+
return {"medoids": cluster_medoids, "mol_ids": cluster_members}
|
|
938
|
+
|
|
939
|
+
@staticmethod
|
|
940
|
+
def _unpacked_medoids_from_members(
|
|
941
|
+
unpacked_fps: NDArray[np.uint8], cluster_members: tp.Sequence[list[int]]
|
|
942
|
+
) -> NDArray[np.uint8]:
|
|
943
|
+
cluster_medoids = np.zeros(
|
|
944
|
+
(len(cluster_members), unpacked_fps.shape[1]), dtype=np.uint8
|
|
945
|
+
)
|
|
946
|
+
for idx, members in enumerate(cluster_members):
|
|
947
|
+
cluster_medoids[idx, :] = jt_isim_medoid(
|
|
948
|
+
unpacked_fps[members],
|
|
949
|
+
input_is_packed=False,
|
|
950
|
+
pack=False,
|
|
951
|
+
)[1]
|
|
952
|
+
return cluster_medoids
|
|
953
|
+
|
|
954
|
+
def get_medoids(
|
|
955
|
+
self,
|
|
956
|
+
fps: NDArray[np.uint8],
|
|
957
|
+
sort: bool = True,
|
|
958
|
+
pack: bool = True,
|
|
959
|
+
global_clusters: bool = False,
|
|
960
|
+
input_is_packed: bool = True,
|
|
961
|
+
n_features: int | None = None,
|
|
962
|
+
) -> NDArray[np.uint8]:
|
|
963
|
+
return self.get_medoids_mol_ids(
|
|
964
|
+
fps, sort, pack, global_clusters, input_is_packed, n_features
|
|
965
|
+
)["medoids"]
|
|
966
|
+
|
|
967
|
+
def get_cluster_mol_ids(
|
|
968
|
+
self, sort: bool = True, global_clusters: bool = False
|
|
969
|
+
) -> list[list[int]]:
|
|
970
|
+
r"""Get the indices of the molecules in each cluster"""
|
|
971
|
+
if global_clusters:
|
|
972
|
+
if self._global_clustering_centroid_labels is None:
|
|
973
|
+
raise ValueError(
|
|
974
|
+
"Must perform global clustering before fetching global labels"
|
|
975
|
+
)
|
|
976
|
+
bf_labels = (
|
|
977
|
+
self._global_clustering_centroid_labels - 1
|
|
978
|
+
) # sub 1 to use as idxs
|
|
979
|
+
|
|
980
|
+
# Collect the members of all clusters
|
|
981
|
+
it = (bf.mol_indices for bf in self._get_leaf_bfs(sort=sort))
|
|
982
|
+
return self._new_ids_from_labels(it, bf_labels, self._n_global_clusters)
|
|
983
|
+
|
|
984
|
+
return [s.mol_indices for s in self._get_leaf_bfs(sort=sort)]
|
|
985
|
+
|
|
986
|
+
@staticmethod
|
|
987
|
+
def _new_ids_from_labels(
|
|
988
|
+
members: tp.Iterable[list[int]],
|
|
989
|
+
labels: NDArray[np.int64],
|
|
990
|
+
n_labels: int | None = None,
|
|
991
|
+
) -> list[list[int]]:
|
|
992
|
+
r"""Get the indices of the molecules in each cluster"""
|
|
993
|
+
if n_labels is None:
|
|
994
|
+
n_labels = len(np.unique(labels))
|
|
995
|
+
new_members: list[list[int]] = [[] for _ in range(n_labels)]
|
|
996
|
+
for i, idxs in enumerate(members):
|
|
997
|
+
new_members[labels[i]].extend(idxs)
|
|
998
|
+
return new_members
|
|
999
|
+
|
|
1000
|
+
def get_assignments(
|
|
1001
|
+
self,
|
|
1002
|
+
n_mols: int | None = None,
|
|
1003
|
+
sort: bool = True,
|
|
1004
|
+
check_valid: bool = True,
|
|
1005
|
+
global_clusters: bool = False,
|
|
1006
|
+
) -> NDArray[np.uint64]:
|
|
1007
|
+
r"""Get an array with the cluster labels associated with each fingerprint idx"""
|
|
1008
|
+
if n_mols is not None:
|
|
1009
|
+
warnings.warn("The n_mols argument is redundant", DeprecationWarning)
|
|
1010
|
+
if n_mols is not None and n_mols != self.num_fitted_fps:
|
|
1011
|
+
raise ValueError(
|
|
1012
|
+
f"Provided n_mols {n_mols} is different"
|
|
1013
|
+
f" from the number of fitted fingerprints {self.num_fitted_fps}"
|
|
1014
|
+
)
|
|
1015
|
+
if check_valid:
|
|
1016
|
+
assignments = np.full(self.num_fitted_fps, 0, dtype=np.uint64)
|
|
1017
|
+
else:
|
|
1018
|
+
assignments = np.empty(self.num_fitted_fps, dtype=np.uint64)
|
|
1019
|
+
|
|
1020
|
+
iterator: tp.Iterable[list[int]]
|
|
1021
|
+
if sort:
|
|
1022
|
+
iterator = self.get_cluster_mol_ids(sort=True)
|
|
1023
|
+
else:
|
|
1024
|
+
iterator = (
|
|
1025
|
+
s.mol_indices for leaf in self._get_leaves() for s in leaf._subclusters
|
|
1026
|
+
)
|
|
1027
|
+
|
|
1028
|
+
if global_clusters:
|
|
1029
|
+
if self._global_clustering_centroid_labels is None:
|
|
1030
|
+
raise ValueError(
|
|
1031
|
+
"Must perform global clustering before fetching global labels"
|
|
1032
|
+
)
|
|
1033
|
+
# Assign according to global clustering labels
|
|
1034
|
+
final_labels = self._global_clustering_centroid_labels
|
|
1035
|
+
for mol_ids, label in zip(iterator, final_labels):
|
|
1036
|
+
assignments[mol_ids] = label
|
|
1037
|
+
else:
|
|
1038
|
+
# Assign according to mol_ids from the subclusters
|
|
1039
|
+
for i, mol_ids in enumerate(iterator, 1):
|
|
1040
|
+
assignments[mol_ids] = i
|
|
1041
|
+
|
|
1042
|
+
# Check that there are no unassigned molecules
|
|
1043
|
+
if check_valid and (assignments == 0).any():
|
|
1044
|
+
raise ValueError("There are unasigned molecules")
|
|
1045
|
+
return assignments
|
|
1046
|
+
|
|
1047
|
+
def dump_assignments(
|
|
1048
|
+
self,
|
|
1049
|
+
path: Path | str,
|
|
1050
|
+
smiles: tp.Iterable[str] = (),
|
|
1051
|
+
sort: bool = True,
|
|
1052
|
+
global_clusters: bool = False,
|
|
1053
|
+
check_valid: bool = True,
|
|
1054
|
+
) -> None:
|
|
1055
|
+
r"""Dump the cluster assignments to a ``*.csv`` file"""
|
|
1056
|
+
import pandas as pd # Hide pandas import since it is heavy
|
|
1057
|
+
|
|
1058
|
+
path = Path(path)
|
|
1059
|
+
if isinstance(smiles, str):
|
|
1060
|
+
smiles = [smiles]
|
|
1061
|
+
smiles = np.asarray(smiles, dtype=np.str_)
|
|
1062
|
+
# Dump cluster assignments to *.csv
|
|
1063
|
+
assignments = self.get_assignments(
|
|
1064
|
+
sort=sort, check_valid=check_valid, global_clusters=global_clusters
|
|
1065
|
+
)
|
|
1066
|
+
if smiles.size and (len(assignments) != len(smiles)):
|
|
1067
|
+
raise ValueError(
|
|
1068
|
+
f"Len of the provided smiles {len(smiles)}"
|
|
1069
|
+
f" must match the number of fitted fingerprints {self.num_fitted_fps}"
|
|
1070
|
+
)
|
|
1071
|
+
df = pd.DataFrame({"assignments": assignments})
|
|
1072
|
+
if smiles.size:
|
|
1073
|
+
df["smiles"] = smiles
|
|
1074
|
+
df.to_csv(path, index=False)
|
|
1075
|
+
|
|
1076
|
+
def reset(self) -> None:
|
|
1077
|
+
r"""Reset the tree state
|
|
1078
|
+
|
|
1079
|
+
Delete *all internal nodes and leafs*, does not reset the merge criterion or
|
|
1080
|
+
other merge parameters.
|
|
1081
|
+
"""
|
|
1082
|
+
# Reset the whole tree
|
|
1083
|
+
if self._root is not None:
|
|
1084
|
+
self._root._prev_leaf = None
|
|
1085
|
+
self._root._next_leaf = None
|
|
1086
|
+
self._dummy_leaf._next_leaf = None
|
|
1087
|
+
self._root = None
|
|
1088
|
+
self._num_fitted_fps = 0
|
|
1089
|
+
|
|
1090
|
+
def delete_internal_nodes(self) -> None:
|
|
1091
|
+
r"""Delete all nodes in the tree that are not leaves
|
|
1092
|
+
|
|
1093
|
+
This function is for advanced usage only. It should be called if there is need
|
|
1094
|
+
to use the BitBirch leaf clusters, but you need to release the memory held by
|
|
1095
|
+
the internal nodes. After calling this function, no more fingerprints can be fit
|
|
1096
|
+
into the tree, unless a call to `BitBirch.reset` afterwards releases the
|
|
1097
|
+
*whole tree*, including the leaf clusters.
|
|
1098
|
+
"""
|
|
1099
|
+
if not tp.cast(_BFNode, self._root).is_leaf:
|
|
1100
|
+
# release all nodes that are not leaves,
|
|
1101
|
+
# they are kept alive by references from dummy_leaf
|
|
1102
|
+
self._root = None
|
|
1103
|
+
|
|
1104
|
+
@property
|
|
1105
|
+
def _only_has_leaves(self) -> bool:
|
|
1106
|
+
return (self._root is None) and (self._dummy_leaf._next_leaf is not None)
|
|
1107
|
+
|
|
1108
|
+
def recluster_inplace(
|
|
1109
|
+
self,
|
|
1110
|
+
iterations: int = 1,
|
|
1111
|
+
extra_threshold: float = 0.0,
|
|
1112
|
+
shuffle: bool = False,
|
|
1113
|
+
seed: int | None = None,
|
|
1114
|
+
verbose: bool = False,
|
|
1115
|
+
stop_early: bool = False,
|
|
1116
|
+
) -> tpx.Self:
|
|
1117
|
+
r"""Refine singleton clusters by re-inserting them into the tree
|
|
1118
|
+
|
|
1119
|
+
Parameters
|
|
1120
|
+
----------
|
|
1121
|
+
extra_threshold : float, default=0.0
|
|
1122
|
+
The amount to increase the current threshold in each iteration.
|
|
1123
|
+
|
|
1124
|
+
iterations : int, default=1
|
|
1125
|
+
The maximum number of refinement iterations to perform.
|
|
1126
|
+
|
|
1127
|
+
Returns
|
|
1128
|
+
-------
|
|
1129
|
+
self : BitBirch
|
|
1130
|
+
The fitted estimator with refined clusters.
|
|
1131
|
+
|
|
1132
|
+
Raises
|
|
1133
|
+
------
|
|
1134
|
+
ValueError
|
|
1135
|
+
If the model has not been fitted.
|
|
1136
|
+
"""
|
|
1137
|
+
if not self.is_init:
|
|
1138
|
+
raise ValueError("The model has not been fitted yet.")
|
|
1139
|
+
|
|
1140
|
+
singletons_before = 0
|
|
1141
|
+
for _ in range(iterations):
|
|
1142
|
+
# Get the BFs
|
|
1143
|
+
bfs = self._get_leaf_bfs(sort=True)
|
|
1144
|
+
|
|
1145
|
+
# Count the number of clusters and singletons
|
|
1146
|
+
singleton_bfs = sum(1 for bf in bfs if bf.n_samples == 1)
|
|
1147
|
+
|
|
1148
|
+
# Check stopping criteria
|
|
1149
|
+
if stop_early:
|
|
1150
|
+
if singleton_bfs == 0 or singleton_bfs == singletons_before:
|
|
1151
|
+
# No more singletons to refine
|
|
1152
|
+
break
|
|
1153
|
+
singletons_before = singleton_bfs
|
|
1154
|
+
|
|
1155
|
+
# Print progress
|
|
1156
|
+
if verbose:
|
|
1157
|
+
print(f"Current number of clusters: {len(bfs)}")
|
|
1158
|
+
print(f"Current number of singletons: {singleton_bfs}")
|
|
1159
|
+
|
|
1160
|
+
if shuffle:
|
|
1161
|
+
random.seed(seed)
|
|
1162
|
+
random.shuffle(bfs)
|
|
1163
|
+
|
|
1164
|
+
# Prepare the buffers for refitting
|
|
1165
|
+
fps_bfs, mols_bfs = self._prepare_bf_to_buffer_dicts(bfs)
|
|
1166
|
+
|
|
1167
|
+
# Reset the tree
|
|
1168
|
+
self.reset()
|
|
1169
|
+
|
|
1170
|
+
# Change the threshold
|
|
1171
|
+
self.threshold += extra_threshold
|
|
1172
|
+
|
|
1173
|
+
# Refit the subsclusters
|
|
1174
|
+
for bufs, mol_idxs in zip(fps_bfs.values(), mols_bfs.values()):
|
|
1175
|
+
self._fit_buffers(bufs, reinsert_index_seqs=mol_idxs)
|
|
1176
|
+
|
|
1177
|
+
# Print final stats
|
|
1178
|
+
if verbose:
|
|
1179
|
+
bfs = self._get_leaf_bfs(sort=True)
|
|
1180
|
+
singleton_bfs = sum(1 for bf in bfs if bf.n_samples == 1)
|
|
1181
|
+
print(f"Final number of clusters: {len(bfs)}")
|
|
1182
|
+
print(f"Final number of singletons: {singleton_bfs}")
|
|
1183
|
+
return self
|
|
1184
|
+
|
|
1185
|
+
def refine_inplace(
|
|
1186
|
+
self,
|
|
1187
|
+
X: _Input | Path | str | tp.Sequence[Path],
|
|
1188
|
+
initial_mol: int = 0,
|
|
1189
|
+
input_is_packed: bool = True,
|
|
1190
|
+
n_largest: int = 1,
|
|
1191
|
+
) -> tpx.Self:
|
|
1192
|
+
r"""Refine the tree: break the largest clusters in singletons and re-fit"""
|
|
1193
|
+
if not self.is_init:
|
|
1194
|
+
raise ValueError("The model has not been fitted yet.")
|
|
1195
|
+
# Release the memory held by non-leaf nodes
|
|
1196
|
+
self.delete_internal_nodes()
|
|
1197
|
+
|
|
1198
|
+
# Extract the BitFeatures of the leaves, breaking the largest cluster apart into
|
|
1199
|
+
# singleton subclusters
|
|
1200
|
+
fps_bfs, mols_bfs = self._bf_to_np_refine( # This function takes a bunch of mem
|
|
1201
|
+
X,
|
|
1202
|
+
initial_mol=initial_mol,
|
|
1203
|
+
input_is_packed=input_is_packed,
|
|
1204
|
+
n_largest=n_largest,
|
|
1205
|
+
)
|
|
1206
|
+
# Reset the tree
|
|
1207
|
+
self.reset()
|
|
1208
|
+
|
|
1209
|
+
# Rebuild the tree again from scratch, reinserting all the subclusters
|
|
1210
|
+
for bufs, mol_idxs in zip(fps_bfs.values(), mols_bfs.values()):
|
|
1211
|
+
self._fit_buffers(bufs, reinsert_index_seqs=mol_idxs)
|
|
1212
|
+
return self
|
|
1213
|
+
|
|
1214
|
+
def _get_leaf_bfs(self, sort: bool = True) -> list[_BFSubcluster]:
|
|
1215
|
+
r"""Get the BitFeatures of the leaves"""
|
|
1216
|
+
bfs = [s for leaf in self._get_leaves() for s in leaf._subclusters]
|
|
1217
|
+
if sort:
|
|
1218
|
+
# Sort the BitFeatures by the number of samples in the cluster
|
|
1219
|
+
bfs.sort(key=lambda x: x.n_samples, reverse=True)
|
|
1220
|
+
return bfs
|
|
1221
|
+
|
|
1222
|
+
def _bf_to_np_refine(
|
|
1223
|
+
self,
|
|
1224
|
+
X: _Input | Path | str | tp.Sequence[Path],
|
|
1225
|
+
initial_mol: int = 0,
|
|
1226
|
+
input_is_packed: bool = True,
|
|
1227
|
+
n_largest: int = 1,
|
|
1228
|
+
) -> tuple[dict[str, list[NDArray[np.integer]]], dict[str, list[list[int]]]]:
|
|
1229
|
+
"""Prepare numpy bufs ('np') for BitFeatures, splitting the biggest n clusters
|
|
1230
|
+
|
|
1231
|
+
The largest clusters are split into singletons. In order to perform this split,
|
|
1232
|
+
the *original* fingerprint array used to fit the tree (X) has to be provided,
|
|
1233
|
+
together with the index associated with the first fingerprint.
|
|
1234
|
+
|
|
1235
|
+
The split is only performed for the returned 'np' buffers, the clusters in the
|
|
1236
|
+
tree itself are not modified
|
|
1237
|
+
"""
|
|
1238
|
+
if n_largest == 0:
|
|
1239
|
+
return self._bf_to_np()
|
|
1240
|
+
|
|
1241
|
+
if n_largest < 1:
|
|
1242
|
+
raise ValueError("n_largest must be >= 1")
|
|
1243
|
+
|
|
1244
|
+
bfs = self._get_leaf_bfs()
|
|
1245
|
+
largest = bfs[:n_largest]
|
|
1246
|
+
rest = bfs[n_largest:]
|
|
1247
|
+
n_features = largest[0].n_features
|
|
1248
|
+
|
|
1249
|
+
dtypes_to_fp, dtypes_to_mols = self._prepare_bf_to_buffer_dicts(rest)
|
|
1250
|
+
# Add X and mol indices of the "big" cluster
|
|
1251
|
+
if input_is_packed:
|
|
1252
|
+
unpack_or_copy = lambda x: _unpack_fingerprints(
|
|
1253
|
+
cast(NDArray[np.uint8], x), n_features
|
|
1254
|
+
)
|
|
1255
|
+
else:
|
|
1256
|
+
unpack_or_copy = lambda x: x.copy()
|
|
1257
|
+
|
|
1258
|
+
for big_bf in largest:
|
|
1259
|
+
full_arr_idxs = [(idx - initial_mol) for idx in big_bf.mol_indices]
|
|
1260
|
+
_X: _Input
|
|
1261
|
+
if isinstance(X, (Path, str)):
|
|
1262
|
+
# Only load the specific required mol idxs
|
|
1263
|
+
_X = cast(NDArray[np.integer], np.load(X, mmap_mode="r"))[full_arr_idxs]
|
|
1264
|
+
arr_idxs = list(range(len(_X)))
|
|
1265
|
+
mol_idxs = big_bf.mol_indices
|
|
1266
|
+
elif isinstance(X[0], Path):
|
|
1267
|
+
# Only load the specific required mol idxs
|
|
1268
|
+
sort_idxs = np.argsort(full_arr_idxs)
|
|
1269
|
+
_X = _get_fingerprints_from_file_seq(
|
|
1270
|
+
cast(tp.Sequence[Path], X),
|
|
1271
|
+
[full_arr_idxs[i] for i in sort_idxs],
|
|
1272
|
+
)
|
|
1273
|
+
arr_idxs = list(range(len(_X)))
|
|
1274
|
+
mol_idxs = big_bf.mol_indices
|
|
1275
|
+
mol_idxs = [mol_idxs[i] for i in sort_idxs]
|
|
1276
|
+
else:
|
|
1277
|
+
# Index the full array / list
|
|
1278
|
+
_X = cast(_Input, X)
|
|
1279
|
+
arr_idxs = full_arr_idxs
|
|
1280
|
+
mol_idxs = big_bf.mol_indices
|
|
1281
|
+
|
|
1282
|
+
for mol_idx, arr_idx in zip(mol_idxs, arr_idxs):
|
|
1283
|
+
buffer = np.empty(n_features + 1, dtype=np.uint8)
|
|
1284
|
+
buffer[:-1] = unpack_or_copy(_X[arr_idx])
|
|
1285
|
+
buffer[-1] = 1
|
|
1286
|
+
dtypes_to_fp["uint8"].append(buffer)
|
|
1287
|
+
dtypes_to_mols["uint8"].append([mol_idx])
|
|
1288
|
+
return dtypes_to_fp, dtypes_to_mols
|
|
1289
|
+
|
|
1290
|
+
def _bf_to_np(
|
|
1291
|
+
self,
|
|
1292
|
+
) -> tuple[dict[str, list[NDArray[np.integer]]], dict[str, list[list[int]]]]:
|
|
1293
|
+
"""Prepare numpy buffers ('np') for BitFeatures of all clusters"""
|
|
1294
|
+
return self._prepare_bf_to_buffer_dicts(self._get_leaf_bfs())
|
|
1295
|
+
|
|
1296
|
+
@staticmethod
|
|
1297
|
+
def _prepare_bf_to_buffer_dicts(
|
|
1298
|
+
bfs: list["_BFSubcluster"],
|
|
1299
|
+
) -> tuple[dict[str, list[NDArray[np.integer]]], dict[str, list[list[int]]]]:
|
|
1300
|
+
# Helper function used when returning lists of subclusters
|
|
1301
|
+
dtypes_to_fp = defaultdict(list)
|
|
1302
|
+
dtypes_to_mols = defaultdict(list)
|
|
1303
|
+
for bf in bfs:
|
|
1304
|
+
dtypes_to_fp[bf.dtype_name].append(bf._buffer)
|
|
1305
|
+
dtypes_to_mols[bf.dtype_name].append(bf.mol_indices)
|
|
1306
|
+
return dtypes_to_fp, dtypes_to_mols
|
|
1307
|
+
|
|
1308
|
+
def __repr__(self) -> str:
|
|
1309
|
+
fn = self._merge_accept_fn
|
|
1310
|
+
parts = [
|
|
1311
|
+
f"threshold={self.threshold}",
|
|
1312
|
+
f"branching_factor={self.branching_factor}",
|
|
1313
|
+
f"merge_criterion='{fn.name if fn.name in BUILTIN_MERGES else fn}'",
|
|
1314
|
+
]
|
|
1315
|
+
if self.tolerance is not None:
|
|
1316
|
+
parts.append(f"tolerance={self.tolerance}")
|
|
1317
|
+
return f"{self.__class__.__name__}({', '.join(parts)})"
|
|
1318
|
+
|
|
1319
|
+
def global_clustering(
|
|
1320
|
+
self,
|
|
1321
|
+
n_clusters: int,
|
|
1322
|
+
*,
|
|
1323
|
+
method: str = "kmeans",
|
|
1324
|
+
# TODO: Type correctly
|
|
1325
|
+
**method_kwargs: tp.Any,
|
|
1326
|
+
) -> tpx.Self:
|
|
1327
|
+
r""":meta private:"""
|
|
1328
|
+
warnings.warn(
|
|
1329
|
+
"Global clustering is an experimental features"
|
|
1330
|
+
" it will be modified without warning, please do not use"
|
|
1331
|
+
)
|
|
1332
|
+
if not self.is_init:
|
|
1333
|
+
raise ValueError("The model has not been fitted yet.")
|
|
1334
|
+
centroids = np.vstack(self.get_centroids(packed=False))
|
|
1335
|
+
labels = self._centrals_global_clustering(
|
|
1336
|
+
centroids, n_clusters, method=method, input_is_packed=False, **method_kwargs
|
|
1337
|
+
)
|
|
1338
|
+
num_centroids = len(centroids)
|
|
1339
|
+
self._n_global_clusters = (
|
|
1340
|
+
n_clusters if num_centroids > n_clusters else num_centroids
|
|
1341
|
+
)
|
|
1342
|
+
self._global_clustering_centroid_labels = labels
|
|
1343
|
+
return self
|
|
1344
|
+
|
|
1345
|
+
@staticmethod
|
|
1346
|
+
def _centrals_global_clustering(
|
|
1347
|
+
centrals: NDArray[np.uint8],
|
|
1348
|
+
n_clusters: int,
|
|
1349
|
+
*,
|
|
1350
|
+
method: str = "kmeans",
|
|
1351
|
+
input_is_packed: bool = True,
|
|
1352
|
+
n_features: int | None = None,
|
|
1353
|
+
# TODO: Type correctly
|
|
1354
|
+
**method_kwargs: tp.Any,
|
|
1355
|
+
) -> NDArray[np.int64]:
|
|
1356
|
+
r""":meta private:"""
|
|
1357
|
+
if method not in {"agglomerative", "kmeans", "kmeans-normalized"}:
|
|
1358
|
+
raise ValueError(f"Unknown method {method}")
|
|
1359
|
+
# Returns the labels associated with global clustering
|
|
1360
|
+
# Lazy import because sklearn is very heavy
|
|
1361
|
+
from sklearn.cluster import KMeans, AgglomerativeClustering
|
|
1362
|
+
from sklearn.exceptions import ConvergenceWarning
|
|
1363
|
+
|
|
1364
|
+
if input_is_packed:
|
|
1365
|
+
centrals = _unpack_fingerprints(centrals, n_features)
|
|
1366
|
+
|
|
1367
|
+
num_centrals = len(centrals)
|
|
1368
|
+
if num_centrals < n_clusters:
|
|
1369
|
+
msg = (
|
|
1370
|
+
f"Number of subclusters found ({num_centrals}) by BitBIRCH is less "
|
|
1371
|
+
"than ({n_clusters}). Decrease k or the threshold."
|
|
1372
|
+
)
|
|
1373
|
+
warnings.warn(msg, ConvergenceWarning, stacklevel=2)
|
|
1374
|
+
n_clusters = num_centrals
|
|
1375
|
+
|
|
1376
|
+
if method == "kmeans-normalized":
|
|
1377
|
+
centrals = centrals / np.linalg.norm(centrals, axis=1, keepdims=True)
|
|
1378
|
+
if method in ["kmeans", "kmeans-normalized"]:
|
|
1379
|
+
predictor = KMeans(n_clusters=n_clusters, **method_kwargs)
|
|
1380
|
+
elif method == "agglomerative":
|
|
1381
|
+
predictor = AgglomerativeClustering(n_clusters=n_clusters, **method_kwargs)
|
|
1382
|
+
else:
|
|
1383
|
+
raise ValueError("method must be one of 'kmeans' or 'agglomerative'")
|
|
1384
|
+
|
|
1385
|
+
# Add 1 to start labels from 1 instead of 0, so 0 can be used as sentinel
|
|
1386
|
+
# value
|
|
1387
|
+
# This is the bottleneck for building this index
|
|
1388
|
+
# K-means is feasible, agglomerative is extremely expensive
|
|
1389
|
+
return predictor.fit_predict(centrals) + 1
|
|
1390
|
+
|
|
1391
|
+
|
|
1392
|
+
# There are 4 cases here:
|
|
1393
|
+
# (1) The input is a scipy.sparse array
|
|
1394
|
+
# (2) The input is a list of dense arrays (nothing required)
|
|
1395
|
+
# (3) The input is a packed array or list of packed arrays (unpack required)
|
|
1396
|
+
# (4) The input is a dense array (copy required)
|
|
1397
|
+
# NOTE: Sparse iteration hack is taken from sklearn
|
|
1398
|
+
# It returns a densified row when iterating over a sparse matrix, instead
|
|
1399
|
+
# of constructing a sparse matrix for every row that is expensive.
|
|
1400
|
+
#
|
|
1401
|
+
# Output is *always* of dtype uint8, but input (if unpacked) can be of arbitrary dtype
|
|
1402
|
+
# It is most efficient for input to be uint8 to prevent copies
|
|
1403
|
+
def _get_array_iterable(
|
|
1404
|
+
X: _Input,
|
|
1405
|
+
input_is_packed: bool = True,
|
|
1406
|
+
n_features: int | None = None,
|
|
1407
|
+
dtype: DTypeLike = np.uint8,
|
|
1408
|
+
) -> tp.Iterable[NDArray[np.integer]]:
|
|
1409
|
+
if input_is_packed:
|
|
1410
|
+
# Unpacking copies the fingerprints, so no extra copy required
|
|
1411
|
+
# NOTE: cast seems to have a very small overhead in this loop for some reason
|
|
1412
|
+
return (_unpack_fingerprints(a, n_features) for a in X) # type: ignore
|
|
1413
|
+
if isinstance(X, list):
|
|
1414
|
+
# No copy is required here unless the dtype is not uint8
|
|
1415
|
+
return (a.astype(dtype, copy=False) for a in X)
|
|
1416
|
+
if isinstance(X, np.ndarray):
|
|
1417
|
+
# A copy is required here to avoid keeping a ref to the full array alive
|
|
1418
|
+
return (a.astype(dtype, copy=True) for a in X)
|
|
1419
|
+
return _iter_sparse(X)
|
|
1420
|
+
|
|
1421
|
+
|
|
1422
|
+
# NOTE: In practice this branch is never used, it could probably safely be deleted
|
|
1423
|
+
def _iter_sparse(X: tp.Any) -> tp.Iterator[NDArray[np.uint8]]:
|
|
1424
|
+
import scipy.sparse # Hide this import since scipy is heavy
|
|
1425
|
+
|
|
1426
|
+
if not scipy.sparse.issparse(X):
|
|
1427
|
+
raise ValueError(f"Input of type {type(X)} is not supported")
|
|
1428
|
+
n_samples, n_features = X.shape
|
|
1429
|
+
X_indices = X.indices
|
|
1430
|
+
X_data = X.data
|
|
1431
|
+
X_indptr = X.indptr
|
|
1432
|
+
for i in range(n_samples):
|
|
1433
|
+
a = np.zeros(n_features, dtype=np.uint8)
|
|
1434
|
+
startptr, endptr = X_indptr[i], X_indptr[i + 1]
|
|
1435
|
+
nonzero_indices = X_indices[startptr:endptr]
|
|
1436
|
+
a[nonzero_indices] = X_data[startptr:endptr].astype(np.uint8, copy=False)
|
|
1437
|
+
yield a
|