pytme 0.2.9.post1__cp311-cp311-macosx_15_0_arm64.whl → 0.3.0__cp311-cp311-macosx_15_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pytme-0.3.0.data/scripts/estimate_memory_usage.py +76 -0
- pytme-0.3.0.data/scripts/match_template.py +1106 -0
- {pytme-0.2.9.post1.data → pytme-0.3.0.data}/scripts/postprocess.py +320 -190
- {pytme-0.2.9.post1.data → pytme-0.3.0.data}/scripts/preprocess.py +21 -31
- {pytme-0.2.9.post1.data → pytme-0.3.0.data}/scripts/preprocessor_gui.py +85 -19
- pytme-0.3.0.data/scripts/pytme_runner.py +771 -0
- {pytme-0.2.9.post1.dist-info → pytme-0.3.0.dist-info}/METADATA +21 -20
- pytme-0.3.0.dist-info/RECORD +126 -0
- {pytme-0.2.9.post1.dist-info → pytme-0.3.0.dist-info}/entry_points.txt +2 -1
- pytme-0.3.0.dist-info/licenses/LICENSE +339 -0
- scripts/estimate_memory_usage.py +76 -0
- scripts/eval.py +93 -0
- scripts/extract_candidates.py +224 -0
- scripts/match_template.py +349 -378
- pytme-0.2.9.post1.data/scripts/match_template.py → scripts/match_template_filters.py +213 -148
- scripts/postprocess.py +320 -190
- scripts/preprocess.py +21 -31
- scripts/preprocessor_gui.py +85 -19
- scripts/pytme_runner.py +771 -0
- scripts/refine_matches.py +625 -0
- tests/preprocessing/test_frequency_filters.py +28 -14
- tests/test_analyzer.py +41 -36
- tests/test_backends.py +1 -0
- tests/test_matching_cli.py +109 -54
- tests/test_matching_data.py +5 -5
- tests/test_matching_exhaustive.py +1 -2
- tests/test_matching_optimization.py +4 -9
- tests/test_matching_utils.py +1 -1
- tests/test_orientations.py +0 -1
- tme/__version__.py +1 -1
- tme/analyzer/__init__.py +2 -0
- tme/analyzer/_utils.py +26 -21
- tme/analyzer/aggregation.py +395 -222
- tme/analyzer/base.py +127 -0
- tme/analyzer/peaks.py +189 -204
- tme/analyzer/proxy.py +123 -0
- tme/backends/__init__.py +4 -3
- tme/backends/_cupy_utils.py +25 -24
- tme/backends/_jax_utils.py +20 -18
- tme/backends/cupy_backend.py +13 -26
- tme/backends/jax_backend.py +24 -23
- tme/backends/matching_backend.py +4 -3
- tme/backends/mlx_backend.py +4 -3
- tme/backends/npfftw_backend.py +34 -30
- tme/backends/pytorch_backend.py +18 -4
- tme/cli.py +126 -0
- tme/density.py +9 -7
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/filters/__init__.py +3 -3
- tme/filters/_utils.py +36 -10
- tme/filters/bandpass.py +229 -188
- tme/filters/compose.py +5 -4
- tme/filters/ctf.py +516 -254
- tme/filters/reconstruction.py +91 -32
- tme/filters/wedge.py +196 -135
- tme/filters/whitening.py +37 -42
- tme/matching_data.py +28 -39
- tme/matching_exhaustive.py +31 -27
- tme/matching_optimization.py +5 -4
- tme/matching_scores.py +25 -15
- tme/matching_utils.py +193 -27
- tme/memory.py +4 -3
- tme/orientations.py +22 -9
- tme/parser.py +114 -33
- tme/preprocessor.py +6 -5
- tme/rotations.py +10 -7
- tme/structure.py +4 -3
- pytme-0.2.9.post1.data/scripts/estimate_ram_usage.py +0 -97
- pytme-0.2.9.post1.dist-info/RECORD +0 -119
- pytme-0.2.9.post1.dist-info/licenses/LICENSE +0 -153
- scripts/estimate_ram_usage.py +0 -97
- tests/data/Maps/.DS_Store +0 -0
- tests/data/Structures/.DS_Store +0 -0
- {pytme-0.2.9.post1.dist-info → pytme-0.3.0.dist-info}/WHEEL +0 -0
- {pytme-0.2.9.post1.dist-info → pytme-0.3.0.dist-info}/top_level.txt +0 -0
tme/analyzer/aggregation.py
CHANGED
@@ -1,34 +1,34 @@
|
|
1
|
-
"""
|
1
|
+
"""
|
2
|
+
Implements classes to analyze outputs from exhaustive template matching.
|
2
3
|
|
3
|
-
|
4
|
+
Copyright (c) 2023 European Molecular Biology Laboratory
|
4
5
|
|
5
|
-
|
6
|
+
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
6
7
|
"""
|
7
8
|
|
8
|
-
from
|
9
|
-
from multiprocessing import Manager
|
10
|
-
from typing import Tuple, List, Dict, Generator
|
9
|
+
from typing import Tuple, List, Dict
|
11
10
|
|
12
11
|
import numpy as np
|
13
12
|
|
13
|
+
from .base import AbstractAnalyzer
|
14
14
|
from ..types import BackendArray
|
15
15
|
from ._utils import cart_to_score
|
16
16
|
from ..backends import backend as be
|
17
17
|
from ..matching_utils import (
|
18
18
|
create_mask,
|
19
19
|
array_to_memmap,
|
20
|
-
generate_tempfile_name,
|
21
20
|
apply_convolution_mode,
|
21
|
+
generate_tempfile_name,
|
22
22
|
)
|
23
23
|
|
24
|
-
|
25
24
|
__all__ = [
|
26
25
|
"MaxScoreOverRotations",
|
26
|
+
"MaxScoreOverRotationsConstrained",
|
27
27
|
"MaxScoreOverTranslations",
|
28
28
|
]
|
29
29
|
|
30
30
|
|
31
|
-
class MaxScoreOverRotations:
|
31
|
+
class MaxScoreOverRotations(AbstractAnalyzer):
|
32
32
|
"""
|
33
33
|
Determine the rotation maximizing the score over all possible translations.
|
34
34
|
|
@@ -36,10 +36,6 @@ class MaxScoreOverRotations:
|
|
36
36
|
----------
|
37
37
|
shape : tuple of int
|
38
38
|
Shape of array passed to :py:meth:`MaxScoreOverRotations.__call__`.
|
39
|
-
scores : BackendArray, optional
|
40
|
-
Array mapping translations to scores.
|
41
|
-
rotations : BackendArray, optional
|
42
|
-
Array mapping translations to rotation indices.
|
43
39
|
offset : BackendArray, optional
|
44
40
|
Coordinate origin considered during merging, zero by default.
|
45
41
|
score_threshold : float, optional
|
@@ -50,58 +46,47 @@ class MaxScoreOverRotations:
|
|
50
46
|
Memmap internal arrays, False by default.
|
51
47
|
thread_safe: bool, optional
|
52
48
|
Allow class to be modified by multiple processes, True by default.
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
scores : BackendArray
|
59
|
-
Mapping of translations to scores.
|
60
|
-
rotations : BackendArray
|
61
|
-
Mmapping of translations to rotation indices.
|
62
|
-
rotation_mapping : Dict
|
63
|
-
Mapping of rotations to rotation indices.
|
64
|
-
offset : BackendArray, optional
|
65
|
-
Coordinate origin considered during merging, zero by default
|
49
|
+
inversion_mapping : bool, optional
|
50
|
+
Do not use rotation matrix bytestrings for intermediate data handling.
|
51
|
+
This is useful for GPU backend where analyzers are not shared across
|
52
|
+
devices and every rotation is only observed once. It is generally
|
53
|
+
safe to deactivate inversion mapping, but at a cost of performance.
|
66
54
|
|
67
55
|
Examples
|
68
56
|
--------
|
69
57
|
The following achieves the minimal definition of a :py:class:`MaxScoreOverRotations`
|
70
58
|
instance
|
71
59
|
|
60
|
+
>>> import numpy as np
|
72
61
|
>>> from tme.analyzer import MaxScoreOverRotations
|
73
|
-
>>> analyzer = MaxScoreOverRotations(shape
|
62
|
+
>>> analyzer = MaxScoreOverRotations(shape=(50, 50))
|
74
63
|
|
75
64
|
The following simulates a template matching run by creating random data for a range
|
76
65
|
of rotations and sending it to ``analyzer`` via its __call__ method
|
77
66
|
|
67
|
+
>>> state = analyzer.init_state()
|
78
68
|
>>> for rotation_number in range(10):
|
79
69
|
>>> scores = np.random.rand(50,50)
|
80
70
|
>>> rotation = np.random.rand(scores.ndim, scores.ndim)
|
81
|
-
>>> analyzer(scores
|
71
|
+
>>> state = analyzer(state, scores=scores, rotation_matrix=rotation)
|
82
72
|
|
83
|
-
The aggregated scores can be extracted by invoking the
|
73
|
+
The aggregated scores can be extracted by invoking the result method of
|
84
74
|
``analyzer``
|
85
75
|
|
86
|
-
>>> results =
|
76
|
+
>>> results = analyzer.result(state)
|
87
77
|
|
88
78
|
The ``results`` tuple contains (1) the maximum scores for each translation,
|
89
79
|
(2) an offset which is relevant when merging results from split template matching
|
90
80
|
using :py:meth:`MaxScoreOverRotations.merge`, (3) the rotation used to obtain a
|
91
|
-
score for a given translation, (4) a dictionary mapping
|
92
|
-
|
81
|
+
score for a given translation, (4) a dictionary mapping indices used in (2) to
|
82
|
+
rotation matrices (2).
|
93
83
|
|
94
84
|
We can extract the ``optimal_score``, ``optimal_translation`` and ``optimal_rotation``
|
95
85
|
as follows
|
96
86
|
|
97
87
|
>>> optimal_score = results[0].max()
|
98
88
|
>>> optimal_translation = np.where(results[0] == results[0].max())
|
99
|
-
>>>
|
100
|
-
>>> for key, value in results[3].items():
|
101
|
-
>>> if value != optimal_rotation_index:
|
102
|
-
>>> continue
|
103
|
-
>>> optimal_rotation = np.frombuffer(key, rotation.dtype)
|
104
|
-
>>> optimal_rotation = optimal_rotation.reshape(scores.ndim, scores.ndim)
|
89
|
+
>>> optimal_rotation = results[2][optimal_translation]
|
105
90
|
|
106
91
|
The outlined procedure is a trivial method to identify high scoring peaks.
|
107
92
|
Alternatively, :py:class:`PeakCaller` offers a range of more elaborate approaches
|
@@ -111,156 +96,221 @@ class MaxScoreOverRotations:
|
|
111
96
|
def __init__(
|
112
97
|
self,
|
113
98
|
shape: Tuple[int],
|
114
|
-
scores: BackendArray = None,
|
115
|
-
rotations: BackendArray = None,
|
116
99
|
offset: BackendArray = None,
|
117
100
|
score_threshold: float = 0,
|
118
101
|
shm_handler: object = None,
|
119
102
|
use_memmap: bool = False,
|
120
|
-
|
121
|
-
|
103
|
+
inversion_mapping: bool = False,
|
104
|
+
jax_mode: bool = False,
|
122
105
|
**kwargs,
|
123
106
|
):
|
107
|
+
self._use_memmap = use_memmap
|
108
|
+
self._score_threshold = score_threshold
|
124
109
|
self._shape = tuple(int(x) for x in shape)
|
110
|
+
self._inversion_mapping = inversion_mapping
|
125
111
|
|
126
|
-
self.
|
127
|
-
if self.
|
128
|
-
self.
|
129
|
-
shape=self._shape, dtype=be._float_dtype, fill_value=score_threshold
|
130
|
-
)
|
131
|
-
self.rotations = rotations
|
132
|
-
if self.rotations is None:
|
133
|
-
self.rotations = be.full(self._shape, dtype=be._int_dtype, fill_value=-1)
|
134
|
-
|
135
|
-
self.scores = be.to_sharedarr(self.scores, shm_handler)
|
136
|
-
self.rotations = be.to_sharedarr(self.rotations, shm_handler)
|
112
|
+
self._jax_mode = jax_mode
|
113
|
+
if self._jax_mode:
|
114
|
+
self._inversion_mapping = False
|
137
115
|
|
138
116
|
if offset is None:
|
139
117
|
offset = be.zeros(len(self._shape), be._int_dtype)
|
140
|
-
self.
|
118
|
+
self._offset = be.astype(be.to_backend_array(offset), int)
|
141
119
|
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
120
|
+
@property
|
121
|
+
def shareable(self):
|
122
|
+
return True
|
123
|
+
|
124
|
+
def init_state(self):
|
125
|
+
"""
|
126
|
+
Initialize the analysis state.
|
127
|
+
|
128
|
+
Returns
|
129
|
+
-------
|
130
|
+
tuple
|
131
|
+
Initial state tuple containing (scores, rotations, rotation_mapping) where:
|
132
|
+
- scores : BackendArray of shape `self._shape` filled with `score_threshold`.
|
133
|
+
- rotations : BackendArray of shape `self._shape` filled with -1.
|
134
|
+
- rotation_mapping : dict, empty mapping from rotation bytes to indices.
|
135
|
+
"""
|
136
|
+
scores = be.full(
|
137
|
+
shape=self._shape, dtype=be._float_dtype, fill_value=self._score_threshold
|
138
|
+
)
|
139
|
+
rotations = be.full(self._shape, dtype=be._int_dtype, fill_value=-1)
|
140
|
+
return scores, rotations, {}
|
141
|
+
|
142
|
+
def __call__(
|
143
|
+
self,
|
144
|
+
state: Tuple,
|
145
|
+
scores: BackendArray,
|
146
|
+
rotation_matrix: BackendArray,
|
147
|
+
**kwargs,
|
148
|
+
) -> Tuple:
|
149
|
+
"""
|
150
|
+
Update the parameter store.
|
151
|
+
|
152
|
+
Parameters
|
153
|
+
----------
|
154
|
+
state : tuple
|
155
|
+
Current state tuple (scores, rotations, rotation_mapping) where:
|
156
|
+
- scores : BackendArray, current maximum scores.
|
157
|
+
- rotations : BackendArray, current rotation indices.
|
158
|
+
- rotation_mapping : dict, mapping from rotation bytes to indices.
|
159
|
+
scores : BackendArray
|
160
|
+
Array of new scores to update analyzer with.
|
161
|
+
rotation_matrix : BackendArray
|
162
|
+
Square matrix used to obtain the current rotation.
|
163
|
+
Returns
|
164
|
+
-------
|
165
|
+
tuple
|
166
|
+
Updated state tuple (scores, rotations, rotation_mapping).
|
167
|
+
"""
|
168
|
+
# be.tobytes behaviour caused overhead for certain GPU/CUDA combinations
|
169
|
+
# If the analyzer is not shared and each rotation is unique, we can
|
170
|
+
# use index to rotation mapping and invert prior to merging.
|
171
|
+
prev_scores, rotations, rotation_mapping = state
|
172
|
+
|
173
|
+
rotation_index = len(rotation_mapping)
|
174
|
+
rotation_matrix = be.astype(rotation_matrix, be._float_dtype)
|
175
|
+
if self._inversion_mapping:
|
176
|
+
rotation_mapping[rotation_index] = rotation_matrix
|
177
|
+
elif self._jax_mode:
|
178
|
+
rotation_index = kwargs.get("rotation_index", 0)
|
179
|
+
else:
|
180
|
+
rotation = be.tobytes(rotation_matrix)
|
181
|
+
rotation_index = rotation_mapping.setdefault(rotation, rotation_index)
|
182
|
+
|
183
|
+
scores, rotations = be.max_score_over_rotations(
|
184
|
+
scores=scores,
|
185
|
+
max_scores=prev_scores,
|
186
|
+
rotations=rotations,
|
187
|
+
rotation_index=rotation_index,
|
188
|
+
)
|
189
|
+
return scores, rotations, rotation_mapping
|
147
190
|
|
148
|
-
|
191
|
+
@staticmethod
|
192
|
+
def _invert_rmap(rotation_mapping: dict) -> dict:
|
193
|
+
"""
|
194
|
+
Invert dictionary from rotation matrix bytestrings mapping to rotation
|
195
|
+
indices ro rotation indices mapping to rotation matrices.
|
196
|
+
"""
|
197
|
+
new_map, ndim = {}, None
|
198
|
+
for k, v in rotation_mapping.items():
|
199
|
+
nbytes = be.datatype_bytes(be._float_dtype)
|
200
|
+
dtype = np.float32 if nbytes == 4 else np.float16
|
201
|
+
rmat = np.frombuffer(k, dtype=dtype)
|
202
|
+
if ndim is None:
|
203
|
+
ndim = int(np.sqrt(rmat.size))
|
204
|
+
new_map[v] = rmat.reshape(ndim, ndim)
|
205
|
+
return new_map
|
206
|
+
|
207
|
+
def result(
|
149
208
|
self,
|
150
|
-
|
151
|
-
|
152
|
-
|
209
|
+
state,
|
210
|
+
targetshape: Tuple[int] = None,
|
211
|
+
templateshape: Tuple[int] = None,
|
212
|
+
convolution_shape: Tuple[int] = None,
|
153
213
|
fourier_shift: Tuple[int] = None,
|
154
214
|
convolution_mode: str = None,
|
155
|
-
shm_handler=None,
|
156
215
|
**kwargs,
|
157
|
-
) ->
|
158
|
-
"""
|
159
|
-
|
160
|
-
|
216
|
+
) -> Tuple:
|
217
|
+
"""
|
218
|
+
Finalize the analysis result with optional postprocessing.
|
219
|
+
|
220
|
+
Parameters
|
221
|
+
----------
|
222
|
+
state : tuple
|
223
|
+
Current state tuple (scores, rotations, rotation_mapping) where:
|
224
|
+
- scores : BackendArray, current maximum scores.
|
225
|
+
- rotations : BackendArray, current rotation indices.
|
226
|
+
- rotation_mapping : dict, mapping from rotation indices to matrices.
|
227
|
+
targetshape : Tuple[int], optional
|
228
|
+
Shape of the target for convolution mode correction.
|
229
|
+
templateshape : Tuple[int], optional
|
230
|
+
Shape of the template for convolution mode correction.
|
231
|
+
convolution_shape : Tuple[int], optional
|
232
|
+
Shape used for convolution.
|
233
|
+
fourier_shift : Tuple[int], optional.
|
234
|
+
Shift to apply for Fourier correction.
|
235
|
+
convolution_mode : str, optional
|
236
|
+
Convolution mode for padding correction.
|
237
|
+
**kwargs
|
238
|
+
Additional keyword arguments.
|
239
|
+
|
240
|
+
Returns
|
241
|
+
-------
|
242
|
+
tuple
|
243
|
+
Final result tuple (scores, offset, rotations, rotation_mapping).
|
244
|
+
"""
|
245
|
+
scores, rotations, rotation_mapping = state
|
246
|
+
|
247
|
+
# Apply postprocessing if parameters are provided
|
161
248
|
if fourier_shift is not None:
|
162
249
|
axis = tuple(i for i in range(len(fourier_shift)))
|
163
250
|
scores = be.roll(scores, shift=fourier_shift, axis=axis)
|
164
251
|
rotations = be.roll(rotations, shift=fourier_shift, axis=axis)
|
165
252
|
|
166
|
-
convargs = {
|
167
|
-
"s1": targetshape,
|
168
|
-
"s2": templateshape,
|
169
|
-
"convolution_mode": convolution_mode,
|
170
|
-
"convolution_shape": convolution_shape,
|
171
|
-
}
|
172
253
|
if convolution_mode is not None:
|
254
|
+
convargs = {
|
255
|
+
"s1": targetshape,
|
256
|
+
"s2": templateshape,
|
257
|
+
"convolution_mode": convolution_mode,
|
258
|
+
"convolution_shape": convolution_shape,
|
259
|
+
}
|
173
260
|
scores = apply_convolution_mode(scores, **convargs)
|
174
261
|
rotations = apply_convolution_mode(rotations, **convargs)
|
175
262
|
|
176
|
-
self._shape, self.scores, self.rotations = scores.shape, scores, rotations
|
177
|
-
if shm_handler is not None:
|
178
|
-
self.scores = be.to_sharedarr(scores, shm_handler)
|
179
|
-
self.rotations = be.to_sharedarr(rotations, shm_handler)
|
180
|
-
return self
|
181
|
-
|
182
|
-
def __iter__(self) -> Generator:
|
183
|
-
scores = be.from_sharedarr(self.scores)
|
184
|
-
rotations = be.from_sharedarr(self.rotations)
|
185
|
-
|
186
263
|
scores = be.to_numpy_array(scores)
|
187
264
|
rotations = be.to_numpy_array(rotations)
|
188
265
|
if self._use_memmap:
|
189
266
|
scores = array_to_memmap(scores)
|
190
267
|
rotations = array_to_memmap(rotations)
|
191
|
-
else:
|
192
|
-
if type(self.scores) is not type(scores):
|
193
|
-
# Copy to avoid invalidation by shared memory handler
|
194
|
-
scores, rotations = scores.copy(), rotations.copy()
|
195
268
|
|
196
269
|
if self._inversion_mapping:
|
197
|
-
|
198
|
-
be.tobytes(v): k for k, v in self.rotation_mapping.items()
|
199
|
-
}
|
270
|
+
rotation_mapping = {be.tobytes(v): k for k, v in rotation_mapping.items()}
|
200
271
|
|
201
|
-
|
272
|
+
return (
|
202
273
|
scores,
|
203
|
-
be.to_numpy_array(self.
|
274
|
+
be.to_numpy_array(self._offset),
|
204
275
|
rotations,
|
205
|
-
|
276
|
+
self._invert_rmap(rotation_mapping),
|
206
277
|
)
|
207
|
-
yield from param_store
|
208
278
|
|
209
|
-
def
|
279
|
+
def _harmonize_states(states: List[Tuple]):
|
210
280
|
"""
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
----------
|
215
|
-
scores : BackendArray
|
216
|
-
Array of scores.
|
217
|
-
rotation_matrix : BackendArray
|
218
|
-
Square matrix describing the current rotation.
|
281
|
+
Create consistent reference frame for merging different analyzer
|
282
|
+
instances, w.r.t. to rotations and output shape from different
|
283
|
+
splits of the target.
|
219
284
|
"""
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
rotation_index = len(self.rotation_mapping)
|
225
|
-
if self._inversion_mapping:
|
226
|
-
self.rotation_mapping[rotation_index] = rotation_matrix
|
227
|
-
else:
|
228
|
-
rotation = be.tobytes(rotation_matrix)
|
229
|
-
rotation_index = self.rotation_mapping.setdefault(
|
230
|
-
rotation, rotation_index
|
231
|
-
)
|
232
|
-
self.scores, self.rotations = be.max_score_over_rotations(
|
233
|
-
scores=scores,
|
234
|
-
max_scores=self.scores,
|
235
|
-
rotations=self.rotations,
|
236
|
-
rotation_index=rotation_index,
|
237
|
-
)
|
238
|
-
return None
|
285
|
+
new_rotation_mapping, out_shape = {}, None
|
286
|
+
for i in range(len(states)):
|
287
|
+
if states[i] is None:
|
288
|
+
continue
|
239
289
|
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
290
|
+
scores, offset, rotations, rotation_mapping = states[i]
|
291
|
+
if out_shape is None:
|
292
|
+
out_shape = np.zeros(scores.ndim, int)
|
293
|
+
out_shape = np.maximum(out_shape, np.add(offset, scores.shape))
|
294
|
+
|
295
|
+
new_param = {}
|
296
|
+
for key, value in rotation_mapping.items():
|
297
|
+
rotation_bytes = be.tobytes(value)
|
298
|
+
new_param[rotation_bytes] = key
|
299
|
+
if rotation_bytes not in new_rotation_mapping:
|
300
|
+
new_rotation_mapping[rotation_bytes] = len(new_rotation_mapping)
|
301
|
+
states[i] = (scores, offset, rotations, new_param)
|
302
|
+
out_shape = tuple(int(x) for x in out_shape)
|
303
|
+
return new_rotation_mapping, out_shape, states
|
254
304
|
|
255
305
|
@classmethod
|
256
|
-
def merge(cls,
|
306
|
+
def merge(cls, results: List[Tuple], **kwargs) -> Tuple:
|
257
307
|
"""
|
258
308
|
Merge multiple instances of the current class.
|
259
309
|
|
260
310
|
Parameters
|
261
311
|
----------
|
262
|
-
|
263
|
-
List of instance's internal state created by applying `
|
312
|
+
results : list of tuple
|
313
|
+
List of instance's internal state created by applying `result`.
|
264
314
|
**kwargs : dict, optional
|
265
315
|
Optional keyword arguments.
|
266
316
|
|
@@ -276,8 +326,8 @@ class MaxScoreOverRotations:
|
|
276
326
|
Mapping between rotations and rotation indices.
|
277
327
|
"""
|
278
328
|
use_memmap = kwargs.get("use_memmap", False)
|
279
|
-
if len(
|
280
|
-
ret =
|
329
|
+
if len(results) == 1:
|
330
|
+
ret = results[0]
|
281
331
|
if use_memmap:
|
282
332
|
scores, offset, rotations, rotation_mapping = ret
|
283
333
|
scores = array_to_memmap(scores)
|
@@ -287,25 +337,12 @@ class MaxScoreOverRotations:
|
|
287
337
|
return ret
|
288
338
|
|
289
339
|
# Determine output array shape and create consistent rotation map
|
290
|
-
|
291
|
-
for i in range(len(param_stores)):
|
292
|
-
if param_stores[i] is None:
|
293
|
-
continue
|
294
|
-
|
295
|
-
scores, offset, rotations, rotation_mapping = param_stores[i]
|
296
|
-
if out_shape is None:
|
297
|
-
out_shape = np.zeros(scores.ndim, int)
|
298
|
-
scores_dtype, rotations_dtype = scores.dtype, rotations.dtype
|
299
|
-
out_shape = np.maximum(out_shape, np.add(offset, scores.shape))
|
300
|
-
|
301
|
-
for key, value in rotation_mapping.items():
|
302
|
-
if key not in new_rotation_mapping:
|
303
|
-
new_rotation_mapping[key] = len(new_rotation_mapping)
|
304
|
-
|
340
|
+
master_rotation_mapping, out_shape, results = cls._harmonize_states(results)
|
305
341
|
if out_shape is None:
|
306
342
|
return None
|
307
343
|
|
308
|
-
|
344
|
+
scores_dtype = results[0][0].dtype
|
345
|
+
rotations_dtype = results[0][2].dtype
|
309
346
|
if use_memmap:
|
310
347
|
scores_out_filename = generate_tempfile_name()
|
311
348
|
rotations_out_filename = generate_tempfile_name()
|
@@ -331,8 +368,8 @@ class MaxScoreOverRotations:
|
|
331
368
|
)
|
332
369
|
rotations_out = np.full(out_shape, fill_value=-1, dtype=rotations_dtype)
|
333
370
|
|
334
|
-
for i in range(len(
|
335
|
-
if
|
371
|
+
for i in range(len(results)):
|
372
|
+
if results[i] is None:
|
336
373
|
continue
|
337
374
|
|
338
375
|
if use_memmap:
|
@@ -348,7 +385,7 @@ class MaxScoreOverRotations:
|
|
348
385
|
shape=out_shape,
|
349
386
|
dtype=rotations_dtype,
|
350
387
|
)
|
351
|
-
scores, offset, rotations, rotation_mapping =
|
388
|
+
scores, offset, rotations, rotation_mapping = results[i]
|
352
389
|
stops = np.add(offset, scores.shape).astype(int)
|
353
390
|
indices = tuple(slice(*pos) for pos in zip(offset, stops))
|
354
391
|
|
@@ -359,7 +396,7 @@ class MaxScoreOverRotations:
|
|
359
396
|
len(rotation_mapping) + 1, dtype=rotations_out.dtype
|
360
397
|
)
|
361
398
|
for key, value in rotation_mapping.items():
|
362
|
-
lookup_table[value] =
|
399
|
+
lookup_table[value] = master_rotation_mapping[key]
|
363
400
|
|
364
401
|
updated_rotations = rotations[indices_update]
|
365
402
|
if len(updated_rotations):
|
@@ -372,7 +409,7 @@ class MaxScoreOverRotations:
|
|
372
409
|
rotations_out.flush()
|
373
410
|
scores_out, rotations_out = None, None
|
374
411
|
|
375
|
-
|
412
|
+
results[i] = None
|
376
413
|
scores, rotations = None, None
|
377
414
|
|
378
415
|
if use_memmap:
|
@@ -390,13 +427,165 @@ class MaxScoreOverRotations:
|
|
390
427
|
scores_out,
|
391
428
|
np.zeros(scores_out.ndim, dtype=int),
|
392
429
|
rotations_out,
|
393
|
-
|
430
|
+
cls._invert_rmap(master_rotation_mapping),
|
394
431
|
)
|
395
432
|
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
433
|
+
|
434
|
+
class MaxScoreOverRotationsConstrained(MaxScoreOverRotations):
|
435
|
+
"""
|
436
|
+
Implements constrained template matching using rejection sampling.
|
437
|
+
|
438
|
+
Parameters
|
439
|
+
----------
|
440
|
+
cone_angle : float
|
441
|
+
Maximum accepted rotational deviation in degrees.
|
442
|
+
positions : BackendArray
|
443
|
+
Array of shape (n, d) with n seed point translations.
|
444
|
+
positions : BackendArray
|
445
|
+
Array of shape (n, d, d) with n seed point rotation matrices.
|
446
|
+
reference : BackendArray
|
447
|
+
Reference orientation of the template, wlog defaults to (0,0,1).
|
448
|
+
acceptance_radius : int or tuple of ints
|
449
|
+
Translational acceptance radius around seed point in voxels.
|
450
|
+
**kwargs : dict, optional
|
451
|
+
Keyword aguments passed to the constructor of :py:class:`MaxScoreOverRotations`.
|
452
|
+
"""
|
453
|
+
|
454
|
+
def __init__(
|
455
|
+
self,
|
456
|
+
cone_angle: float,
|
457
|
+
positions: BackendArray,
|
458
|
+
rotations: BackendArray,
|
459
|
+
reference: BackendArray = (0, 0, 1),
|
460
|
+
acceptance_radius: int = 10,
|
461
|
+
**kwargs,
|
462
|
+
):
|
463
|
+
MaxScoreOverRotations.__init__(self, **kwargs)
|
464
|
+
|
465
|
+
if not isinstance(acceptance_radius, (int, Tuple)):
|
466
|
+
raise ValueError("acceptance_radius needs to be of type int or tuple.")
|
467
|
+
|
468
|
+
if isinstance(acceptance_radius, int):
|
469
|
+
acceptance_radius = (
|
470
|
+
acceptance_radius,
|
471
|
+
acceptance_radius,
|
472
|
+
acceptance_radius,
|
473
|
+
)
|
474
|
+
acceptance_radius = tuple(int(x) for x in acceptance_radius)
|
475
|
+
|
476
|
+
self._cone_angle = float(np.radians(cone_angle))
|
477
|
+
self._cone_cutoff = float(np.tan(self._cone_angle))
|
478
|
+
self._reference = be.astype(
|
479
|
+
be.reshape(be.to_backend_array(reference), (-1,)), be._float_dtype
|
480
|
+
)
|
481
|
+
positions = be.astype(be.to_backend_array(positions), be._int_dtype)
|
482
|
+
|
483
|
+
ndim = positions.shape[1]
|
484
|
+
rotate_mask = len(set(acceptance_radius)) != 1
|
485
|
+
extend = max(acceptance_radius)
|
486
|
+
mask = create_mask(
|
487
|
+
mask_type="ellipse",
|
488
|
+
radius=acceptance_radius,
|
489
|
+
shape=tuple(2 * extend + 1 for _ in range(ndim)),
|
490
|
+
center=tuple(extend for _ in range(ndim)),
|
491
|
+
)
|
492
|
+
self._score_mask = be.astype(be.to_backend_array(mask), be._float_dtype)
|
493
|
+
|
494
|
+
# Map position from real space to shifted score space
|
495
|
+
lower_limit = be.to_backend_array(self._offset)
|
496
|
+
positions = be.subtract(positions, lower_limit)
|
497
|
+
positions, valid_positions = cart_to_score(
|
498
|
+
positions=positions,
|
499
|
+
fast_shape=kwargs.get("fast_shape", None),
|
500
|
+
targetshape=kwargs.get("targetshape", None),
|
501
|
+
templateshape=kwargs.get("templateshape", None),
|
502
|
+
fourier_shift=kwargs.get("fourier_shift", None),
|
503
|
+
convolution_mode=kwargs.get("convolution_mode", None),
|
504
|
+
convolution_shape=kwargs.get("convolution_shape", None),
|
505
|
+
)
|
506
|
+
|
507
|
+
self._positions = positions[valid_positions]
|
508
|
+
rotations = be.to_backend_array(rotations)[valid_positions]
|
509
|
+
ex = be.astype(be.to_backend_array((1, 0, 0)), be._float_dtype)
|
510
|
+
ey = be.astype(be.to_backend_array((0, 1, 0)), be._float_dtype)
|
511
|
+
ez = be.astype(be.to_backend_array((0, 0, 1)), be._float_dtype)
|
512
|
+
|
513
|
+
self._normals_x = (rotations @ ex[..., None])[..., 0]
|
514
|
+
self._normals_y = (rotations @ ey[..., None])[..., 0]
|
515
|
+
self._normals_z = (rotations @ ez[..., None])[..., 0]
|
516
|
+
|
517
|
+
# Periodic wrapping could be avoided by padding the target
|
518
|
+
shape = be.to_backend_array(self._shape)
|
519
|
+
starts = be.subtract(self._positions, extend)
|
520
|
+
ret, (n, d), mshape = [], self._positions.shape, self._score_mask.shape
|
521
|
+
if starts.shape[0] > 0:
|
522
|
+
for i in range(d):
|
523
|
+
indices = starts[:, slice(i, i + 1)] + be.arange(mshape[i])[None]
|
524
|
+
indices = be.mod(indices, shape[i], out=indices)
|
525
|
+
indices_shape = (n, *tuple(1 if k != i else -1 for k in range(d)))
|
526
|
+
ret.append(be.reshape(indices, indices_shape))
|
527
|
+
|
528
|
+
self._index_grid = tuple(ret)
|
529
|
+
self._mask_shape = tuple(1 if i != 0 else -1 for i in range(1 + ndim))
|
530
|
+
|
531
|
+
if rotate_mask:
|
532
|
+
self._score_mask = be.zeros(
|
533
|
+
(rotations.shape[0], *self._score_mask.shape), dtype=be._float_dtype
|
534
|
+
)
|
535
|
+
for i in range(rotations.shape[0]):
|
536
|
+
mask = create_mask(
|
537
|
+
mask_type="ellipse",
|
538
|
+
radius=acceptance_radius,
|
539
|
+
shape=tuple(2 * extend + 1 for _ in range(ndim)),
|
540
|
+
center=tuple(extend for _ in range(ndim)),
|
541
|
+
orientation=be.to_numpy_array(rotations[i]),
|
542
|
+
)
|
543
|
+
self._score_mask[i] = be.astype(
|
544
|
+
be.to_backend_array(mask), be._float_dtype
|
545
|
+
)
|
546
|
+
|
547
|
+
def __call__(
|
548
|
+
self, state: Tuple, scores: BackendArray, rotation_matrix: BackendArray
|
549
|
+
) -> Tuple:
|
550
|
+
mask = self._get_constraint(rotation_matrix)
|
551
|
+
mask = self._get_score_mask(mask=mask, scores=scores)
|
552
|
+
|
553
|
+
scores = be.multiply(scores, mask, out=scores)
|
554
|
+
return super().__call__(state, scores=scores, rotation_matrix=rotation_matrix)
|
555
|
+
|
556
|
+
def _get_constraint(self, rotation_matrix: BackendArray) -> BackendArray:
|
557
|
+
"""
|
558
|
+
Determine whether the angle between projection of reference w.r.t to
|
559
|
+
a given rotation matrix and a set of rotations fall within the set
|
560
|
+
cone_angle cutoff.
|
561
|
+
|
562
|
+
Parameters
|
563
|
+
----------
|
564
|
+
rotation_matrix : BackendArray
|
565
|
+
Rotation matrix with shape (d,d).
|
566
|
+
|
567
|
+
Returns
|
568
|
+
-------
|
569
|
+
BackerndArray
|
570
|
+
Boolean mask of shape (n, )
|
571
|
+
"""
|
572
|
+
template_rot = rotation_matrix @ self._reference
|
573
|
+
|
574
|
+
x = be.sum(be.multiply(self._normals_x, template_rot), axis=1)
|
575
|
+
y = be.sum(be.multiply(self._normals_y, template_rot), axis=1)
|
576
|
+
z = be.sum(be.multiply(self._normals_z, template_rot), axis=1)
|
577
|
+
|
578
|
+
return be.sqrt(x**2 + y**2) <= (z * self._cone_cutoff)
|
579
|
+
|
580
|
+
def _get_score_mask(self, mask: BackendArray, scores: BackendArray, **kwargs):
|
581
|
+
score_mask = be.zeros(scores.shape, scores.dtype)
|
582
|
+
|
583
|
+
if be.sum(mask) == 0:
|
584
|
+
return score_mask
|
585
|
+
mask = be.reshape(mask, self._mask_shape)
|
586
|
+
|
587
|
+
score_mask = be.addat(score_mask, self._index_grid, self._score_mask * mask)
|
588
|
+
return score_mask > 0
|
400
589
|
|
401
590
|
|
402
591
|
class MaxScoreOverTranslations(MaxScoreOverRotations):
|
@@ -437,43 +626,40 @@ class MaxScoreOverTranslations(MaxScoreOverRotations):
|
|
437
626
|
super().__init__(
|
438
627
|
shape=shape_reduced, shm_handler=shm_handler, offset=offset, **kwargs
|
439
628
|
)
|
440
|
-
|
441
|
-
self.rotations = be.full(1, dtype=be._int_dtype, fill_value=-1)
|
442
|
-
self.rotations = be.to_sharedarr(self.rotations, shm_handler)
|
443
629
|
self._aggregate_axis = aggregate_axis
|
444
630
|
|
445
|
-
def
|
446
|
-
|
447
|
-
|
448
|
-
|
449
|
-
|
450
|
-
|
451
|
-
rotation = be.tobytes(rotation_matrix)
|
452
|
-
rotation_index = self.rotation_mapping.setdefault(
|
453
|
-
rotation, rotation_index
|
454
|
-
)
|
455
|
-
max_score = be.max(scores, axis=self._aggregate_axis)
|
456
|
-
self.scores[rotation_index] = max_score
|
457
|
-
return None
|
631
|
+
def init_state(self):
|
632
|
+
scores = be.full(
|
633
|
+
shape=self._shape, dtype=be._float_dtype, fill_value=self._score_threshold
|
634
|
+
)
|
635
|
+
rotations = be.full(1, dtype=be._int_dtype, fill_value=-1)
|
636
|
+
return scores, rotations, {}
|
458
637
|
|
459
|
-
|
460
|
-
|
461
|
-
|
462
|
-
|
463
|
-
|
464
|
-
|
465
|
-
|
466
|
-
|
467
|
-
|
638
|
+
def __call__(
|
639
|
+
self, state, scores: BackendArray, rotation_matrix: BackendArray
|
640
|
+
) -> Tuple:
|
641
|
+
prev_scores, rotations, rotation_mapping = state
|
642
|
+
|
643
|
+
rotation_index = len(rotation_mapping)
|
644
|
+
if self._inversion_mapping:
|
645
|
+
rotation_mapping[rotation_index] = rotation_matrix
|
646
|
+
else:
|
647
|
+
rotation = be.tobytes(rotation_matrix)
|
648
|
+
rotation_index = rotation_mapping.setdefault(rotation, rotation_index)
|
649
|
+
max_score = be.max(scores, axis=self._aggregate_axis)
|
650
|
+
|
651
|
+
update = prev_scores[rotation_index]
|
652
|
+
update = be.maximum(max_score, update, out=update)
|
653
|
+
return prev_scores, rotations, rotation_mapping
|
468
654
|
|
469
655
|
@classmethod
|
470
|
-
def merge(cls,
|
656
|
+
def merge(cls, states: List[Tuple], **kwargs) -> Tuple:
|
471
657
|
"""
|
472
658
|
Merge multiple instances of the current class.
|
473
659
|
|
474
660
|
Parameters
|
475
661
|
----------
|
476
|
-
|
662
|
+
states : list of tuple
|
477
663
|
List of instance's internal state created by applying `tuple(instance)`.
|
478
664
|
**kwargs : dict, optional
|
479
665
|
Optional keyword arguments.
|
@@ -489,31 +675,18 @@ class MaxScoreOverTranslations(MaxScoreOverRotations):
|
|
489
675
|
Dict
|
490
676
|
Mapping between rotations and rotation indices.
|
491
677
|
"""
|
492
|
-
if len(
|
493
|
-
return
|
678
|
+
if len(states) == 1:
|
679
|
+
return states[0]
|
494
680
|
|
495
681
|
# Determine output array shape and create consistent rotation map
|
496
|
-
|
497
|
-
for i in range(len(param_stores)):
|
498
|
-
if param_stores[i] is None:
|
499
|
-
continue
|
500
|
-
|
501
|
-
scores, offset, rotations, rotation_mapping = param_stores[i]
|
502
|
-
if out_shape is None:
|
503
|
-
out_shape = np.zeros(scores.ndim, int)
|
504
|
-
scores_dtype, rotations_out = scores.dtype, rotations
|
505
|
-
out_shape = np.maximum(out_shape, np.add(offset, scores.shape))
|
506
|
-
|
507
|
-
for key, value in rotation_mapping.items():
|
508
|
-
if key not in new_rotation_mapping:
|
509
|
-
new_rotation_mapping[key] = len(new_rotation_mapping)
|
510
|
-
|
682
|
+
states, master_rotation_mapping, out_shape = cls._harmonize_states(states)
|
511
683
|
if out_shape is None:
|
512
684
|
return None
|
513
685
|
|
514
|
-
out_shape[0] = len(
|
686
|
+
out_shape[0] = len(master_rotation_mapping)
|
515
687
|
out_shape = tuple(int(x) for x in out_shape)
|
516
688
|
|
689
|
+
scores_dtype = states[0][0].dtype
|
517
690
|
use_memmap = kwargs.get("use_memmap", False)
|
518
691
|
if use_memmap:
|
519
692
|
scores_out_filename = generate_tempfile_name()
|
@@ -529,8 +702,8 @@ class MaxScoreOverTranslations(MaxScoreOverRotations):
|
|
529
702
|
dtype=scores_dtype,
|
530
703
|
)
|
531
704
|
|
532
|
-
for i in range(len(
|
533
|
-
if
|
705
|
+
for i in range(len(states)):
|
706
|
+
if states[i] is None:
|
534
707
|
continue
|
535
708
|
|
536
709
|
if use_memmap:
|
@@ -540,11 +713,11 @@ class MaxScoreOverTranslations(MaxScoreOverRotations):
|
|
540
713
|
shape=out_shape,
|
541
714
|
dtype=scores_dtype,
|
542
715
|
)
|
543
|
-
scores, offset, rotations, rotation_mapping =
|
716
|
+
scores, offset, rotations, rotation_mapping = states[i]
|
544
717
|
|
545
718
|
outer_table = np.arange(len(rotation_mapping), dtype=int)
|
546
719
|
lookup_table = np.array(
|
547
|
-
[
|
720
|
+
[master_rotation_mapping[key] for key in rotation_mapping.keys()],
|
548
721
|
dtype=int,
|
549
722
|
)
|
550
723
|
|
@@ -560,7 +733,7 @@ class MaxScoreOverTranslations(MaxScoreOverRotations):
|
|
560
733
|
scores_out.flush()
|
561
734
|
scores_out = None
|
562
735
|
|
563
|
-
|
736
|
+
states[i], scores = None, None
|
564
737
|
|
565
738
|
if use_memmap:
|
566
739
|
scores_out = np.memmap(
|
@@ -570,9 +743,9 @@ class MaxScoreOverTranslations(MaxScoreOverRotations):
|
|
570
743
|
return (
|
571
744
|
scores_out,
|
572
745
|
np.zeros(scores_out.ndim, dtype=int),
|
573
|
-
|
574
|
-
|
746
|
+
states[2],
|
747
|
+
cls._invert_rmap(master_rotation_mapping),
|
575
748
|
)
|
576
749
|
|
577
|
-
def
|
578
|
-
return
|
750
|
+
def result(self, state: Tuple, **kwargs) -> Tuple:
|
751
|
+
return state
|