pytme 0.2.9__cp311-cp311-macosx_15_0_arm64.whl → 0.3.0__cp311-cp311-macosx_15_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pytme-0.3.0.data/scripts/estimate_memory_usage.py +76 -0
- pytme-0.3.0.data/scripts/match_template.py +1106 -0
- {pytme-0.2.9.data → pytme-0.3.0.data}/scripts/postprocess.py +320 -190
- {pytme-0.2.9.data → pytme-0.3.0.data}/scripts/preprocess.py +21 -31
- {pytme-0.2.9.data → pytme-0.3.0.data}/scripts/preprocessor_gui.py +85 -19
- pytme-0.3.0.data/scripts/pytme_runner.py +771 -0
- {pytme-0.2.9.dist-info → pytme-0.3.0.dist-info}/METADATA +22 -20
- pytme-0.3.0.dist-info/RECORD +126 -0
- {pytme-0.2.9.dist-info → pytme-0.3.0.dist-info}/entry_points.txt +2 -1
- pytme-0.3.0.dist-info/licenses/LICENSE +339 -0
- scripts/estimate_memory_usage.py +76 -0
- scripts/eval.py +93 -0
- scripts/extract_candidates.py +224 -0
- scripts/match_template.py +349 -378
- pytme-0.2.9.data/scripts/match_template.py → scripts/match_template_filters.py +213 -148
- scripts/postprocess.py +320 -190
- scripts/preprocess.py +21 -31
- scripts/preprocessor_gui.py +85 -19
- scripts/pytme_runner.py +771 -0
- scripts/refine_matches.py +625 -0
- tests/preprocessing/test_frequency_filters.py +28 -14
- tests/test_analyzer.py +41 -36
- tests/test_backends.py +1 -0
- tests/test_matching_cli.py +109 -53
- tests/test_matching_data.py +5 -5
- tests/test_matching_exhaustive.py +1 -2
- tests/test_matching_optimization.py +4 -9
- tests/test_matching_utils.py +1 -1
- tests/test_orientations.py +0 -1
- tme/__version__.py +1 -1
- tme/analyzer/__init__.py +2 -0
- tme/analyzer/_utils.py +26 -21
- tme/analyzer/aggregation.py +396 -222
- tme/analyzer/base.py +127 -0
- tme/analyzer/peaks.py +189 -201
- tme/analyzer/proxy.py +123 -0
- tme/backends/__init__.py +4 -3
- tme/backends/_cupy_utils.py +25 -24
- tme/backends/_jax_utils.py +20 -18
- tme/backends/cupy_backend.py +13 -26
- tme/backends/jax_backend.py +24 -23
- tme/backends/matching_backend.py +4 -3
- tme/backends/mlx_backend.py +4 -3
- tme/backends/npfftw_backend.py +34 -30
- tme/backends/pytorch_backend.py +18 -4
- tme/cli.py +126 -0
- tme/density.py +9 -7
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/filters/__init__.py +3 -3
- tme/filters/_utils.py +36 -10
- tme/filters/bandpass.py +229 -188
- tme/filters/compose.py +5 -4
- tme/filters/ctf.py +516 -254
- tme/filters/reconstruction.py +91 -32
- tme/filters/wedge.py +196 -135
- tme/filters/whitening.py +37 -42
- tme/matching_data.py +28 -39
- tme/matching_exhaustive.py +31 -27
- tme/matching_optimization.py +5 -4
- tme/matching_scores.py +25 -15
- tme/matching_utils.py +158 -28
- tme/memory.py +4 -3
- tme/orientations.py +22 -9
- tme/parser.py +114 -33
- tme/preprocessor.py +6 -5
- tme/rotations.py +10 -7
- tme/structure.py +4 -3
- pytme-0.2.9.data/scripts/estimate_ram_usage.py +0 -97
- pytme-0.2.9.dist-info/RECORD +0 -119
- pytme-0.2.9.dist-info/licenses/LICENSE +0 -153
- scripts/estimate_ram_usage.py +0 -97
- tests/data/Maps/.DS_Store +0 -0
- tests/data/Structures/.DS_Store +0 -0
- {pytme-0.2.9.dist-info → pytme-0.3.0.dist-info}/WHEEL +0 -0
- {pytme-0.2.9.dist-info → pytme-0.3.0.dist-info}/top_level.txt +0 -0
tme/analyzer/base.py
ADDED
@@ -0,0 +1,127 @@
|
|
1
|
+
"""
|
2
|
+
Implements abstract base class for template matching analyzers.
|
3
|
+
|
4
|
+
Copyright (c) 2025 European Molecular Biology Laboratory
|
5
|
+
|
6
|
+
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
7
|
+
"""
|
8
|
+
|
9
|
+
from typing import Tuple, List
|
10
|
+
from abc import ABC, abstractmethod
|
11
|
+
|
12
|
+
__all__ = ["AbstractAnalyzer"]
|
13
|
+
|
14
|
+
|
15
|
+
class AbstractAnalyzer(ABC):
|
16
|
+
"""
|
17
|
+
Abstract base class for template matching analyzers.
|
18
|
+
"""
|
19
|
+
|
20
|
+
@property
|
21
|
+
def shareable(self):
|
22
|
+
"""
|
23
|
+
Indicate whether the analyzer can be shared across processes.
|
24
|
+
|
25
|
+
Returns
|
26
|
+
-------
|
27
|
+
bool
|
28
|
+
True if the analyzer supports shared memory operations
|
29
|
+
and can be safely used across multiple processes, False
|
30
|
+
if it should only be used within a single process.
|
31
|
+
"""
|
32
|
+
return False
|
33
|
+
|
34
|
+
@abstractmethod
|
35
|
+
def init_state(self, *args, **kwargs) -> Tuple:
|
36
|
+
"""
|
37
|
+
Initialize the analyzer state.
|
38
|
+
|
39
|
+
Returns
|
40
|
+
-------
|
41
|
+
state : tuple
|
42
|
+
Initial state tuple of the analyzer instance. The exact structure
|
43
|
+
depends on the specific implementation.
|
44
|
+
|
45
|
+
Notes
|
46
|
+
-----
|
47
|
+
This method creates the initial state that will be passed to
|
48
|
+
:py:meth:`AbstractAnalyzer.__call__` and finally to
|
49
|
+
:py:meth:`AbstractAnalyzer.result`. The state should contain all necessary
|
50
|
+
data structures for accumulating analysis results.
|
51
|
+
"""
|
52
|
+
|
53
|
+
@abstractmethod
|
54
|
+
def __call__(self, state, scores, rotation_matrix, **kwargs) -> Tuple:
|
55
|
+
"""
|
56
|
+
Update the analyzer state with new scoring data.
|
57
|
+
|
58
|
+
Parameters
|
59
|
+
----------
|
60
|
+
state : tuple
|
61
|
+
Current analyzer state as returned :py:meth:`AbstractAnalyzer.init_state`
|
62
|
+
or previous invocations of :py:meth:`AbstractAnalyzer.__call__`.
|
63
|
+
scores : BackendArray
|
64
|
+
Array of new scores with dimensionality d.
|
65
|
+
rotation_matrix : BackendArray
|
66
|
+
Rotation matrix used to generate scores with shape (d,d).
|
67
|
+
**kwargs : dict
|
68
|
+
Keyword arguments used by specific implementations.
|
69
|
+
|
70
|
+
Returns
|
71
|
+
-------
|
72
|
+
tuple
|
73
|
+
Updated analyzer state incorporating the new data.
|
74
|
+
"""
|
75
|
+
|
76
|
+
@abstractmethod
|
77
|
+
def result(self, state: Tuple, **kwargs) -> Tuple:
|
78
|
+
"""
|
79
|
+
Finalize the analysis by performing potential post processing.
|
80
|
+
|
81
|
+
Parameters
|
82
|
+
----------
|
83
|
+
state : tuple
|
84
|
+
Analyzer state containing accumulated data.
|
85
|
+
**kwargs : dict
|
86
|
+
Keyword arguments used by specific implementations.
|
87
|
+
|
88
|
+
Returns
|
89
|
+
-------
|
90
|
+
result
|
91
|
+
Final analysis result. The exact struccture depends on the
|
92
|
+
analyzer implementation.
|
93
|
+
|
94
|
+
Notes
|
95
|
+
-----
|
96
|
+
This method converts the internal analyzer state into the
|
97
|
+
final output format expected by the template matching pipeline.
|
98
|
+
It may apply postprocessing operations like convolution mode
|
99
|
+
correction or coordinate transformations.
|
100
|
+
"""
|
101
|
+
|
102
|
+
@classmethod
|
103
|
+
@abstractmethod
|
104
|
+
def merge(cls, results: List[Tuple], **kwargs) -> Tuple:
|
105
|
+
"""
|
106
|
+
Merge multiple analyzer results.
|
107
|
+
|
108
|
+
Parameters
|
109
|
+
----------
|
110
|
+
results : list of tuple
|
111
|
+
List of tuple objects returned by :py:meth:`AbstractAnalyzer.result`
|
112
|
+
from different instances of the same analyzer class.
|
113
|
+
**kwargs : dict
|
114
|
+
Keyword arguments used by specific implementations.
|
115
|
+
|
116
|
+
Returns
|
117
|
+
-------
|
118
|
+
tuple
|
119
|
+
Single result object combining all input results.
|
120
|
+
|
121
|
+
Notes
|
122
|
+
-----
|
123
|
+
This method enables parallel processing by allowing results
|
124
|
+
from different processes or splits to be combined into a
|
125
|
+
unified result. The merge operation should handle overlapping
|
126
|
+
data appropriately and maintain consistency.
|
127
|
+
"""
|
tme/analyzer/peaks.py
CHANGED
@@ -1,23 +1,25 @@
|
|
1
|
-
"""
|
1
|
+
"""
|
2
|
+
Implements classes to analyze outputs from exhaustive template matching.
|
2
3
|
|
3
|
-
|
4
|
+
Copyright (c) 2023 European Molecular Biology Laboratory
|
4
5
|
|
5
|
-
|
6
|
+
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
6
7
|
"""
|
7
8
|
|
8
9
|
from functools import wraps
|
9
|
-
from abc import
|
10
|
-
from typing import Tuple, List, Dict
|
10
|
+
from abc import abstractmethod
|
11
|
+
from typing import Tuple, List, Dict
|
11
12
|
|
12
13
|
import numpy as np
|
13
14
|
from skimage.feature import peak_local_max
|
14
15
|
from skimage.registration._phase_cross_correlation import _upsampled_dft
|
15
16
|
|
17
|
+
from .base import AbstractAnalyzer
|
16
18
|
from ._utils import score_to_cart
|
17
19
|
from ..backends import backend as be
|
18
|
-
from ..matching_utils import split_shape
|
19
20
|
from ..types import BackendArray, NDArray
|
20
21
|
from ..rotations import euler_to_rotationmatrix
|
22
|
+
from ..matching_utils import split_shape, compute_extraction_box
|
21
23
|
|
22
24
|
__all__ = [
|
23
25
|
"PeakCaller",
|
@@ -141,7 +143,7 @@ def batchify(shape: Tuple[int], batch_dims: Tuple[int] = None) -> List:
|
|
141
143
|
yield from _generate_slices_recursive(0, ())
|
142
144
|
|
143
145
|
|
144
|
-
class PeakCaller(
|
146
|
+
class PeakCaller(AbstractAnalyzer):
|
145
147
|
"""
|
146
148
|
Base class for peak calling algorithms.
|
147
149
|
|
@@ -190,16 +192,7 @@ class PeakCaller(ABC):
|
|
190
192
|
if min_boundary_distance < 0:
|
191
193
|
raise ValueError("min_boundary_distance has to be non-negative.")
|
192
194
|
|
193
|
-
|
194
|
-
self.translations = be.full(
|
195
|
-
(num_peaks, ndim), fill_value=-1, dtype=be._int_dtype
|
196
|
-
)
|
197
|
-
self.rotations = be.full(
|
198
|
-
(num_peaks, ndim, ndim), fill_value=0, dtype=be._float_dtype
|
199
|
-
)
|
200
|
-
self.scores = be.full((num_peaks,), fill_value=0, dtype=be._float_dtype)
|
201
|
-
self.details = be.full((num_peaks,), fill_value=0, dtype=be._float_dtype)
|
202
|
-
|
195
|
+
self.shape = shape
|
203
196
|
self.num_peaks = int(num_peaks)
|
204
197
|
self.min_distance = int(min_distance)
|
205
198
|
self.min_boundary_distance = int(min_boundary_distance)
|
@@ -210,31 +203,47 @@ class PeakCaller(ABC):
|
|
210
203
|
|
211
204
|
self.min_score, self.max_score = min_score, max_score
|
212
205
|
|
213
|
-
|
214
|
-
|
215
|
-
self.convolution_mode = kwargs.get("convolution_mode", None)
|
216
|
-
self.targetshape = kwargs.get("targetshape", None)
|
217
|
-
self.templateshape = kwargs.get("templateshape", None)
|
218
|
-
|
219
|
-
def __iter__(self) -> Generator:
|
206
|
+
@abstractmethod
|
207
|
+
def call_peaks(self, scores: BackendArray, **kwargs) -> PeakType:
|
220
208
|
"""
|
221
|
-
|
222
|
-
|
209
|
+
Call peaks in the score space.
|
210
|
+
|
211
|
+
Parameters
|
212
|
+
----------
|
213
|
+
scores : BackendArray
|
214
|
+
Score array to update analyzer with.
|
215
|
+
**kwargs : dict
|
216
|
+
Optional keyword arguments passed to underlying implementations.
|
217
|
+
|
218
|
+
Returns
|
219
|
+
-------
|
220
|
+
BackendArray
|
221
|
+
Peak positions (n, d).
|
222
|
+
BackendArray
|
223
|
+
Peak details (n, d).
|
223
224
|
"""
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
225
|
+
|
226
|
+
def init_state(self):
|
227
|
+
ndim = len(self.shape)
|
228
|
+
translations = be.full(
|
229
|
+
(self.num_peaks, ndim), fill_value=-1, dtype=be._int_dtype
|
230
|
+
)
|
231
|
+
rotations = be.full(
|
232
|
+
(self.num_peaks, ndim, ndim), fill_value=0, dtype=be._float_dtype
|
233
|
+
)
|
234
|
+
for i in range(ndim):
|
235
|
+
rotations[:, i, i] = 1.0
|
236
|
+
|
237
|
+
scores = be.full((self.num_peaks,), fill_value=-1, dtype=be._float_dtype)
|
238
|
+
details = be.full((self.num_peaks,), fill_value=-1, dtype=be._float_dtype)
|
239
|
+
return translations, rotations, scores, details
|
231
240
|
|
232
241
|
def _get_peak_mask(self, peaks: BackendArray, scores: BackendArray) -> BackendArray:
|
233
242
|
if not len(peaks):
|
234
243
|
return None
|
235
244
|
|
236
245
|
valid_peaks = be.full((peaks.shape[0],), fill_value=1) == 1
|
237
|
-
if self.min_boundary_distance
|
246
|
+
if self.min_boundary_distance >= 0:
|
238
247
|
upper_limit = be.subtract(
|
239
248
|
be.to_backend_array(scores.shape), self.min_boundary_distance
|
240
249
|
)
|
@@ -315,20 +324,34 @@ class PeakCaller(ABC):
|
|
315
324
|
peak_positions = be.astype(peak_positions, int)
|
316
325
|
return peak_positions, peak_details
|
317
326
|
|
318
|
-
def __call__(
|
327
|
+
def __call__(
|
328
|
+
self,
|
329
|
+
state: Tuple,
|
330
|
+
scores: BackendArray,
|
331
|
+
rotation_matrix: BackendArray,
|
332
|
+
**kwargs,
|
333
|
+
) -> Tuple:
|
319
334
|
"""
|
320
335
|
Update the internal parameter store based on input array.
|
321
336
|
|
322
337
|
Parameters
|
323
338
|
----------
|
339
|
+
state : tuple
|
340
|
+
Current state tuple where:
|
341
|
+
- positions : BackendArray, (n, d) of peak positions
|
342
|
+
- rotations : BackendArray, (n, d, d) of correponding rotations
|
343
|
+
- scores : BackendArray, (n, ) of peak scores
|
344
|
+
- details : BackendArray, (n, ) of peak details
|
324
345
|
scores : BackendArray
|
325
|
-
|
346
|
+
Array of new scores to update analyzer with.
|
326
347
|
rotation_matrix : BackendArray
|
327
348
|
Rotation matrix used to obtain the score array.
|
328
349
|
**kwargs
|
329
350
|
Optional keyword aguments passed to :py:meth:`PeakCaller.call_peaks`.
|
330
351
|
"""
|
331
|
-
for ret in self._call_peaks(
|
352
|
+
for ret in self._call_peaks(
|
353
|
+
scores=scores, rotation_matrix=rotation_matrix, **kwargs
|
354
|
+
):
|
332
355
|
peak_positions, peak_details = ret
|
333
356
|
if peak_positions is None:
|
334
357
|
continue
|
@@ -341,7 +364,6 @@ class PeakCaller(ABC):
|
|
341
364
|
peak_scores = scores[tuple(peak_positions.T)]
|
342
365
|
if peak_details is not None:
|
343
366
|
peak_details = peak_details[valid_peaks]
|
344
|
-
# peak_details, peak_scores = peak_scores, -peak_details
|
345
367
|
else:
|
346
368
|
peak_details = be.full(peak_scores.shape, fill_value=-1)
|
347
369
|
|
@@ -351,66 +373,57 @@ class PeakCaller(ABC):
|
|
351
373
|
axis=0,
|
352
374
|
)
|
353
375
|
|
354
|
-
self._update(
|
376
|
+
state = self._update(
|
377
|
+
state,
|
355
378
|
peak_positions=peak_positions,
|
356
379
|
peak_details=peak_details,
|
357
380
|
peak_scores=peak_scores,
|
358
|
-
|
381
|
+
peak_rotations=rotations,
|
359
382
|
)
|
360
383
|
|
361
|
-
return
|
362
|
-
|
363
|
-
@abstractmethod
|
364
|
-
def call_peaks(self, scores: BackendArray, **kwargs) -> PeakType:
|
365
|
-
"""
|
366
|
-
Call peaks in the score space.
|
367
|
-
|
368
|
-
Parameters
|
369
|
-
----------
|
370
|
-
scores : BackendArray
|
371
|
-
Score array.
|
372
|
-
**kwargs : dict
|
373
|
-
Optional keyword arguments passed to underlying implementations.
|
374
|
-
|
375
|
-
Returns
|
376
|
-
-------
|
377
|
-
Tuple[BackendArray, BackendArray]
|
378
|
-
Array of peak coordinates and peak details.
|
379
|
-
"""
|
384
|
+
return state
|
380
385
|
|
381
386
|
@classmethod
|
382
|
-
def merge(cls,
|
387
|
+
def merge(cls, results=List[Tuple], **kwargs) -> Tuple:
|
383
388
|
"""
|
384
389
|
Merge multiple instances of :py:class:`PeakCaller`.
|
385
390
|
|
386
391
|
Parameters
|
387
392
|
----------
|
388
|
-
|
389
|
-
|
393
|
+
results : list of tuple
|
394
|
+
List of instance results created by applying `result`.
|
390
395
|
**kwargs
|
391
396
|
Optional keyword arguments.
|
392
397
|
|
393
398
|
Returns
|
394
399
|
-------
|
395
|
-
|
396
|
-
|
400
|
+
NDArray
|
401
|
+
Peak positions (n, d).
|
402
|
+
NDArray
|
403
|
+
Peak rotation matrices (n, d, d).
|
404
|
+
NDArray
|
405
|
+
Peak scores (n, ).
|
406
|
+
NDArray
|
407
|
+
Peak details (n,).
|
397
408
|
"""
|
398
409
|
if "shape" not in kwargs:
|
399
|
-
kwargs["shape"] = tuple(1 for _ in range(
|
410
|
+
kwargs["shape"] = tuple(1 for _ in range(results[0][0].shape[1]))
|
400
411
|
|
401
412
|
base = cls(**kwargs)
|
402
|
-
|
403
|
-
|
413
|
+
base_state = base.init_state()
|
414
|
+
for result in results:
|
415
|
+
if len(result) == 0:
|
404
416
|
continue
|
405
|
-
peak_positions, rotations, peak_scores, peak_details =
|
406
|
-
base._update(
|
417
|
+
peak_positions, rotations, peak_scores, peak_details = result
|
418
|
+
base_state = base._update(
|
419
|
+
base_state,
|
407
420
|
peak_positions=be.to_backend_array(peak_positions),
|
408
421
|
peak_details=be.to_backend_array(peak_details),
|
409
422
|
peak_scores=be.to_backend_array(peak_scores),
|
410
|
-
|
423
|
+
peak_rotations=be.to_backend_array(rotations),
|
411
424
|
offset=kwargs.get("offset", None),
|
412
425
|
)
|
413
|
-
return
|
426
|
+
return base_state
|
414
427
|
|
415
428
|
@staticmethod
|
416
429
|
def oversample_peaks(
|
@@ -517,10 +530,11 @@ class PeakCaller(ABC):
|
|
517
530
|
|
518
531
|
def _update(
|
519
532
|
self,
|
533
|
+
state,
|
520
534
|
peak_positions: BackendArray,
|
521
535
|
peak_details: BackendArray,
|
522
536
|
peak_scores: BackendArray,
|
523
|
-
|
537
|
+
peak_rotations: BackendArray,
|
524
538
|
offset: BackendArray = None,
|
525
539
|
):
|
526
540
|
"""
|
@@ -539,14 +553,15 @@ class PeakCaller(ABC):
|
|
539
553
|
offset : BackendArray, optional
|
540
554
|
Translation offset, e.g. from splitting, (d, ).
|
541
555
|
"""
|
556
|
+
translations, rotations, scores, details = state
|
542
557
|
if offset is not None:
|
543
558
|
offset = be.astype(be.to_backend_array(offset), peak_positions.dtype)
|
544
559
|
peak_positions = be.add(peak_positions, offset, out=peak_positions)
|
545
560
|
|
546
|
-
positions = be.concatenate((
|
547
|
-
rotations = be.concatenate((
|
548
|
-
scores = be.concatenate((
|
549
|
-
details = be.concatenate((
|
561
|
+
positions = be.concatenate((translations, peak_positions))
|
562
|
+
rotations = be.concatenate((rotations, peak_rotations))
|
563
|
+
scores = be.concatenate((scores, peak_scores))
|
564
|
+
details = be.concatenate((details, peak_details))
|
550
565
|
|
551
566
|
# topk filtering after distances yields more distributed peak calls
|
552
567
|
distance_order = filter_points_indices(
|
@@ -561,22 +576,69 @@ class PeakCaller(ABC):
|
|
561
576
|
)
|
562
577
|
final_order = distance_order[top_scores]
|
563
578
|
|
564
|
-
|
565
|
-
|
566
|
-
|
567
|
-
|
579
|
+
translations = positions[final_order, :]
|
580
|
+
rotations = rotations[final_order, :]
|
581
|
+
scores = scores[final_order]
|
582
|
+
details = details[final_order]
|
583
|
+
return translations, rotations, scores, details
|
568
584
|
|
569
|
-
def
|
570
|
-
|
571
|
-
|
585
|
+
def result(
|
586
|
+
self,
|
587
|
+
state,
|
588
|
+
fast_shape: Tuple[int] = None,
|
589
|
+
targetshape: Tuple[int] = None,
|
590
|
+
templateshape: Tuple[int] = None,
|
591
|
+
convolution_shape: Tuple[int] = None,
|
592
|
+
fourier_shift: Tuple[int] = None,
|
593
|
+
convolution_mode: str = None,
|
594
|
+
**kwargs,
|
595
|
+
):
|
596
|
+
"""
|
597
|
+
Finalize the analysis result with optional postprocessing.
|
572
598
|
|
573
|
-
|
599
|
+
Parameters
|
600
|
+
----------
|
601
|
+
state : tuple
|
602
|
+
Current state tuple where:
|
603
|
+
- positions : BackendArray, (n, d) of peak positions
|
604
|
+
- rotations : BackendArray, (n, d, d) of correponding rotations
|
605
|
+
- scores : BackendArray, (n, ) of peak scores
|
606
|
+
- details : BackendArray, (n, ) of peak details
|
607
|
+
targetshape : Tuple[int], optional
|
608
|
+
Shape of the target for convolution mode correction.
|
609
|
+
templateshape : Tuple[int], optional
|
610
|
+
Shape of the template for convolution mode correction.
|
611
|
+
convolution_shape : Tuple[int], optional
|
612
|
+
Shape used for convolution.
|
613
|
+
fourier_shift : Tuple[int], optional.
|
614
|
+
Shift to apply for Fourier correction.
|
615
|
+
convolution_mode : str, optional
|
616
|
+
Convolution mode for padding correction.
|
617
|
+
**kwargs
|
618
|
+
Additional keyword arguments.
|
574
619
|
|
575
|
-
|
576
|
-
|
577
|
-
|
578
|
-
|
579
|
-
|
620
|
+
Returns
|
621
|
+
-------
|
622
|
+
tuple
|
623
|
+
Final result tuple (positions, rotations, scores, details).
|
624
|
+
"""
|
625
|
+
translations, rotations, scores, details = state
|
626
|
+
|
627
|
+
positions, valid_peaks = score_to_cart(
|
628
|
+
positions=translations,
|
629
|
+
fast_shape=fast_shape,
|
630
|
+
targetshape=targetshape,
|
631
|
+
templateshape=templateshape,
|
632
|
+
convolution_shape=convolution_shape,
|
633
|
+
fourier_shift=fourier_shift,
|
634
|
+
convolution_mode=convolution_mode,
|
635
|
+
**kwargs,
|
636
|
+
)
|
637
|
+
translations = be.to_cpu_array(positions[valid_peaks])
|
638
|
+
rotations = be.to_cpu_array(rotations[valid_peaks])
|
639
|
+
scores = be.to_cpu_array(scores[valid_peaks])
|
640
|
+
details = be.to_cpu_array(details[valid_peaks])
|
641
|
+
return translations, rotations, scores, details
|
580
642
|
|
581
643
|
|
582
644
|
class PeakCallerSort(PeakCaller):
|
@@ -703,14 +765,14 @@ class PeakCallerRecursiveMasking(PeakCaller):
|
|
703
765
|
values. If rotations and rotation_mapping is provided, the respective
|
704
766
|
rotation will be applied to the mask, otherwise rotation_matrix is used.
|
705
767
|
"""
|
706
|
-
|
768
|
+
peaks = []
|
769
|
+
box = tuple(self.min_distance for _ in range(scores.ndim))
|
707
770
|
|
708
|
-
|
709
|
-
|
710
|
-
|
711
|
-
mask = be.
|
712
|
-
|
713
|
-
rotated_template = be.zeros(mask.shape, dtype=mask.dtype)
|
771
|
+
scores = be.to_backend_array(scores)
|
772
|
+
if mask is not None:
|
773
|
+
box = mask.shape
|
774
|
+
mask = be.to_backend_array(mask)
|
775
|
+
mask_buffer = be.zeros(mask.shape, dtype=mask.dtype)
|
714
776
|
|
715
777
|
peak_limit = self.num_peaks
|
716
778
|
if min_score is not None:
|
@@ -718,39 +780,45 @@ class PeakCallerRecursiveMasking(PeakCaller):
|
|
718
780
|
else:
|
719
781
|
min_score = be.min(scores) - 1
|
720
782
|
|
721
|
-
|
722
|
-
|
723
|
-
|
783
|
+
_scores = be.zeros(scores.shape, dtype=scores.dtype)
|
784
|
+
_scores[:] = scores[:]
|
724
785
|
while True:
|
725
|
-
be.argmax(
|
726
|
-
peak
|
727
|
-
indices=be.argmax(scores_copy), shape=scores_copy.shape
|
728
|
-
)
|
729
|
-
if scores_copy[tuple(peak)] < min_score:
|
786
|
+
peak = be.unravel_index(indices=be.argmax(_scores), shape=_scores.shape)
|
787
|
+
if _scores[tuple(peak)] < min_score:
|
730
788
|
break
|
789
|
+
peaks.append(peak)
|
731
790
|
|
732
|
-
|
733
|
-
|
734
|
-
|
735
|
-
|
736
|
-
rotation_space=rotations,
|
737
|
-
rotation_mapping=rotation_mapping,
|
738
|
-
rotation_matrix=rotation_matrix,
|
791
|
+
score_beg, score_end, tmpl_beg, tmpl_end, _ = compute_extraction_box(
|
792
|
+
centers=be.to_backend_array(peak)[None],
|
793
|
+
extraction_shape=box,
|
794
|
+
original_shape=scores.shape,
|
739
795
|
)
|
740
|
-
|
741
|
-
|
742
|
-
|
743
|
-
|
744
|
-
|
745
|
-
mask=mask,
|
746
|
-
rotated_template=rotated_template,
|
796
|
+
score_slice = tuple(
|
797
|
+
slice(int(x), int(y)) for x, y in zip(score_beg[0], score_end[0])
|
798
|
+
)
|
799
|
+
tmpl_slice = tuple(
|
800
|
+
slice(int(x), int(y)) for x, y in zip(tmpl_beg[0], tmpl_end[0])
|
747
801
|
)
|
748
802
|
|
749
|
-
|
803
|
+
score_mask = 0
|
804
|
+
if mask is not None:
|
805
|
+
mask_buffer.fill(0)
|
806
|
+
rmat = self._get_rotation_matrix(
|
807
|
+
peak=peak,
|
808
|
+
rotation_space=rotations,
|
809
|
+
rotation_mapping=rotation_mapping,
|
810
|
+
rotation_matrix=rotation_matrix,
|
811
|
+
)
|
812
|
+
be.rigid_transform(
|
813
|
+
arr=mask, rotation_matrix=rmat, order=1, out=mask_buffer
|
814
|
+
)
|
815
|
+
score_mask = mask_buffer[tmpl_slice] <= 0.1
|
816
|
+
|
817
|
+
_scores[score_slice] = be.multiply(_scores[score_slice], score_mask)
|
818
|
+
if len(peaks) >= peak_limit:
|
750
819
|
break
|
751
820
|
|
752
|
-
|
753
|
-
return peaks, None
|
821
|
+
return be.to_backend_array(peaks), None
|
754
822
|
|
755
823
|
@staticmethod
|
756
824
|
def _get_rotation_matrix(
|
@@ -790,86 +858,6 @@ class PeakCallerRecursiveMasking(PeakCaller):
|
|
790
858
|
)
|
791
859
|
return rotation
|
792
860
|
|
793
|
-
@staticmethod
|
794
|
-
def _mask_scores_box(
|
795
|
-
scores: BackendArray, peak: BackendArray, mask: BackendArray, **kwargs: Dict
|
796
|
-
) -> None:
|
797
|
-
"""
|
798
|
-
Mask scores in a box around a peak.
|
799
|
-
|
800
|
-
Parameters
|
801
|
-
----------
|
802
|
-
scores : BackendArray
|
803
|
-
Data array of scores.
|
804
|
-
peak : BackendArray
|
805
|
-
Peak coordinates.
|
806
|
-
mask : BackendArray
|
807
|
-
Mask array.
|
808
|
-
"""
|
809
|
-
start = be.maximum(be.subtract(peak, mask.shape), 0)
|
810
|
-
stop = be.minimum(be.add(peak, mask.shape), scores.shape)
|
811
|
-
start, stop = be.astype(start, int), be.astype(stop, int)
|
812
|
-
coords = tuple(slice(*pos) for pos in zip(start, stop))
|
813
|
-
scores[coords] = 0
|
814
|
-
return None
|
815
|
-
|
816
|
-
@staticmethod
|
817
|
-
def _mask_scores_rotate(
|
818
|
-
scores: BackendArray,
|
819
|
-
peak: BackendArray,
|
820
|
-
mask: BackendArray,
|
821
|
-
rotated_template: BackendArray,
|
822
|
-
rotation_matrix: BackendArray,
|
823
|
-
**kwargs: Dict,
|
824
|
-
) -> None:
|
825
|
-
"""
|
826
|
-
Mask scores using mask rotation around a peak.
|
827
|
-
|
828
|
-
Parameters
|
829
|
-
----------
|
830
|
-
scores : BackendArray
|
831
|
-
Data array of scores.
|
832
|
-
peak : BackendArray
|
833
|
-
Peak coordinates.
|
834
|
-
mask : BackendArray
|
835
|
-
Mask array.
|
836
|
-
rotated_template : BackendArray
|
837
|
-
Empty array to write mask rotations to.
|
838
|
-
rotation_matrix : BackendArray
|
839
|
-
Rotation matrix.
|
840
|
-
"""
|
841
|
-
left_pad = be.divide(mask.shape, 2).astype(int)
|
842
|
-
right_pad = be.add(left_pad, be.mod(mask.shape, 2).astype(int))
|
843
|
-
|
844
|
-
score_start = be.subtract(peak, left_pad)
|
845
|
-
score_stop = be.add(peak, right_pad)
|
846
|
-
|
847
|
-
template_start = be.subtract(be.maximum(score_start, 0), score_start)
|
848
|
-
template_stop = be.subtract(score_stop, be.minimum(score_stop, scores.shape))
|
849
|
-
template_stop = be.subtract(mask.shape, template_stop)
|
850
|
-
|
851
|
-
score_start = be.maximum(score_start, 0)
|
852
|
-
score_stop = be.minimum(score_stop, scores.shape)
|
853
|
-
score_start = be.astype(score_start, int)
|
854
|
-
score_stop = be.astype(score_stop, int)
|
855
|
-
|
856
|
-
template_start = be.astype(template_start, int)
|
857
|
-
template_stop = be.astype(template_stop, int)
|
858
|
-
coords_score = tuple(slice(*pos) for pos in zip(score_start, score_stop))
|
859
|
-
coords_template = tuple(
|
860
|
-
slice(*pos) for pos in zip(template_start, template_stop)
|
861
|
-
)
|
862
|
-
|
863
|
-
rotated_template.fill(0)
|
864
|
-
be.rigid_transform(
|
865
|
-
arr=mask, rotation_matrix=rotation_matrix, order=1, out=rotated_template
|
866
|
-
)
|
867
|
-
|
868
|
-
scores[coords_score] = be.multiply(
|
869
|
-
scores[coords_score], (rotated_template[coords_template] <= 0.1)
|
870
|
-
)
|
871
|
-
return None
|
872
|
-
|
873
861
|
|
874
862
|
class PeakCallerScipy(PeakCaller):
|
875
863
|
"""
|