pytme 0.3b0__cp311-cp311-macosx_15_0_arm64.whl → 0.3.1__cp311-cp311-macosx_15_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pytme-0.3b0.data → pytme-0.3.1.data}/scripts/estimate_memory_usage.py +1 -5
- {pytme-0.3b0.data → pytme-0.3.1.data}/scripts/match_template.py +177 -226
- {pytme-0.3b0.data → pytme-0.3.1.data}/scripts/postprocess.py +69 -47
- {pytme-0.3b0.data → pytme-0.3.1.data}/scripts/preprocess.py +10 -23
- {pytme-0.3b0.data → pytme-0.3.1.data}/scripts/preprocessor_gui.py +98 -28
- pytme-0.3.1.data/scripts/pytme_runner.py +1223 -0
- {pytme-0.3b0.dist-info → pytme-0.3.1.dist-info}/METADATA +15 -15
- pytme-0.3.1.dist-info/RECORD +133 -0
- {pytme-0.3b0.dist-info → pytme-0.3.1.dist-info}/entry_points.txt +1 -0
- pytme-0.3.1.dist-info/licenses/LICENSE +339 -0
- scripts/estimate_memory_usage.py +1 -5
- scripts/eval.py +93 -0
- scripts/extract_candidates.py +118 -99
- scripts/match_template.py +177 -226
- scripts/match_template_filters.py +1200 -0
- scripts/postprocess.py +69 -47
- scripts/preprocess.py +10 -23
- scripts/preprocessor_gui.py +98 -28
- scripts/pytme_runner.py +1223 -0
- scripts/refine_matches.py +156 -387
- tests/data/.DS_Store +0 -0
- tests/data/Blurring/.DS_Store +0 -0
- tests/data/Maps/.DS_Store +0 -0
- tests/data/Raw/.DS_Store +0 -0
- tests/data/Structures/.DS_Store +0 -0
- tests/preprocessing/test_frequency_filters.py +19 -10
- tests/preprocessing/test_utils.py +18 -0
- tests/test_analyzer.py +122 -122
- tests/test_backends.py +4 -9
- tests/test_density.py +0 -1
- tests/test_matching_cli.py +30 -30
- tests/test_matching_data.py +5 -5
- tests/test_matching_utils.py +11 -61
- tests/test_rotations.py +1 -1
- tme/__version__.py +1 -1
- tme/analyzer/__init__.py +1 -1
- tme/analyzer/_utils.py +5 -8
- tme/analyzer/aggregation.py +28 -9
- tme/analyzer/base.py +25 -36
- tme/analyzer/peaks.py +49 -122
- tme/analyzer/proxy.py +1 -0
- tme/backends/_jax_utils.py +31 -28
- tme/backends/_numpyfftw_utils.py +270 -0
- tme/backends/cupy_backend.py +11 -54
- tme/backends/jax_backend.py +72 -48
- tme/backends/matching_backend.py +6 -51
- tme/backends/mlx_backend.py +1 -27
- tme/backends/npfftw_backend.py +95 -90
- tme/backends/pytorch_backend.py +5 -26
- tme/density.py +7 -10
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/filters/__init__.py +2 -2
- tme/filters/_utils.py +32 -7
- tme/filters/bandpass.py +225 -186
- tme/filters/ctf.py +138 -87
- tme/filters/reconstruction.py +38 -9
- tme/filters/wedge.py +98 -112
- tme/filters/whitening.py +1 -6
- tme/mask.py +341 -0
- tme/matching_data.py +20 -44
- tme/matching_exhaustive.py +46 -56
- tme/matching_optimization.py +2 -1
- tme/matching_scores.py +216 -412
- tme/matching_utils.py +82 -424
- tme/memory.py +1 -1
- tme/orientations.py +16 -8
- tme/parser.py +109 -29
- tme/preprocessor.py +2 -2
- tme/rotations.py +1 -1
- pytme-0.3b0.dist-info/RECORD +0 -122
- pytme-0.3b0.dist-info/licenses/LICENSE +0 -153
- {pytme-0.3b0.dist-info → pytme-0.3.1.dist-info}/WHEEL +0 -0
- {pytme-0.3b0.dist-info → pytme-0.3.1.dist-info}/top_level.txt +0 -0
tests/test_matching_utils.py
CHANGED
@@ -10,9 +10,6 @@ from tme.backends import backend as be
|
|
10
10
|
from tme.memory import MATCHING_MEMORY_REGISTRY
|
11
11
|
from tme.matching_utils import (
|
12
12
|
compute_parallelization_schedule,
|
13
|
-
elliptical_mask,
|
14
|
-
box_mask,
|
15
|
-
tube_mask,
|
16
13
|
create_mask,
|
17
14
|
scramble_phases,
|
18
15
|
apply_convolution_mode,
|
@@ -50,73 +47,26 @@ class TestMatchingUtils:
|
|
50
47
|
max_splits=256,
|
51
48
|
)
|
52
49
|
|
53
|
-
|
50
|
+
@pytest.mark.parametrize("mask_type", ["ellipse", "box", "tube", "membrane"])
|
51
|
+
def test_create_mask(self, mask_type: str):
|
54
52
|
create_mask(
|
55
|
-
mask_type=
|
53
|
+
mask_type=mask_type,
|
56
54
|
shape=self.density.shape,
|
57
55
|
radius=5,
|
58
56
|
center=np.divide(self.density.shape, 2),
|
57
|
+
height=np.max(self.density.shape) // 2,
|
58
|
+
size=np.divide(self.density.shape, 2).astype(int),
|
59
|
+
thickness=2,
|
60
|
+
separation=2,
|
61
|
+
symmetry_axis=1,
|
62
|
+
inner_radius=5,
|
63
|
+
outer_radius=10,
|
59
64
|
)
|
60
65
|
|
61
66
|
def test_create_mask_error(self):
|
62
67
|
with pytest.raises(ValueError):
|
63
68
|
create_mask(mask_type=None)
|
64
69
|
|
65
|
-
def test_elliptical_mask(self):
|
66
|
-
elliptical_mask(
|
67
|
-
shape=self.density.shape,
|
68
|
-
radius=5,
|
69
|
-
center=np.divide(self.density.shape, 2),
|
70
|
-
)
|
71
|
-
|
72
|
-
def test_box_mask(self):
|
73
|
-
box_mask(
|
74
|
-
shape=self.density.shape,
|
75
|
-
height=[5, 10, 20],
|
76
|
-
center=np.divide(self.density.shape, 2),
|
77
|
-
)
|
78
|
-
|
79
|
-
def test_tube_mask(self):
|
80
|
-
tube_mask(
|
81
|
-
shape=self.density.shape,
|
82
|
-
outer_radius=10,
|
83
|
-
inner_radius=5,
|
84
|
-
height=5,
|
85
|
-
base_center=np.divide(self.density.shape, 2),
|
86
|
-
symmetry_axis=1,
|
87
|
-
)
|
88
|
-
|
89
|
-
def test_tube_mask_error(self):
|
90
|
-
with pytest.raises(ValueError):
|
91
|
-
tube_mask(
|
92
|
-
shape=self.density.shape,
|
93
|
-
outer_radius=5,
|
94
|
-
inner_radius=10,
|
95
|
-
height=5,
|
96
|
-
base_center=np.divide(self.density.shape, 2),
|
97
|
-
symmetry_axis=1,
|
98
|
-
)
|
99
|
-
|
100
|
-
with pytest.raises(ValueError):
|
101
|
-
tube_mask(
|
102
|
-
shape=self.density.shape,
|
103
|
-
outer_radius=5,
|
104
|
-
inner_radius=10,
|
105
|
-
height=10 * np.max(self.density.shape),
|
106
|
-
base_center=np.divide(self.density.shape, 2),
|
107
|
-
symmetry_axis=1,
|
108
|
-
)
|
109
|
-
|
110
|
-
with pytest.raises(ValueError):
|
111
|
-
tube_mask(
|
112
|
-
shape=self.density.shape,
|
113
|
-
outer_radius=5,
|
114
|
-
inner_radius=10,
|
115
|
-
height=10 * np.max(self.density.shape),
|
116
|
-
base_center=np.divide(self.density.shape, 2),
|
117
|
-
symmetry_axis=len(self.density.shape) + 1,
|
118
|
-
)
|
119
|
-
|
120
70
|
def test_scramble_phases(self):
|
121
71
|
scramble_phases(arr=self.density.data, noise_proportion=0.5)
|
122
72
|
|
@@ -139,7 +89,7 @@ class TestMatchingUtils:
|
|
139
89
|
expected_size = np.subtract(
|
140
90
|
self.density.shape, self.structure_density.shape
|
141
91
|
)
|
142
|
-
expected_size +=
|
92
|
+
expected_size += 1
|
143
93
|
assert np.allclose(ret.shape, expected_size)
|
144
94
|
|
145
95
|
def test_apply_convolution_mode_error(self):
|
tests/test_rotations.py
CHANGED
@@ -8,6 +8,7 @@ from tme import Density
|
|
8
8
|
from scipy.spatial.transform import Rotation
|
9
9
|
from scipy.signal import correlate
|
10
10
|
|
11
|
+
from tme.mask import elliptical_mask
|
11
12
|
from tme.rotations import (
|
12
13
|
euler_from_rotationmatrix,
|
13
14
|
euler_to_rotationmatrix,
|
@@ -16,7 +17,6 @@ from tme.rotations import (
|
|
16
17
|
get_rotation_matrices,
|
17
18
|
)
|
18
19
|
from tme.matching_utils import (
|
19
|
-
elliptical_mask,
|
20
20
|
split_shape,
|
21
21
|
compute_full_convolution_index,
|
22
22
|
)
|
tme/__version__.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
__version__ = "0.3.
|
1
|
+
__version__ = "0.3.1"
|
tme/analyzer/__init__.py
CHANGED
tme/analyzer/_utils.py
CHANGED
@@ -47,10 +47,7 @@ def _convmode_to_shape(
|
|
47
47
|
if convolution_mode == "same":
|
48
48
|
output_shape = targetshape
|
49
49
|
elif convolution_mode == "valid":
|
50
|
-
output_shape = be.
|
51
|
-
be.subtract(targetshape, templateshape),
|
52
|
-
be.mod(templateshape, 2),
|
53
|
-
)
|
50
|
+
output_shape = be.subtract(targetshape, templateshape) + 1
|
54
51
|
return be.to_backend_array(output_shape)
|
55
52
|
|
56
53
|
|
@@ -96,22 +93,22 @@ def cart_to_score(
|
|
96
93
|
templateshape = be.to_backend_array(templateshape)
|
97
94
|
convolution_shape = be.to_backend_array(convolution_shape)
|
98
95
|
|
99
|
-
# Compute removed padding
|
100
96
|
output_shape = _convmode_to_shape(
|
101
97
|
convolution_mode=convolution_mode,
|
102
98
|
targetshape=targetshape,
|
103
99
|
templateshape=templateshape,
|
104
100
|
convolution_shape=convolution_shape,
|
105
101
|
)
|
106
|
-
valid_positions = be.multiply(positions >= 0, positions < output_shape)
|
107
|
-
valid_positions = be.sum(valid_positions, axis=1) == positions.shape[1]
|
108
102
|
|
103
|
+
# Offset from padding the target
|
109
104
|
starts = be.astype(
|
110
105
|
be.divide(be.subtract(convolution_shape, output_shape), 2),
|
111
106
|
be._int_dtype,
|
112
107
|
)
|
113
|
-
|
114
108
|
positions = be.add(positions, starts)
|
109
|
+
|
110
|
+
valid_positions = be.multiply(positions >= 0, positions < fast_shape)
|
111
|
+
valid_positions = be.sum(valid_positions, axis=1) == positions.shape[1]
|
115
112
|
if fourier_shift is not None:
|
116
113
|
fourier_shift = be.to_backend_array(fourier_shift)
|
117
114
|
positions = be.subtract(positions, fourier_shift)
|
tme/analyzer/aggregation.py
CHANGED
@@ -57,22 +57,23 @@ class MaxScoreOverRotations(AbstractAnalyzer):
|
|
57
57
|
The following achieves the minimal definition of a :py:class:`MaxScoreOverRotations`
|
58
58
|
instance
|
59
59
|
|
60
|
+
>>> import numpy as np
|
60
61
|
>>> from tme.analyzer import MaxScoreOverRotations
|
61
|
-
>>> analyzer = MaxScoreOverRotations(shape
|
62
|
+
>>> analyzer = MaxScoreOverRotations(shape=(50, 50))
|
62
63
|
|
63
64
|
The following simulates a template matching run by creating random data for a range
|
64
65
|
of rotations and sending it to ``analyzer`` via its __call__ method
|
65
66
|
|
66
|
-
|
67
|
+
>>> state = analyzer.init_state()
|
67
68
|
>>> for rotation_number in range(10):
|
68
69
|
>>> scores = np.random.rand(50,50)
|
69
70
|
>>> rotation = np.random.rand(scores.ndim, scores.ndim)
|
70
|
-
>>> state
|
71
|
+
>>> state = analyzer(state, scores=scores, rotation_matrix=rotation)
|
71
72
|
|
72
73
|
The aggregated scores can be extracted by invoking the result method of
|
73
74
|
``analyzer``
|
74
75
|
|
75
|
-
>>> results = analyzer.result()
|
76
|
+
>>> results = analyzer.result(state)
|
76
77
|
|
77
78
|
The ``results`` tuple contains (1) the maximum scores for each translation,
|
78
79
|
(2) an offset which is relevant when merging results from split template matching
|
@@ -100,6 +101,7 @@ class MaxScoreOverRotations(AbstractAnalyzer):
|
|
100
101
|
shm_handler: object = None,
|
101
102
|
use_memmap: bool = False,
|
102
103
|
inversion_mapping: bool = False,
|
104
|
+
jax_mode: bool = False,
|
103
105
|
**kwargs,
|
104
106
|
):
|
105
107
|
self._use_memmap = use_memmap
|
@@ -107,6 +109,10 @@ class MaxScoreOverRotations(AbstractAnalyzer):
|
|
107
109
|
self._shape = tuple(int(x) for x in shape)
|
108
110
|
self._inversion_mapping = inversion_mapping
|
109
111
|
|
112
|
+
self._jax_mode = jax_mode
|
113
|
+
if self._jax_mode:
|
114
|
+
self._inversion_mapping = False
|
115
|
+
|
110
116
|
if offset is None:
|
111
117
|
offset = be.zeros(len(self._shape), be._int_dtype)
|
112
118
|
self._offset = be.astype(be.to_backend_array(offset), int)
|
@@ -138,6 +144,7 @@ class MaxScoreOverRotations(AbstractAnalyzer):
|
|
138
144
|
state: Tuple,
|
139
145
|
scores: BackendArray,
|
140
146
|
rotation_matrix: BackendArray,
|
147
|
+
**kwargs,
|
141
148
|
) -> Tuple:
|
142
149
|
"""
|
143
150
|
Update the parameter store.
|
@@ -167,6 +174,8 @@ class MaxScoreOverRotations(AbstractAnalyzer):
|
|
167
174
|
rotation_matrix = be.astype(rotation_matrix, be._float_dtype)
|
168
175
|
if self._inversion_mapping:
|
169
176
|
rotation_mapping[rotation_index] = rotation_matrix
|
177
|
+
elif self._jax_mode:
|
178
|
+
rotation_index = kwargs.get("rotation_index", 0)
|
170
179
|
else:
|
171
180
|
rotation = be.tobytes(rotation_matrix)
|
172
181
|
rotation_index = rotation_mapping.setdefault(rotation, rotation_index)
|
@@ -536,13 +545,19 @@ class MaxScoreOverRotationsConstrained(MaxScoreOverRotations):
|
|
536
545
|
)
|
537
546
|
|
538
547
|
def __call__(
|
539
|
-
self,
|
548
|
+
self,
|
549
|
+
state: Tuple,
|
550
|
+
scores: BackendArray,
|
551
|
+
rotation_matrix: BackendArray,
|
552
|
+
**kwargs,
|
540
553
|
) -> Tuple:
|
541
554
|
mask = self._get_constraint(rotation_matrix)
|
542
555
|
mask = self._get_score_mask(mask=mask, scores=scores)
|
543
556
|
|
544
557
|
scores = be.multiply(scores, mask, out=scores)
|
545
|
-
return super().__call__(
|
558
|
+
return super().__call__(
|
559
|
+
state, scores=scores, rotation_matrix=rotation_matrix, **kwargs
|
560
|
+
)
|
546
561
|
|
547
562
|
def _get_constraint(self, rotation_matrix: BackendArray) -> BackendArray:
|
548
563
|
"""
|
@@ -627,7 +642,11 @@ class MaxScoreOverTranslations(MaxScoreOverRotations):
|
|
627
642
|
return scores, rotations, {}
|
628
643
|
|
629
644
|
def __call__(
|
630
|
-
self,
|
645
|
+
self,
|
646
|
+
state,
|
647
|
+
scores: BackendArray,
|
648
|
+
rotation_matrix: BackendArray,
|
649
|
+
**kwargs,
|
631
650
|
) -> Tuple:
|
632
651
|
prev_scores, rotations, rotation_mapping = state
|
633
652
|
|
@@ -738,5 +757,5 @@ class MaxScoreOverTranslations(MaxScoreOverRotations):
|
|
738
757
|
cls._invert_rmap(master_rotation_mapping),
|
739
758
|
)
|
740
759
|
|
741
|
-
def
|
742
|
-
return
|
760
|
+
def result(self, state: Tuple, **kwargs) -> Tuple:
|
761
|
+
return state
|
tme/analyzer/base.py
CHANGED
@@ -38,16 +38,16 @@ class AbstractAnalyzer(ABC):
|
|
38
38
|
|
39
39
|
Returns
|
40
40
|
-------
|
41
|
-
state
|
42
|
-
Initial state tuple
|
43
|
-
|
44
|
-
implementation.
|
41
|
+
state : tuple
|
42
|
+
Initial state tuple of the analyzer instance. The exact structure
|
43
|
+
depends on the specific implementation.
|
45
44
|
|
46
45
|
Notes
|
47
46
|
-----
|
48
47
|
This method creates the initial state that will be passed to
|
49
|
-
|
50
|
-
|
48
|
+
:py:meth:`AbstractAnalyzer.__call__` and finally to
|
49
|
+
:py:meth:`AbstractAnalyzer.result`. The state should contain all necessary
|
50
|
+
data structures for accumulating analysis results.
|
51
51
|
"""
|
52
52
|
|
53
53
|
@abstractmethod
|
@@ -57,49 +57,39 @@ class AbstractAnalyzer(ABC):
|
|
57
57
|
|
58
58
|
Parameters
|
59
59
|
----------
|
60
|
-
state :
|
61
|
-
Current analyzer state as returned
|
62
|
-
previous
|
60
|
+
state : tuple
|
61
|
+
Current analyzer state as returned :py:meth:`AbstractAnalyzer.init_state`
|
62
|
+
or previous invocations of :py:meth:`AbstractAnalyzer.__call__`.
|
63
63
|
scores : BackendArray
|
64
|
-
Array of scores
|
64
|
+
Array of new scores with dimensionality d.
|
65
65
|
rotation_matrix : BackendArray
|
66
|
-
Rotation matrix used to generate
|
66
|
+
Rotation matrix used to generate scores with shape (d,d).
|
67
67
|
**kwargs : dict
|
68
|
-
|
69
|
-
implementation.
|
68
|
+
Keyword arguments used by specific implementations.
|
70
69
|
|
71
70
|
Returns
|
72
71
|
-------
|
73
|
-
|
74
|
-
Updated analyzer state
|
75
|
-
|
76
|
-
Notes
|
77
|
-
-----
|
78
|
-
This method should be pure functional - it should not modify
|
79
|
-
the input state but return a new state with the updates applied.
|
80
|
-
The exact signature may vary between implementations.
|
72
|
+
tuple
|
73
|
+
Updated analyzer state incorporating the new data.
|
81
74
|
"""
|
82
|
-
pass
|
83
75
|
|
84
76
|
@abstractmethod
|
85
77
|
def result(self, state: Tuple, **kwargs) -> Tuple:
|
86
78
|
"""
|
87
|
-
Finalize the analysis
|
79
|
+
Finalize the analysis by performing potential post processing.
|
88
80
|
|
89
81
|
Parameters
|
90
82
|
----------
|
91
83
|
state : tuple
|
92
|
-
|
84
|
+
Analyzer state containing accumulated data.
|
93
85
|
**kwargs : dict
|
94
|
-
|
95
|
-
such as postprocessing parameters.
|
86
|
+
Keyword arguments used by specific implementations.
|
96
87
|
|
97
88
|
Returns
|
98
89
|
-------
|
99
90
|
result
|
100
|
-
Final analysis result. The exact
|
101
|
-
analyzer implementation
|
102
|
-
scores, rotation information, and metadata.
|
91
|
+
Final analysis result. The exact struccture depends on the
|
92
|
+
analyzer implementation.
|
103
93
|
|
104
94
|
Notes
|
105
95
|
-----
|
@@ -108,25 +98,24 @@ class AbstractAnalyzer(ABC):
|
|
108
98
|
It may apply postprocessing operations like convolution mode
|
109
99
|
correction or coordinate transformations.
|
110
100
|
"""
|
111
|
-
pass
|
112
101
|
|
113
102
|
@classmethod
|
114
103
|
@abstractmethod
|
115
104
|
def merge(cls, results: List[Tuple], **kwargs) -> Tuple:
|
116
105
|
"""
|
117
|
-
Merge
|
106
|
+
Merge multiple analyzer results.
|
118
107
|
|
119
108
|
Parameters
|
120
109
|
----------
|
121
|
-
results : list
|
122
|
-
List of
|
123
|
-
from
|
110
|
+
results : list of tuple
|
111
|
+
List of tuple objects returned by :py:meth:`AbstractAnalyzer.result`
|
112
|
+
from different instances of the same analyzer class.
|
124
113
|
**kwargs : dict
|
125
|
-
|
114
|
+
Keyword arguments used by specific implementations.
|
126
115
|
|
127
116
|
Returns
|
128
117
|
-------
|
129
|
-
|
118
|
+
tuple
|
130
119
|
Single result object combining all input results.
|
131
120
|
|
132
121
|
Notes
|
tme/analyzer/peaks.py
CHANGED
@@ -17,9 +17,9 @@ from skimage.registration._phase_cross_correlation import _upsampled_dft
|
|
17
17
|
from .base import AbstractAnalyzer
|
18
18
|
from ._utils import score_to_cart
|
19
19
|
from ..backends import backend as be
|
20
|
-
from ..matching_utils import split_shape
|
21
20
|
from ..types import BackendArray, NDArray
|
22
21
|
from ..rotations import euler_to_rotationmatrix
|
22
|
+
from ..matching_utils import split_shape, compute_extraction_box
|
23
23
|
|
24
24
|
__all__ = [
|
25
25
|
"PeakCaller",
|
@@ -228,10 +228,15 @@ class PeakCaller(AbstractAnalyzer):
|
|
228
228
|
translations = be.full(
|
229
229
|
(self.num_peaks, ndim), fill_value=-1, dtype=be._int_dtype
|
230
230
|
)
|
231
|
+
|
232
|
+
rdim = len(self.shape)
|
233
|
+
if self.batch_dims:
|
234
|
+
rdim -= len(self.batch_dims)
|
235
|
+
|
231
236
|
rotations = be.full(
|
232
|
-
(self.num_peaks,
|
237
|
+
(self.num_peaks, rdim, rdim), fill_value=0, dtype=be._float_dtype
|
233
238
|
)
|
234
|
-
for i in range(
|
239
|
+
for i in range(rdim):
|
235
240
|
rotations[:, i, i] = 1.0
|
236
241
|
|
237
242
|
scores = be.full((self.num_peaks,), fill_value=-1, dtype=be._float_dtype)
|
@@ -750,8 +755,7 @@ class PeakCallerRecursiveMasking(PeakCaller):
|
|
750
755
|
Dictionary mapping values in rotations to Euler angles.
|
751
756
|
By default None
|
752
757
|
min_score : float
|
753
|
-
Minimum score value to consider.
|
754
|
-
by :py:attr:`PeakCaller.num_peaks`.
|
758
|
+
Minimum score value to consider.
|
755
759
|
|
756
760
|
Returns
|
757
761
|
-------
|
@@ -765,54 +769,57 @@ class PeakCallerRecursiveMasking(PeakCaller):
|
|
765
769
|
values. If rotations and rotation_mapping is provided, the respective
|
766
770
|
rotation will be applied to the mask, otherwise rotation_matrix is used.
|
767
771
|
"""
|
768
|
-
|
772
|
+
peaks = []
|
773
|
+
box = tuple(self.min_distance for _ in range(scores.ndim))
|
769
774
|
|
770
|
-
|
771
|
-
|
772
|
-
|
773
|
-
mask = be.
|
775
|
+
scores = be.to_backend_array(scores)
|
776
|
+
if mask is not None:
|
777
|
+
box = mask.shape
|
778
|
+
mask = be.to_backend_array(mask)
|
779
|
+
mask_buffer = be.zeros(mask.shape, dtype=mask.dtype)
|
774
780
|
|
775
|
-
|
776
|
-
|
777
|
-
peak_limit = self.num_peaks
|
778
|
-
if min_score is not None:
|
779
|
-
peak_limit = be.size(scores)
|
780
|
-
else:
|
781
|
+
if min_score is None:
|
781
782
|
min_score = be.min(scores) - 1
|
782
783
|
|
783
|
-
|
784
|
-
|
785
|
-
|
784
|
+
_scores = be.zeros(scores.shape, dtype=scores.dtype)
|
785
|
+
_scores[:] = scores[:]
|
786
786
|
while True:
|
787
|
-
be.argmax(
|
788
|
-
peak
|
789
|
-
indices=be.argmax(scores_copy), shape=scores_copy.shape
|
790
|
-
)
|
791
|
-
if scores_copy[tuple(peak)] < min_score:
|
787
|
+
peak = be.unravel_index(indices=be.argmax(_scores), shape=_scores.shape)
|
788
|
+
if _scores[tuple(peak)] < min_score:
|
792
789
|
break
|
790
|
+
peaks.append(peak)
|
793
791
|
|
794
|
-
|
795
|
-
|
796
|
-
|
797
|
-
|
798
|
-
rotation_space=rotations,
|
799
|
-
rotation_mapping=rotation_mapping,
|
800
|
-
rotation_matrix=rotation_matrix,
|
792
|
+
score_beg, score_end, tmpl_beg, tmpl_end, _ = compute_extraction_box(
|
793
|
+
centers=be.to_backend_array(peak)[None],
|
794
|
+
extraction_shape=box,
|
795
|
+
original_shape=scores.shape,
|
801
796
|
)
|
802
|
-
|
803
|
-
|
804
|
-
|
805
|
-
|
806
|
-
|
807
|
-
mask=mask,
|
808
|
-
rotated_template=rotated_template,
|
797
|
+
score_slice = tuple(
|
798
|
+
slice(int(x), int(y)) for x, y in zip(score_beg[0], score_end[0])
|
799
|
+
)
|
800
|
+
tmpl_slice = tuple(
|
801
|
+
slice(int(x), int(y)) for x, y in zip(tmpl_beg[0], tmpl_end[0])
|
809
802
|
)
|
810
803
|
|
811
|
-
|
804
|
+
score_mask = 0
|
805
|
+
if mask is not None:
|
806
|
+
mask_buffer.fill(0)
|
807
|
+
rmat = self._get_rotation_matrix(
|
808
|
+
peak=peak,
|
809
|
+
rotation_space=rotations,
|
810
|
+
rotation_mapping=rotation_mapping,
|
811
|
+
rotation_matrix=rotation_matrix,
|
812
|
+
)
|
813
|
+
be.rigid_transform(
|
814
|
+
arr=mask, rotation_matrix=rmat, order=1, out=mask_buffer
|
815
|
+
)
|
816
|
+
score_mask = mask_buffer[tmpl_slice] <= 0.1
|
817
|
+
|
818
|
+
_scores[score_slice] = be.multiply(_scores[score_slice], score_mask)
|
819
|
+
if len(peaks) >= self.num_peaks:
|
812
820
|
break
|
813
821
|
|
814
|
-
|
815
|
-
return peaks, None
|
822
|
+
return be.to_backend_array(peaks), None
|
816
823
|
|
817
824
|
@staticmethod
|
818
825
|
def _get_rotation_matrix(
|
@@ -845,93 +852,13 @@ class PeakCallerRecursiveMasking(PeakCaller):
|
|
845
852
|
|
846
853
|
rotation = rotation_mapping[rotation_space[tuple(peak)]]
|
847
854
|
|
848
|
-
#
|
855
|
+
# Old versions of rotation mapping contained Euler angles
|
849
856
|
if rotation.ndim != 2:
|
850
857
|
rotation = be.to_backend_array(
|
851
858
|
euler_to_rotationmatrix(be.to_numpy_array(rotation))
|
852
859
|
)
|
853
860
|
return rotation
|
854
861
|
|
855
|
-
@staticmethod
|
856
|
-
def _mask_scores_box(
|
857
|
-
scores: BackendArray, peak: BackendArray, mask: BackendArray, **kwargs: Dict
|
858
|
-
) -> None:
|
859
|
-
"""
|
860
|
-
Mask scores in a box around a peak.
|
861
|
-
|
862
|
-
Parameters
|
863
|
-
----------
|
864
|
-
scores : BackendArray
|
865
|
-
Data array of scores.
|
866
|
-
peak : BackendArray
|
867
|
-
Peak coordinates.
|
868
|
-
mask : BackendArray
|
869
|
-
Mask array.
|
870
|
-
"""
|
871
|
-
start = be.maximum(be.subtract(peak, mask.shape), 0)
|
872
|
-
stop = be.minimum(be.add(peak, mask.shape), scores.shape)
|
873
|
-
start, stop = be.astype(start, int), be.astype(stop, int)
|
874
|
-
coords = tuple(slice(*pos) for pos in zip(start, stop))
|
875
|
-
scores[coords] = 0
|
876
|
-
return None
|
877
|
-
|
878
|
-
@staticmethod
|
879
|
-
def _mask_scores_rotate(
|
880
|
-
scores: BackendArray,
|
881
|
-
peak: BackendArray,
|
882
|
-
mask: BackendArray,
|
883
|
-
rotated_template: BackendArray,
|
884
|
-
rotation_matrix: BackendArray,
|
885
|
-
**kwargs: Dict,
|
886
|
-
) -> None:
|
887
|
-
"""
|
888
|
-
Mask scores using mask rotation around a peak.
|
889
|
-
|
890
|
-
Parameters
|
891
|
-
----------
|
892
|
-
scores : BackendArray
|
893
|
-
Data array of scores.
|
894
|
-
peak : BackendArray
|
895
|
-
Peak coordinates.
|
896
|
-
mask : BackendArray
|
897
|
-
Mask array.
|
898
|
-
rotated_template : BackendArray
|
899
|
-
Empty array to write mask rotations to.
|
900
|
-
rotation_matrix : BackendArray
|
901
|
-
Rotation matrix.
|
902
|
-
"""
|
903
|
-
left_pad = be.divide(mask.shape, 2).astype(int)
|
904
|
-
right_pad = be.add(left_pad, be.mod(mask.shape, 2).astype(int))
|
905
|
-
|
906
|
-
score_start = be.subtract(peak, left_pad)
|
907
|
-
score_stop = be.add(peak, right_pad)
|
908
|
-
|
909
|
-
template_start = be.subtract(be.maximum(score_start, 0), score_start)
|
910
|
-
template_stop = be.subtract(score_stop, be.minimum(score_stop, scores.shape))
|
911
|
-
template_stop = be.subtract(mask.shape, template_stop)
|
912
|
-
|
913
|
-
score_start = be.maximum(score_start, 0)
|
914
|
-
score_stop = be.minimum(score_stop, scores.shape)
|
915
|
-
score_start = be.astype(score_start, int)
|
916
|
-
score_stop = be.astype(score_stop, int)
|
917
|
-
|
918
|
-
template_start = be.astype(template_start, int)
|
919
|
-
template_stop = be.astype(template_stop, int)
|
920
|
-
coords_score = tuple(slice(*pos) for pos in zip(score_start, score_stop))
|
921
|
-
coords_template = tuple(
|
922
|
-
slice(*pos) for pos in zip(template_start, template_stop)
|
923
|
-
)
|
924
|
-
|
925
|
-
rotated_template.fill(0)
|
926
|
-
be.rigid_transform(
|
927
|
-
arr=mask, rotation_matrix=rotation_matrix, order=1, out=rotated_template
|
928
|
-
)
|
929
|
-
|
930
|
-
scores[coords_score] = be.multiply(
|
931
|
-
scores[coords_score], (rotated_template[coords_template] <= 0.1)
|
932
|
-
)
|
933
|
-
return None
|
934
|
-
|
935
862
|
|
936
863
|
class PeakCallerScipy(PeakCaller):
|
937
864
|
"""
|
tme/analyzer/proxy.py
CHANGED