pytme 0.2.9__cp311-cp311-macosx_15_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pytme-0.2.9.data/scripts/estimate_ram_usage.py +97 -0
- pytme-0.2.9.data/scripts/match_template.py +1135 -0
- pytme-0.2.9.data/scripts/postprocess.py +622 -0
- pytme-0.2.9.data/scripts/preprocess.py +209 -0
- pytme-0.2.9.data/scripts/preprocessor_gui.py +1227 -0
- pytme-0.2.9.dist-info/METADATA +95 -0
- pytme-0.2.9.dist-info/RECORD +119 -0
- pytme-0.2.9.dist-info/WHEEL +5 -0
- pytme-0.2.9.dist-info/entry_points.txt +6 -0
- pytme-0.2.9.dist-info/licenses/LICENSE +153 -0
- pytme-0.2.9.dist-info/top_level.txt +3 -0
- scripts/__init__.py +0 -0
- scripts/estimate_ram_usage.py +97 -0
- scripts/match_template.py +1135 -0
- scripts/postprocess.py +622 -0
- scripts/preprocess.py +209 -0
- scripts/preprocessor_gui.py +1227 -0
- tests/__init__.py +0 -0
- tests/data/Blurring/blob_width18.npy +0 -0
- tests/data/Blurring/edgegaussian_sigma3.npy +0 -0
- tests/data/Blurring/gaussian_sigma2.npy +0 -0
- tests/data/Blurring/hamming_width6.npy +0 -0
- tests/data/Blurring/kaiserb_width18.npy +0 -0
- tests/data/Blurring/localgaussian_sigma0510.npy +0 -0
- tests/data/Blurring/mean_size5.npy +0 -0
- tests/data/Blurring/ntree_sigma0510.npy +0 -0
- tests/data/Blurring/rank_rank3.npy +0 -0
- tests/data/Maps/.DS_Store +0 -0
- tests/data/Maps/emd_8621.mrc.gz +0 -0
- tests/data/README.md +2 -0
- tests/data/Raw/em_map.map +0 -0
- tests/data/Structures/.DS_Store +0 -0
- tests/data/Structures/1pdj.cif +3339 -0
- tests/data/Structures/1pdj.pdb +1429 -0
- tests/data/Structures/5khe.cif +3685 -0
- tests/data/Structures/5khe.ent +2210 -0
- tests/data/Structures/5khe.pdb +2210 -0
- tests/data/Structures/5uz4.cif +70548 -0
- tests/preprocessing/__init__.py +0 -0
- tests/preprocessing/test_compose.py +76 -0
- tests/preprocessing/test_frequency_filters.py +178 -0
- tests/preprocessing/test_preprocessor.py +136 -0
- tests/preprocessing/test_utils.py +79 -0
- tests/test_analyzer.py +216 -0
- tests/test_backends.py +446 -0
- tests/test_density.py +503 -0
- tests/test_extensions.py +130 -0
- tests/test_matching_cli.py +283 -0
- tests/test_matching_data.py +162 -0
- tests/test_matching_exhaustive.py +124 -0
- tests/test_matching_memory.py +30 -0
- tests/test_matching_optimization.py +226 -0
- tests/test_matching_utils.py +189 -0
- tests/test_orientations.py +175 -0
- tests/test_parser.py +33 -0
- tests/test_rotations.py +153 -0
- tests/test_structure.py +247 -0
- tme/__init__.py +6 -0
- tme/__version__.py +1 -0
- tme/analyzer/__init__.py +2 -0
- tme/analyzer/_utils.py +186 -0
- tme/analyzer/aggregation.py +577 -0
- tme/analyzer/peaks.py +953 -0
- tme/backends/__init__.py +171 -0
- tme/backends/_cupy_utils.py +734 -0
- tme/backends/_jax_utils.py +188 -0
- tme/backends/cupy_backend.py +294 -0
- tme/backends/jax_backend.py +314 -0
- tme/backends/matching_backend.py +1270 -0
- tme/backends/mlx_backend.py +241 -0
- tme/backends/npfftw_backend.py +583 -0
- tme/backends/pytorch_backend.py +430 -0
- tme/data/__init__.py +0 -0
- tme/data/c48n309.npy +0 -0
- tme/data/c48n527.npy +0 -0
- tme/data/c48n9.npy +0 -0
- tme/data/c48u1.npy +0 -0
- tme/data/c48u1153.npy +0 -0
- tme/data/c48u1201.npy +0 -0
- tme/data/c48u1641.npy +0 -0
- tme/data/c48u181.npy +0 -0
- tme/data/c48u2219.npy +0 -0
- tme/data/c48u27.npy +0 -0
- tme/data/c48u2947.npy +0 -0
- tme/data/c48u3733.npy +0 -0
- tme/data/c48u4749.npy +0 -0
- tme/data/c48u5879.npy +0 -0
- tme/data/c48u7111.npy +0 -0
- tme/data/c48u815.npy +0 -0
- tme/data/c48u83.npy +0 -0
- tme/data/c48u8649.npy +0 -0
- tme/data/c600v.npy +0 -0
- tme/data/c600vc.npy +0 -0
- tme/data/metadata.yaml +80 -0
- tme/data/quat_to_numpy.py +42 -0
- tme/data/scattering_factors.pickle +0 -0
- tme/density.py +2263 -0
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/external/bindings.cpp +332 -0
- tme/filters/__init__.py +6 -0
- tme/filters/_utils.py +311 -0
- tme/filters/bandpass.py +230 -0
- tme/filters/compose.py +81 -0
- tme/filters/ctf.py +393 -0
- tme/filters/reconstruction.py +160 -0
- tme/filters/wedge.py +542 -0
- tme/filters/whitening.py +191 -0
- tme/matching_data.py +863 -0
- tme/matching_exhaustive.py +497 -0
- tme/matching_optimization.py +1311 -0
- tme/matching_scores.py +1183 -0
- tme/matching_utils.py +1188 -0
- tme/memory.py +337 -0
- tme/orientations.py +598 -0
- tme/parser.py +685 -0
- tme/preprocessor.py +1329 -0
- tme/rotations.py +350 -0
- tme/structure.py +1864 -0
- tme/types.py +13 -0
tme/analyzer/peaks.py
ADDED
@@ -0,0 +1,953 @@
|
|
1
|
+
""" Implements classes to analyze outputs from exhaustive template matching.
|
2
|
+
|
3
|
+
Copyright (c) 2023 European Molecular Biology Laboratory
|
4
|
+
|
5
|
+
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
6
|
+
"""
|
7
|
+
|
8
|
+
from functools import wraps
|
9
|
+
from abc import ABC, abstractmethod
|
10
|
+
from typing import Tuple, List, Dict, Generator
|
11
|
+
|
12
|
+
import numpy as np
|
13
|
+
from skimage.feature import peak_local_max
|
14
|
+
from skimage.registration._phase_cross_correlation import _upsampled_dft
|
15
|
+
|
16
|
+
from ._utils import score_to_cart
|
17
|
+
from ..backends import backend as be
|
18
|
+
from ..matching_utils import split_shape
|
19
|
+
from ..types import BackendArray, NDArray
|
20
|
+
from ..rotations import euler_to_rotationmatrix
|
21
|
+
|
22
|
+
__all__ = [
|
23
|
+
"PeakCaller",
|
24
|
+
"PeakCallerSort",
|
25
|
+
"PeakCallerMaximumFilter",
|
26
|
+
"PeakCallerFast",
|
27
|
+
"PeakCallerRecursiveMasking",
|
28
|
+
"PeakCallerScipy",
|
29
|
+
"PeakClustering",
|
30
|
+
"filter_points",
|
31
|
+
"filter_points_indices",
|
32
|
+
]
|
33
|
+
|
34
|
+
PeakType = Tuple[BackendArray, BackendArray]
|
35
|
+
|
36
|
+
|
37
|
+
def _filter_bucket(
|
38
|
+
coordinates: BackendArray, min_distance: Tuple[float], scores: BackendArray = None
|
39
|
+
) -> BackendArray:
|
40
|
+
coordinates = be.subtract(coordinates, be.min(coordinates, axis=0))
|
41
|
+
bucket_indices = be.astype(be.divide(coordinates, min_distance), int)
|
42
|
+
multiplier = be.power(
|
43
|
+
be.max(bucket_indices, axis=0) + 1, be.arange(bucket_indices.shape[1])
|
44
|
+
)
|
45
|
+
bucket_indices = be.multiply(bucket_indices, multiplier, out=bucket_indices)
|
46
|
+
flattened_indices = be.sum(bucket_indices, axis=1)
|
47
|
+
|
48
|
+
if scores is not None:
|
49
|
+
_, inverse_indices = be.unique(flattened_indices, return_inverse=True)
|
50
|
+
|
51
|
+
# Avoid bucket index overlap
|
52
|
+
scores = be.subtract(scores, be.min(scores))
|
53
|
+
scores = be.divide(scores, be.max(scores) + 0.1, out=scores)
|
54
|
+
scores = be.subtract(inverse_indices, scores)
|
55
|
+
|
56
|
+
indices = be.argsort(scores)
|
57
|
+
sorted_buckets = inverse_indices[indices]
|
58
|
+
mask = sorted_buckets[1:] != sorted_buckets[:-1]
|
59
|
+
mask = be.concatenate((be.full((1,), fill_value=1, dtype=mask.dtype), mask))
|
60
|
+
return indices[mask]
|
61
|
+
|
62
|
+
_, unique_indices = be.unique(flattened_indices, return_index=True)
|
63
|
+
return unique_indices[be.argsort(unique_indices)]
|
64
|
+
|
65
|
+
|
66
|
+
def filter_points_indices(
|
67
|
+
coordinates: BackendArray,
|
68
|
+
min_distance: float,
|
69
|
+
bucket_cutoff: int = 1e5,
|
70
|
+
batch_dims: Tuple[int] = None,
|
71
|
+
scores: BackendArray = None,
|
72
|
+
) -> BackendArray:
|
73
|
+
from ..extensions import find_candidate_indices
|
74
|
+
|
75
|
+
if min_distance <= 0:
|
76
|
+
return be.arange(coordinates.shape[0])
|
77
|
+
|
78
|
+
n_coords = coordinates.shape[0]
|
79
|
+
if n_coords == 0:
|
80
|
+
return ()
|
81
|
+
|
82
|
+
if batch_dims is not None:
|
83
|
+
coordinates_new = be.zeros(coordinates.shape, coordinates.dtype)
|
84
|
+
coordinates_new[:] = coordinates
|
85
|
+
coordinates_new[..., batch_dims] = be.astype(
|
86
|
+
coordinates[..., batch_dims] * (2 * min_distance), coordinates_new.dtype
|
87
|
+
)
|
88
|
+
coordinates = coordinates_new
|
89
|
+
|
90
|
+
if isinstance(coordinates, np.ndarray) and n_coords < bucket_cutoff:
|
91
|
+
if scores is not None:
|
92
|
+
sorted_indices = np.argsort(-scores)
|
93
|
+
coordinates = coordinates[sorted_indices]
|
94
|
+
indices = find_candidate_indices(coordinates, min_distance)
|
95
|
+
if scores is not None:
|
96
|
+
return sorted_indices[indices]
|
97
|
+
elif n_coords > bucket_cutoff or not isinstance(coordinates, np.ndarray):
|
98
|
+
return _filter_bucket(coordinates, min_distance, scores)
|
99
|
+
|
100
|
+
distances = be.linalg.norm(coordinates[:, None] - coordinates, axis=-1)
|
101
|
+
distances = be.tril(distances)
|
102
|
+
keep = be.sum(distances > min_distance, axis=1)
|
103
|
+
indices = be.arange(coordinates.shape[0])
|
104
|
+
return indices[keep == indices]
|
105
|
+
|
106
|
+
|
107
|
+
def filter_points(
|
108
|
+
coordinates: NDArray, min_distance: Tuple[int], batch_dims: Tuple[int] = None
|
109
|
+
) -> BackendArray:
|
110
|
+
unique_indices = filter_points_indices(coordinates, min_distance, batch_dims)
|
111
|
+
coordinates = coordinates[unique_indices]
|
112
|
+
return coordinates
|
113
|
+
|
114
|
+
|
115
|
+
def batchify(shape: Tuple[int], batch_dims: Tuple[int] = None) -> List:
|
116
|
+
if batch_dims is None:
|
117
|
+
yield (tuple(slice(None) for _ in shape), tuple(0 for _ in shape))
|
118
|
+
return None
|
119
|
+
|
120
|
+
batch_ranges = [range(shape[dim]) for dim in batch_dims]
|
121
|
+
|
122
|
+
def _generate_slices_recursive(current_dim, current_indices):
|
123
|
+
if current_dim == len(batch_dims):
|
124
|
+
slice_list, offset_list, batch_index = [], [], 0
|
125
|
+
for i in range(len(shape)):
|
126
|
+
if i in batch_dims:
|
127
|
+
index = current_indices[batch_index]
|
128
|
+
slice_list.append(slice(index, index + 1))
|
129
|
+
offset_list.append(index)
|
130
|
+
batch_index += 1
|
131
|
+
else:
|
132
|
+
slice_list.append(slice(None))
|
133
|
+
offset_list.append(0)
|
134
|
+
yield (tuple(slice_list), tuple(offset_list))
|
135
|
+
else:
|
136
|
+
for index in batch_ranges[current_dim]:
|
137
|
+
yield from _generate_slices_recursive(
|
138
|
+
current_dim + 1, current_indices + (index,)
|
139
|
+
)
|
140
|
+
|
141
|
+
yield from _generate_slices_recursive(0, ())
|
142
|
+
|
143
|
+
|
144
|
+
class PeakCaller(ABC):
|
145
|
+
"""
|
146
|
+
Base class for peak calling algorithms.
|
147
|
+
|
148
|
+
Parameters
|
149
|
+
----------
|
150
|
+
shape : tuple of int
|
151
|
+
Score space shape. Used to determine dimension of peak calling problem.
|
152
|
+
num_peaks : int, optional
|
153
|
+
Number of candidate peaks to consider.
|
154
|
+
min_distance : int, optional
|
155
|
+
Minimum distance between peaks, 1 by default
|
156
|
+
min_boundary_distance : int, optional
|
157
|
+
Minimum distance to array boundaries, 0 by default.
|
158
|
+
min_score : float, optional
|
159
|
+
Minimum score from which to consider peaks.
|
160
|
+
max_score : float, optional
|
161
|
+
Maximum score upon which to consider peaks.
|
162
|
+
batch_dims : int, optional
|
163
|
+
Peak calling batch dimensions.
|
164
|
+
**kwargs
|
165
|
+
Optional keyword arguments.
|
166
|
+
|
167
|
+
Raises
|
168
|
+
------
|
169
|
+
ValueError
|
170
|
+
If num_peaks is less than or equal to zero.
|
171
|
+
If min_distances is less than zero.
|
172
|
+
"""
|
173
|
+
|
174
|
+
def __init__(
|
175
|
+
self,
|
176
|
+
shape: int,
|
177
|
+
num_peaks: int = 1000,
|
178
|
+
min_distance: int = 1,
|
179
|
+
min_boundary_distance: int = 0,
|
180
|
+
min_score: float = None,
|
181
|
+
max_score: float = None,
|
182
|
+
batch_dims: Tuple[int] = None,
|
183
|
+
shm_handler: object = None,
|
184
|
+
**kwargs,
|
185
|
+
):
|
186
|
+
if num_peaks <= 0:
|
187
|
+
raise ValueError("num_peaks has to be larger than 0.")
|
188
|
+
if min_distance < 0:
|
189
|
+
raise ValueError("min_distance has to be non-negative.")
|
190
|
+
if min_boundary_distance < 0:
|
191
|
+
raise ValueError("min_boundary_distance has to be non-negative.")
|
192
|
+
|
193
|
+
ndim = len(shape)
|
194
|
+
self.translations = be.full(
|
195
|
+
(num_peaks, ndim), fill_value=-1, dtype=be._int_dtype
|
196
|
+
)
|
197
|
+
self.rotations = be.full(
|
198
|
+
(num_peaks, ndim, ndim), fill_value=0, dtype=be._float_dtype
|
199
|
+
)
|
200
|
+
self.scores = be.full((num_peaks,), fill_value=0, dtype=be._float_dtype)
|
201
|
+
self.details = be.full((num_peaks,), fill_value=0, dtype=be._float_dtype)
|
202
|
+
|
203
|
+
self.num_peaks = int(num_peaks)
|
204
|
+
self.min_distance = int(min_distance)
|
205
|
+
self.min_boundary_distance = int(min_boundary_distance)
|
206
|
+
|
207
|
+
self.batch_dims = batch_dims
|
208
|
+
if batch_dims is not None:
|
209
|
+
self.batch_dims = tuple(int(x) for x in self.batch_dims)
|
210
|
+
|
211
|
+
self.min_score, self.max_score = min_score, max_score
|
212
|
+
|
213
|
+
# Postprocessing arguments
|
214
|
+
self.fourier_shift = kwargs.get("fourier_shift", None)
|
215
|
+
self.convolution_mode = kwargs.get("convolution_mode", None)
|
216
|
+
self.targetshape = kwargs.get("targetshape", None)
|
217
|
+
self.templateshape = kwargs.get("templateshape", None)
|
218
|
+
|
219
|
+
def __iter__(self) -> Generator:
|
220
|
+
"""
|
221
|
+
Returns a generator to list objects containing translation,
|
222
|
+
rotation, score and details of a given candidate.
|
223
|
+
"""
|
224
|
+
self.peak_list = [
|
225
|
+
be.to_cpu_array(self.translations),
|
226
|
+
be.to_cpu_array(self.rotations),
|
227
|
+
be.to_cpu_array(self.scores),
|
228
|
+
be.to_cpu_array(self.details),
|
229
|
+
]
|
230
|
+
yield from self.peak_list
|
231
|
+
|
232
|
+
def _get_peak_mask(self, peaks: BackendArray, scores: BackendArray) -> BackendArray:
|
233
|
+
if not len(peaks):
|
234
|
+
return None
|
235
|
+
|
236
|
+
valid_peaks = be.full((peaks.shape[0],), fill_value=1) == 1
|
237
|
+
if self.min_boundary_distance > 0:
|
238
|
+
upper_limit = be.subtract(
|
239
|
+
be.to_backend_array(scores.shape), self.min_boundary_distance
|
240
|
+
)
|
241
|
+
valid_peaks = be.multiply(
|
242
|
+
peaks < upper_limit,
|
243
|
+
peaks >= self.min_boundary_distance,
|
244
|
+
)
|
245
|
+
if self.batch_dims is not None:
|
246
|
+
valid_peaks[..., self.batch_dims] = True
|
247
|
+
|
248
|
+
valid_peaks = be.sum(valid_peaks, axis=1) == peaks.shape[1]
|
249
|
+
|
250
|
+
# Score thresholds and nan removal
|
251
|
+
peak_scores = scores[tuple(peaks.T)]
|
252
|
+
valid_peaks = be.multiply(peak_scores == peak_scores, valid_peaks)
|
253
|
+
if self.min_score is not None:
|
254
|
+
valid_peaks = be.multiply(peak_scores >= self.min_score, valid_peaks)
|
255
|
+
|
256
|
+
if self.max_score is not None:
|
257
|
+
valid_peaks = be.multiply(peak_scores <= self.max_score, valid_peaks)
|
258
|
+
|
259
|
+
if be.sum(valid_peaks) == 0:
|
260
|
+
return None
|
261
|
+
|
262
|
+
# Ensure consistent upper limit of input peaks for _update step
|
263
|
+
if (
|
264
|
+
be.sum(valid_peaks) > self.num_peaks
|
265
|
+
or peak_scores.shape[0] > self.num_peaks
|
266
|
+
):
|
267
|
+
peak_indices = self._top_peaks(
|
268
|
+
peaks, scores=peak_scores * valid_peaks, num_peaks=2 * self.num_peaks
|
269
|
+
)
|
270
|
+
valid_peaks = be.full(peak_scores.shape, 0, bool)
|
271
|
+
valid_peaks[peak_indices] = True
|
272
|
+
|
273
|
+
if self.min_score is not None:
|
274
|
+
valid_peaks = be.multiply(peak_scores >= self.min_score, valid_peaks)
|
275
|
+
|
276
|
+
if self.max_score is not None:
|
277
|
+
valid_peaks = be.multiply(peak_scores <= self.max_score, valid_peaks)
|
278
|
+
|
279
|
+
if valid_peaks.shape[0] != peaks.shape[0]:
|
280
|
+
return None
|
281
|
+
return valid_peaks
|
282
|
+
|
283
|
+
def _apply_over_batch(func):
|
284
|
+
@wraps(func)
|
285
|
+
def wrapper(self, scores, rotation_matrix, **kwargs):
|
286
|
+
for subset, batch_offset in batchify(scores.shape, self.batch_dims):
|
287
|
+
yield func(
|
288
|
+
self,
|
289
|
+
scores=scores[subset],
|
290
|
+
rotation_matrix=rotation_matrix,
|
291
|
+
batch_offset=batch_offset,
|
292
|
+
**kwargs,
|
293
|
+
)
|
294
|
+
|
295
|
+
return wrapper
|
296
|
+
|
297
|
+
@_apply_over_batch
|
298
|
+
def _call_peaks(self, scores, rotation_matrix, batch_offset=None, **kwargs):
|
299
|
+
peak_positions, peak_details = self.call_peaks(
|
300
|
+
scores=scores,
|
301
|
+
rotation_matrix=rotation_matrix,
|
302
|
+
min_score=self.min_score,
|
303
|
+
max_score=self.max_score,
|
304
|
+
batch_offset=batch_offset,
|
305
|
+
**kwargs,
|
306
|
+
)
|
307
|
+
if peak_positions is None:
|
308
|
+
return None, None
|
309
|
+
|
310
|
+
peak_positions = be.to_backend_array(peak_positions)
|
311
|
+
if batch_offset is not None:
|
312
|
+
batch_offset = be.to_backend_array(batch_offset)
|
313
|
+
peak_positions = be.add(peak_positions, batch_offset, out=peak_positions)
|
314
|
+
|
315
|
+
peak_positions = be.astype(peak_positions, int)
|
316
|
+
return peak_positions, peak_details
|
317
|
+
|
318
|
+
def __call__(self, scores: BackendArray, rotation_matrix: BackendArray, **kwargs):
|
319
|
+
"""
|
320
|
+
Update the internal parameter store based on input array.
|
321
|
+
|
322
|
+
Parameters
|
323
|
+
----------
|
324
|
+
scores : BackendArray
|
325
|
+
Score space data.
|
326
|
+
rotation_matrix : BackendArray
|
327
|
+
Rotation matrix used to obtain the score array.
|
328
|
+
**kwargs
|
329
|
+
Optional keyword aguments passed to :py:meth:`PeakCaller.call_peaks`.
|
330
|
+
"""
|
331
|
+
for ret in self._call_peaks(scores=scores, rotation_matrix=rotation_matrix):
|
332
|
+
peak_positions, peak_details = ret
|
333
|
+
if peak_positions is None:
|
334
|
+
continue
|
335
|
+
|
336
|
+
valid_peaks = self._get_peak_mask(peaks=peak_positions, scores=scores)
|
337
|
+
if valid_peaks is None:
|
338
|
+
continue
|
339
|
+
|
340
|
+
peak_positions = peak_positions[valid_peaks]
|
341
|
+
peak_scores = scores[tuple(peak_positions.T)]
|
342
|
+
if peak_details is not None:
|
343
|
+
peak_details = peak_details[valid_peaks]
|
344
|
+
# peak_details, peak_scores = peak_scores, -peak_details
|
345
|
+
else:
|
346
|
+
peak_details = be.full(peak_scores.shape, fill_value=-1)
|
347
|
+
|
348
|
+
rotations = be.repeat(
|
349
|
+
rotation_matrix.reshape(1, *rotation_matrix.shape),
|
350
|
+
peak_positions.shape[0],
|
351
|
+
axis=0,
|
352
|
+
)
|
353
|
+
|
354
|
+
self._update(
|
355
|
+
peak_positions=peak_positions,
|
356
|
+
peak_details=peak_details,
|
357
|
+
peak_scores=peak_scores,
|
358
|
+
rotations=rotations,
|
359
|
+
)
|
360
|
+
|
361
|
+
return None
|
362
|
+
|
363
|
+
@abstractmethod
|
364
|
+
def call_peaks(self, scores: BackendArray, **kwargs) -> PeakType:
|
365
|
+
"""
|
366
|
+
Call peaks in the score space.
|
367
|
+
|
368
|
+
Parameters
|
369
|
+
----------
|
370
|
+
scores : BackendArray
|
371
|
+
Score array.
|
372
|
+
**kwargs : dict
|
373
|
+
Optional keyword arguments passed to underlying implementations.
|
374
|
+
|
375
|
+
Returns
|
376
|
+
-------
|
377
|
+
Tuple[BackendArray, BackendArray]
|
378
|
+
Array of peak coordinates and peak details.
|
379
|
+
"""
|
380
|
+
|
381
|
+
@classmethod
|
382
|
+
def merge(cls, candidates=List[List], **kwargs) -> Tuple:
|
383
|
+
"""
|
384
|
+
Merge multiple instances of :py:class:`PeakCaller`.
|
385
|
+
|
386
|
+
Parameters
|
387
|
+
----------
|
388
|
+
candidates : list of lists
|
389
|
+
Obtained by invoking list on the generator returned by __iter__.
|
390
|
+
**kwargs
|
391
|
+
Optional keyword arguments.
|
392
|
+
|
393
|
+
Returns
|
394
|
+
-------
|
395
|
+
Tuple
|
396
|
+
Tuple of translation, rotation, score and details of candidates.
|
397
|
+
"""
|
398
|
+
if "shape" not in kwargs:
|
399
|
+
kwargs["shape"] = tuple(1 for _ in range(candidates[0][0].shape[1]))
|
400
|
+
|
401
|
+
base = cls(**kwargs)
|
402
|
+
for candidate in candidates:
|
403
|
+
if len(candidate) == 0:
|
404
|
+
continue
|
405
|
+
peak_positions, rotations, peak_scores, peak_details = candidate
|
406
|
+
base._update(
|
407
|
+
peak_positions=be.to_backend_array(peak_positions),
|
408
|
+
peak_details=be.to_backend_array(peak_details),
|
409
|
+
peak_scores=be.to_backend_array(peak_scores),
|
410
|
+
rotations=be.to_backend_array(rotations),
|
411
|
+
offset=kwargs.get("offset", None),
|
412
|
+
)
|
413
|
+
return tuple(base)
|
414
|
+
|
415
|
+
@staticmethod
|
416
|
+
def oversample_peaks(
|
417
|
+
scores: BackendArray, peak_positions: BackendArray, oversampling_factor: int = 8
|
418
|
+
):
|
419
|
+
"""
|
420
|
+
Refines peaks positions in the corresponding score space.
|
421
|
+
|
422
|
+
Parameters
|
423
|
+
----------
|
424
|
+
scores : BackendArray
|
425
|
+
The d-dimensional array representing the score space.
|
426
|
+
peak_positions : BackendArray
|
427
|
+
An array of shape (n, d) containing the peak coordinates
|
428
|
+
to be refined, where n is the number of peaks and d is the
|
429
|
+
dimensionality of the score space.
|
430
|
+
oversampling_factor : int, optional
|
431
|
+
The oversampling factor for Fourier transforms. Defaults to 8.
|
432
|
+
|
433
|
+
Returns
|
434
|
+
-------
|
435
|
+
BackendArray
|
436
|
+
An array of shape (n, d) containing the refined subpixel
|
437
|
+
coordinates of the peaks.
|
438
|
+
|
439
|
+
Notes
|
440
|
+
-----
|
441
|
+
Floating point peak positions are determined by oversampling the
|
442
|
+
scores around peak_positions. The accuracy
|
443
|
+
of refinement scales with 1 / oversampling_factor.
|
444
|
+
|
445
|
+
References
|
446
|
+
----------
|
447
|
+
.. [1] https://scikit-image.org/docs/stable/api/skimage.registration.html
|
448
|
+
.. [2] Manuel Guizar-Sicairos, Samuel T. Thurman, and
|
449
|
+
James R. Fienup, “Efficient subpixel image registration
|
450
|
+
algorithms,” Optics Letters 33, 156-158 (2008).
|
451
|
+
DOI:10.1364/OL.33.000156
|
452
|
+
|
453
|
+
"""
|
454
|
+
scores = be.to_numpy_array(scores)
|
455
|
+
peak_positions = be.to_numpy_array(peak_positions)
|
456
|
+
|
457
|
+
peak_positions = np.round(
|
458
|
+
np.divide(
|
459
|
+
np.multiply(peak_positions, oversampling_factor), oversampling_factor
|
460
|
+
)
|
461
|
+
)
|
462
|
+
upsampled_region_size = np.ceil(np.multiply(oversampling_factor, 1.5))
|
463
|
+
dftshift = np.round(np.divide(upsampled_region_size, 2.0))
|
464
|
+
sample_region_offset = np.subtract(
|
465
|
+
dftshift, np.multiply(peak_positions, oversampling_factor)
|
466
|
+
)
|
467
|
+
|
468
|
+
scores_ft = np.fft.fftn(scores).conj()
|
469
|
+
for index in range(sample_region_offset.shape[0]):
|
470
|
+
cross_correlation_upsampled = _upsampled_dft(
|
471
|
+
data=scores_ft,
|
472
|
+
upsampled_region_size=upsampled_region_size,
|
473
|
+
upsample_factor=oversampling_factor,
|
474
|
+
axis_offsets=sample_region_offset[index],
|
475
|
+
).conj()
|
476
|
+
|
477
|
+
maxima = np.unravel_index(
|
478
|
+
np.argmax(np.abs(cross_correlation_upsampled)),
|
479
|
+
cross_correlation_upsampled.shape,
|
480
|
+
)
|
481
|
+
maxima = np.divide(np.subtract(maxima, dftshift), oversampling_factor)
|
482
|
+
peak_positions[index] = np.add(peak_positions[index], maxima)
|
483
|
+
|
484
|
+
peak_positions = be.to_backend_array(peak_positions)
|
485
|
+
|
486
|
+
return peak_positions
|
487
|
+
|
488
|
+
def _top_peaks(self, positions, scores, num_peaks: int = None):
|
489
|
+
num_peaks = be.size(scores) if not num_peaks else num_peaks
|
490
|
+
|
491
|
+
if self.batch_dims is None:
|
492
|
+
top_n = min(be.size(scores), num_peaks)
|
493
|
+
top_scores, *_ = be.topk_indices(scores, top_n)
|
494
|
+
return top_scores
|
495
|
+
|
496
|
+
# Not very performant but fairly robust
|
497
|
+
batch_indices = positions[..., self.batch_dims]
|
498
|
+
batch_indices = be.subtract(batch_indices, be.min(batch_indices, axis=0))
|
499
|
+
multiplier = be.power(
|
500
|
+
be.max(batch_indices, axis=0) + 1,
|
501
|
+
be.arange(batch_indices.shape[1]),
|
502
|
+
)
|
503
|
+
batch_indices = be.multiply(batch_indices, multiplier, out=batch_indices)
|
504
|
+
batch_indices = be.sum(batch_indices, axis=1)
|
505
|
+
unique_indices, batch_counts = be.unique(batch_indices, return_counts=True)
|
506
|
+
total_indices = be.arange(scores.shape[0])
|
507
|
+
batch_indices = [total_indices[batch_indices == x] for x in unique_indices]
|
508
|
+
top_scores = be.concatenate(
|
509
|
+
[
|
510
|
+
total_indices[indices][
|
511
|
+
be.topk_indices(scores[indices], min(y, num_peaks))
|
512
|
+
]
|
513
|
+
for indices, y in zip(batch_indices, batch_counts)
|
514
|
+
]
|
515
|
+
)
|
516
|
+
return top_scores
|
517
|
+
|
518
|
+
def _update(
|
519
|
+
self,
|
520
|
+
peak_positions: BackendArray,
|
521
|
+
peak_details: BackendArray,
|
522
|
+
peak_scores: BackendArray,
|
523
|
+
rotations: BackendArray,
|
524
|
+
offset: BackendArray = None,
|
525
|
+
):
|
526
|
+
"""
|
527
|
+
Update internal parameter store.
|
528
|
+
|
529
|
+
Parameters
|
530
|
+
----------
|
531
|
+
peak_positions : BackendArray
|
532
|
+
Position of peaks (n, d).
|
533
|
+
peak_details : BackendArray
|
534
|
+
Details of each peak (n, ).
|
535
|
+
peak_scores: BackendArray
|
536
|
+
Score at each peak (n,).
|
537
|
+
rotations: BackendArray
|
538
|
+
Rotation at each peak (n, d, d).
|
539
|
+
offset : BackendArray, optional
|
540
|
+
Translation offset, e.g. from splitting, (d, ).
|
541
|
+
"""
|
542
|
+
if offset is not None:
|
543
|
+
offset = be.astype(be.to_backend_array(offset), peak_positions.dtype)
|
544
|
+
peak_positions = be.add(peak_positions, offset, out=peak_positions)
|
545
|
+
|
546
|
+
positions = be.concatenate((self.translations, peak_positions))
|
547
|
+
rotations = be.concatenate((self.rotations, rotations))
|
548
|
+
scores = be.concatenate((self.scores, peak_scores))
|
549
|
+
details = be.concatenate((self.details, peak_details))
|
550
|
+
|
551
|
+
# topk filtering after distances yields more distributed peak calls
|
552
|
+
distance_order = filter_points_indices(
|
553
|
+
coordinates=positions,
|
554
|
+
min_distance=self.min_distance,
|
555
|
+
batch_dims=self.batch_dims,
|
556
|
+
scores=scores,
|
557
|
+
)
|
558
|
+
|
559
|
+
top_scores = self._top_peaks(
|
560
|
+
positions[distance_order, :], scores[distance_order], self.num_peaks
|
561
|
+
)
|
562
|
+
final_order = distance_order[top_scores]
|
563
|
+
|
564
|
+
self.translations = positions[final_order, :]
|
565
|
+
self.rotations = rotations[final_order, :]
|
566
|
+
self.scores = scores[final_order]
|
567
|
+
self.details = details[final_order]
|
568
|
+
|
569
|
+
def _postprocess(self, **kwargs):
|
570
|
+
if not len(self.translations):
|
571
|
+
return self
|
572
|
+
|
573
|
+
positions, valid_peaks = score_to_cart(self.translations, **kwargs)
|
574
|
+
|
575
|
+
self.translations = positions[valid_peaks]
|
576
|
+
self.rotations = self.rotations[valid_peaks]
|
577
|
+
self.scores = self.scores[valid_peaks]
|
578
|
+
self.details = self.details[valid_peaks]
|
579
|
+
return self
|
580
|
+
|
581
|
+
|
582
|
+
class PeakCallerSort(PeakCaller):
|
583
|
+
"""
|
584
|
+
A :py:class:`PeakCaller` subclass that first selects ``num_peaks``
|
585
|
+
highest scores.
|
586
|
+
"""
|
587
|
+
|
588
|
+
def call_peaks(self, scores: BackendArray, **kwargs) -> PeakType:
|
589
|
+
flat_scores = scores.reshape(-1)
|
590
|
+
k = min(self.num_peaks, be.size(flat_scores))
|
591
|
+
|
592
|
+
top_k_indices, *_ = be.topk_indices(flat_scores, k)
|
593
|
+
|
594
|
+
coordinates = be.unravel_index(top_k_indices, scores.shape)
|
595
|
+
coordinates = be.transpose(be.stack(coordinates))
|
596
|
+
|
597
|
+
return coordinates, None
|
598
|
+
|
599
|
+
|
600
|
+
class PeakCallerMaximumFilter(PeakCaller):
|
601
|
+
"""
|
602
|
+
Find local maxima by applying a maximum filter and enforcing a distance
|
603
|
+
constraint subsequently. This is similar to the strategy implemented in
|
604
|
+
:obj:`skimage.feature.peak_local_max`.
|
605
|
+
"""
|
606
|
+
|
607
|
+
def call_peaks(self, scores: BackendArray, **kwargs) -> PeakType:
|
608
|
+
return be.max_filter_coordinates(scores, self.min_distance), None
|
609
|
+
|
610
|
+
|
611
|
+
class PeakCallerFast(PeakCaller):
|
612
|
+
"""
|
613
|
+
Subdivides the score space into squares with edge length ``min_distance``
|
614
|
+
and determiens maximum value for each. In a second pass, all local maxima
|
615
|
+
that are not the local maxima in a ``min_distance`` square centered around them
|
616
|
+
are removed.
|
617
|
+
|
618
|
+
"""
|
619
|
+
|
620
|
+
def call_peaks(self, scores: BackendArray, **kwargs) -> PeakType:
|
621
|
+
splits = {i: x // self.min_distance for i, x in enumerate(scores.shape)}
|
622
|
+
slices = split_shape(scores.shape, splits)
|
623
|
+
|
624
|
+
coordinates = be.to_backend_array(
|
625
|
+
[
|
626
|
+
be.unravel_index(be.argmax(scores[subvol]), scores[subvol].shape)
|
627
|
+
for subvol in slices
|
628
|
+
]
|
629
|
+
)
|
630
|
+
offset = be.to_backend_array(
|
631
|
+
[tuple(x.start for x in subvol) for subvol in slices]
|
632
|
+
)
|
633
|
+
be.add(coordinates, offset, out=coordinates)
|
634
|
+
coordinates = coordinates[be.argsort(-scores[tuple(coordinates.T)])]
|
635
|
+
|
636
|
+
if coordinates.shape[0] == 0:
|
637
|
+
return None
|
638
|
+
|
639
|
+
starts = be.maximum(coordinates - self.min_distance, 0)
|
640
|
+
stops = be.minimum(coordinates + self.min_distance, scores.shape)
|
641
|
+
slices_list = [
|
642
|
+
tuple(slice(*coord) for coord in zip(start_row, stop_row))
|
643
|
+
for start_row, stop_row in zip(starts, stops)
|
644
|
+
]
|
645
|
+
|
646
|
+
keep = [
|
647
|
+
score_subvol >= be.max(scores[subvol])
|
648
|
+
for subvol, score_subvol in zip(slices_list, scores[tuple(coordinates.T)])
|
649
|
+
]
|
650
|
+
coordinates = coordinates[keep,]
|
651
|
+
|
652
|
+
if len(coordinates) == 0:
|
653
|
+
return coordinates, None
|
654
|
+
|
655
|
+
return coordinates, None
|
656
|
+
|
657
|
+
|
658
|
+
class PeakCallerRecursiveMasking(PeakCaller):
|
659
|
+
"""
|
660
|
+
Identifies peaks iteratively by selecting the top score and masking
|
661
|
+
a region around it.
|
662
|
+
"""
|
663
|
+
|
664
|
+
def call_peaks(
|
665
|
+
self,
|
666
|
+
scores: BackendArray,
|
667
|
+
rotation_matrix: BackendArray,
|
668
|
+
mask: BackendArray = None,
|
669
|
+
min_score: float = None,
|
670
|
+
rotations: BackendArray = None,
|
671
|
+
rotation_mapping: Dict = None,
|
672
|
+
**kwargs,
|
673
|
+
) -> PeakType:
|
674
|
+
"""
|
675
|
+
Call peaks in the score space.
|
676
|
+
|
677
|
+
Parameters
|
678
|
+
----------
|
679
|
+
scores : BackendArray
|
680
|
+
Data array of scores.
|
681
|
+
rotation_matrix : BackendArray
|
682
|
+
Rotation matrix.
|
683
|
+
mask : BackendArray, optional
|
684
|
+
Mask array, by default None.
|
685
|
+
rotations : BackendArray, optional
|
686
|
+
Rotation space array, by default None.
|
687
|
+
rotation_mapping : Dict optional
|
688
|
+
Dictionary mapping values in rotations to Euler angles.
|
689
|
+
By default None
|
690
|
+
min_score : float
|
691
|
+
Minimum score value to consider. If provided, superseeds limit given
|
692
|
+
by :py:attr:`PeakCaller.num_peaks`.
|
693
|
+
|
694
|
+
Returns
|
695
|
+
-------
|
696
|
+
Tuple[BackendArray, BackendArray]
|
697
|
+
Array of peak coordinates and peak details.
|
698
|
+
|
699
|
+
Notes
|
700
|
+
-----
|
701
|
+
By default, scores are masked using a box with edge length self.min_distance.
|
702
|
+
If mask is provided, elements around each peak will be multiplied by the mask
|
703
|
+
values. If rotations and rotation_mapping is provided, the respective
|
704
|
+
rotation will be applied to the mask, otherwise rotation_matrix is used.
|
705
|
+
"""
|
706
|
+
coordinates, masking_function = [], self._mask_scores_rotate
|
707
|
+
|
708
|
+
if mask is None:
|
709
|
+
masking_function = self._mask_scores_box
|
710
|
+
shape = tuple(self.min_distance for _ in range(scores.ndim))
|
711
|
+
mask = be.zeros(shape, dtype=be._float_dtype)
|
712
|
+
|
713
|
+
rotated_template = be.zeros(mask.shape, dtype=mask.dtype)
|
714
|
+
|
715
|
+
peak_limit = self.num_peaks
|
716
|
+
if min_score is not None:
|
717
|
+
peak_limit = be.size(scores)
|
718
|
+
else:
|
719
|
+
min_score = be.min(scores) - 1
|
720
|
+
|
721
|
+
scores_copy = be.zeros(scores.shape, dtype=scores.dtype)
|
722
|
+
scores_copy[:] = scores
|
723
|
+
|
724
|
+
while True:
|
725
|
+
be.argmax(scores_copy)
|
726
|
+
peak = be.unravel_index(
|
727
|
+
indices=be.argmax(scores_copy), shape=scores_copy.shape
|
728
|
+
)
|
729
|
+
if scores_copy[tuple(peak)] < min_score:
|
730
|
+
break
|
731
|
+
|
732
|
+
coordinates.append(peak)
|
733
|
+
|
734
|
+
current_rotation_matrix = self._get_rotation_matrix(
|
735
|
+
peak=peak,
|
736
|
+
rotation_space=rotations,
|
737
|
+
rotation_mapping=rotation_mapping,
|
738
|
+
rotation_matrix=rotation_matrix,
|
739
|
+
)
|
740
|
+
|
741
|
+
masking_function(
|
742
|
+
scores=scores_copy,
|
743
|
+
rotation_matrix=current_rotation_matrix,
|
744
|
+
peak=peak,
|
745
|
+
mask=mask,
|
746
|
+
rotated_template=rotated_template,
|
747
|
+
)
|
748
|
+
|
749
|
+
if len(coordinates) >= peak_limit:
|
750
|
+
break
|
751
|
+
|
752
|
+
peaks = be.to_backend_array(coordinates)
|
753
|
+
return peaks, None
|
754
|
+
|
755
|
+
@staticmethod
|
756
|
+
def _get_rotation_matrix(
|
757
|
+
peak: BackendArray,
|
758
|
+
rotation_space: BackendArray,
|
759
|
+
rotation_mapping: BackendArray,
|
760
|
+
rotation_matrix: BackendArray,
|
761
|
+
) -> BackendArray:
|
762
|
+
"""
|
763
|
+
Get rotation matrix based on peak and rotation data.
|
764
|
+
|
765
|
+
Parameters
|
766
|
+
----------
|
767
|
+
peak : BackendArray
|
768
|
+
Peak coordinates.
|
769
|
+
rotation_space : BackendArray
|
770
|
+
Rotation space array.
|
771
|
+
rotation_mapping : Dict
|
772
|
+
Dictionary mapping values in rotation_space to Euler angles.
|
773
|
+
rotation_matrix : BackendArray
|
774
|
+
Current rotation matrix.
|
775
|
+
|
776
|
+
Returns
|
777
|
+
-------
|
778
|
+
BackendArray
|
779
|
+
Rotation matrix.
|
780
|
+
"""
|
781
|
+
if rotation_space is None or rotation_mapping is None:
|
782
|
+
return rotation_matrix
|
783
|
+
|
784
|
+
rotation = rotation_mapping[rotation_space[tuple(peak)]]
|
785
|
+
|
786
|
+
# TODO: Newer versions of rotation mapping contain rotation matrices not angles
|
787
|
+
if rotation.ndim != 2:
|
788
|
+
rotation = be.to_backend_array(
|
789
|
+
euler_to_rotationmatrix(be.to_numpy_array(rotation))
|
790
|
+
)
|
791
|
+
return rotation
|
792
|
+
|
793
|
+
@staticmethod
|
794
|
+
def _mask_scores_box(
|
795
|
+
scores: BackendArray, peak: BackendArray, mask: BackendArray, **kwargs: Dict
|
796
|
+
) -> None:
|
797
|
+
"""
|
798
|
+
Mask scores in a box around a peak.
|
799
|
+
|
800
|
+
Parameters
|
801
|
+
----------
|
802
|
+
scores : BackendArray
|
803
|
+
Data array of scores.
|
804
|
+
peak : BackendArray
|
805
|
+
Peak coordinates.
|
806
|
+
mask : BackendArray
|
807
|
+
Mask array.
|
808
|
+
"""
|
809
|
+
start = be.maximum(be.subtract(peak, mask.shape), 0)
|
810
|
+
stop = be.minimum(be.add(peak, mask.shape), scores.shape)
|
811
|
+
start, stop = be.astype(start, int), be.astype(stop, int)
|
812
|
+
coords = tuple(slice(*pos) for pos in zip(start, stop))
|
813
|
+
scores[coords] = 0
|
814
|
+
return None
|
815
|
+
|
816
|
+
@staticmethod
|
817
|
+
def _mask_scores_rotate(
|
818
|
+
scores: BackendArray,
|
819
|
+
peak: BackendArray,
|
820
|
+
mask: BackendArray,
|
821
|
+
rotated_template: BackendArray,
|
822
|
+
rotation_matrix: BackendArray,
|
823
|
+
**kwargs: Dict,
|
824
|
+
) -> None:
|
825
|
+
"""
|
826
|
+
Mask scores using mask rotation around a peak.
|
827
|
+
|
828
|
+
Parameters
|
829
|
+
----------
|
830
|
+
scores : BackendArray
|
831
|
+
Data array of scores.
|
832
|
+
peak : BackendArray
|
833
|
+
Peak coordinates.
|
834
|
+
mask : BackendArray
|
835
|
+
Mask array.
|
836
|
+
rotated_template : BackendArray
|
837
|
+
Empty array to write mask rotations to.
|
838
|
+
rotation_matrix : BackendArray
|
839
|
+
Rotation matrix.
|
840
|
+
"""
|
841
|
+
left_pad = be.divide(mask.shape, 2).astype(int)
|
842
|
+
right_pad = be.add(left_pad, be.mod(mask.shape, 2).astype(int))
|
843
|
+
|
844
|
+
score_start = be.subtract(peak, left_pad)
|
845
|
+
score_stop = be.add(peak, right_pad)
|
846
|
+
|
847
|
+
template_start = be.subtract(be.maximum(score_start, 0), score_start)
|
848
|
+
template_stop = be.subtract(score_stop, be.minimum(score_stop, scores.shape))
|
849
|
+
template_stop = be.subtract(mask.shape, template_stop)
|
850
|
+
|
851
|
+
score_start = be.maximum(score_start, 0)
|
852
|
+
score_stop = be.minimum(score_stop, scores.shape)
|
853
|
+
score_start = be.astype(score_start, int)
|
854
|
+
score_stop = be.astype(score_stop, int)
|
855
|
+
|
856
|
+
template_start = be.astype(template_start, int)
|
857
|
+
template_stop = be.astype(template_stop, int)
|
858
|
+
coords_score = tuple(slice(*pos) for pos in zip(score_start, score_stop))
|
859
|
+
coords_template = tuple(
|
860
|
+
slice(*pos) for pos in zip(template_start, template_stop)
|
861
|
+
)
|
862
|
+
|
863
|
+
rotated_template.fill(0)
|
864
|
+
be.rigid_transform(
|
865
|
+
arr=mask, rotation_matrix=rotation_matrix, order=1, out=rotated_template
|
866
|
+
)
|
867
|
+
|
868
|
+
scores[coords_score] = be.multiply(
|
869
|
+
scores[coords_score], (rotated_template[coords_template] <= 0.1)
|
870
|
+
)
|
871
|
+
return None
|
872
|
+
|
873
|
+
|
874
|
+
class PeakCallerScipy(PeakCaller):
|
875
|
+
"""
|
876
|
+
Peak calling using :obj:`skimage.feature.peak_local_max` to compute local maxima.
|
877
|
+
"""
|
878
|
+
|
879
|
+
def call_peaks(
|
880
|
+
self, scores: BackendArray, min_score: float = None, **kwargs
|
881
|
+
) -> PeakType:
|
882
|
+
scores = be.to_numpy_array(scores)
|
883
|
+
num_peaks = self.num_peaks
|
884
|
+
if min_score is not None:
|
885
|
+
num_peaks = np.inf
|
886
|
+
|
887
|
+
non_squeezable_dims = tuple(i for i, x in enumerate(scores.shape) if x != 1)
|
888
|
+
peaks = peak_local_max(
|
889
|
+
np.squeeze(scores),
|
890
|
+
num_peaks=num_peaks,
|
891
|
+
min_distance=self.min_distance,
|
892
|
+
threshold_abs=min_score,
|
893
|
+
)
|
894
|
+
peaks_full = np.zeros((peaks.shape[0], scores.ndim), peaks.dtype)
|
895
|
+
peaks_full[..., non_squeezable_dims] = peaks[:]
|
896
|
+
peaks = be.to_backend_array(peaks_full)
|
897
|
+
return peaks, None
|
898
|
+
|
899
|
+
|
900
|
+
class PeakClustering(PeakCallerSort):
|
901
|
+
"""
|
902
|
+
Use DBScan clustering to identify more reliable peaks.
|
903
|
+
"""
|
904
|
+
|
905
|
+
def __init__(
|
906
|
+
self,
|
907
|
+
num_peaks: int = 1000,
|
908
|
+
**kwargs,
|
909
|
+
):
|
910
|
+
kwargs["min_distance"] = 0
|
911
|
+
super().__init__(num_peaks=num_peaks, **kwargs)
|
912
|
+
|
913
|
+
@classmethod
|
914
|
+
def merge(cls, **kwargs) -> NDArray:
|
915
|
+
"""
|
916
|
+
Merge multiple instances of Analyzer.
|
917
|
+
|
918
|
+
Parameters
|
919
|
+
----------
|
920
|
+
**kwargs
|
921
|
+
Optional keyword arguments passed to :py:meth:`PeakCaller.merge`.
|
922
|
+
|
923
|
+
Returns
|
924
|
+
-------
|
925
|
+
NDArray
|
926
|
+
NDArray of candidates.
|
927
|
+
"""
|
928
|
+
from sklearn.cluster import DBSCAN
|
929
|
+
from ..extensions import max_index_by_label
|
930
|
+
|
931
|
+
peaks, rotations, scores, details = super().merge(**kwargs)
|
932
|
+
|
933
|
+
scores = np.array([candidate[2] for candidate in peaks])
|
934
|
+
clusters = DBSCAN(eps=np.finfo(float).eps, min_samples=8).fit(peaks)
|
935
|
+
labels = clusters.labels_.astype(int)
|
936
|
+
|
937
|
+
label_max = max_index_by_label(labels=labels, scores=scores)
|
938
|
+
if -1 in label_max:
|
939
|
+
_ = label_max.pop(-1)
|
940
|
+
representatives = set(label_max.values())
|
941
|
+
|
942
|
+
keep = np.array(
|
943
|
+
[
|
944
|
+
True if index in representatives else False
|
945
|
+
for index in range(peaks.shape[0])
|
946
|
+
]
|
947
|
+
)
|
948
|
+
peaks = peaks[keep,]
|
949
|
+
rotations = rotations[keep,]
|
950
|
+
scores = scores[keep]
|
951
|
+
details = details[keep]
|
952
|
+
|
953
|
+
return peaks, rotations, scores, details
|