pytme 0.2.9__cp311-cp311-macosx_15_0_arm64.whl → 0.3.0__cp311-cp311-macosx_15_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pytme-0.3.0.data/scripts/estimate_memory_usage.py +76 -0
- pytme-0.3.0.data/scripts/match_template.py +1106 -0
- {pytme-0.2.9.data → pytme-0.3.0.data}/scripts/postprocess.py +320 -190
- {pytme-0.2.9.data → pytme-0.3.0.data}/scripts/preprocess.py +21 -31
- {pytme-0.2.9.data → pytme-0.3.0.data}/scripts/preprocessor_gui.py +85 -19
- pytme-0.3.0.data/scripts/pytme_runner.py +771 -0
- {pytme-0.2.9.dist-info → pytme-0.3.0.dist-info}/METADATA +22 -20
- pytme-0.3.0.dist-info/RECORD +126 -0
- {pytme-0.2.9.dist-info → pytme-0.3.0.dist-info}/entry_points.txt +2 -1
- pytme-0.3.0.dist-info/licenses/LICENSE +339 -0
- scripts/estimate_memory_usage.py +76 -0
- scripts/eval.py +93 -0
- scripts/extract_candidates.py +224 -0
- scripts/match_template.py +349 -378
- pytme-0.2.9.data/scripts/match_template.py → scripts/match_template_filters.py +213 -148
- scripts/postprocess.py +320 -190
- scripts/preprocess.py +21 -31
- scripts/preprocessor_gui.py +85 -19
- scripts/pytme_runner.py +771 -0
- scripts/refine_matches.py +625 -0
- tests/preprocessing/test_frequency_filters.py +28 -14
- tests/test_analyzer.py +41 -36
- tests/test_backends.py +1 -0
- tests/test_matching_cli.py +109 -53
- tests/test_matching_data.py +5 -5
- tests/test_matching_exhaustive.py +1 -2
- tests/test_matching_optimization.py +4 -9
- tests/test_matching_utils.py +1 -1
- tests/test_orientations.py +0 -1
- tme/__version__.py +1 -1
- tme/analyzer/__init__.py +2 -0
- tme/analyzer/_utils.py +26 -21
- tme/analyzer/aggregation.py +396 -222
- tme/analyzer/base.py +127 -0
- tme/analyzer/peaks.py +189 -201
- tme/analyzer/proxy.py +123 -0
- tme/backends/__init__.py +4 -3
- tme/backends/_cupy_utils.py +25 -24
- tme/backends/_jax_utils.py +20 -18
- tme/backends/cupy_backend.py +13 -26
- tme/backends/jax_backend.py +24 -23
- tme/backends/matching_backend.py +4 -3
- tme/backends/mlx_backend.py +4 -3
- tme/backends/npfftw_backend.py +34 -30
- tme/backends/pytorch_backend.py +18 -4
- tme/cli.py +126 -0
- tme/density.py +9 -7
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/filters/__init__.py +3 -3
- tme/filters/_utils.py +36 -10
- tme/filters/bandpass.py +229 -188
- tme/filters/compose.py +5 -4
- tme/filters/ctf.py +516 -254
- tme/filters/reconstruction.py +91 -32
- tme/filters/wedge.py +196 -135
- tme/filters/whitening.py +37 -42
- tme/matching_data.py +28 -39
- tme/matching_exhaustive.py +31 -27
- tme/matching_optimization.py +5 -4
- tme/matching_scores.py +25 -15
- tme/matching_utils.py +158 -28
- tme/memory.py +4 -3
- tme/orientations.py +22 -9
- tme/parser.py +114 -33
- tme/preprocessor.py +6 -5
- tme/rotations.py +10 -7
- tme/structure.py +4 -3
- pytme-0.2.9.data/scripts/estimate_ram_usage.py +0 -97
- pytme-0.2.9.dist-info/RECORD +0 -119
- pytme-0.2.9.dist-info/licenses/LICENSE +0 -153
- scripts/estimate_ram_usage.py +0 -97
- tests/data/Maps/.DS_Store +0 -0
- tests/data/Structures/.DS_Store +0 -0
- {pytme-0.2.9.dist-info → pytme-0.3.0.dist-info}/WHEEL +0 -0
- {pytme-0.2.9.dist-info → pytme-0.3.0.dist-info}/top_level.txt +0 -0
tme/analyzer/aggregation.py
CHANGED
@@ -1,34 +1,34 @@
|
|
1
|
-
"""
|
1
|
+
"""
|
2
|
+
Implements classes to analyze outputs from exhaustive template matching.
|
2
3
|
|
3
|
-
|
4
|
+
Copyright (c) 2023 European Molecular Biology Laboratory
|
4
5
|
|
5
|
-
|
6
|
+
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
6
7
|
"""
|
7
8
|
|
8
|
-
from
|
9
|
-
from multiprocessing import Manager
|
10
|
-
from typing import Tuple, List, Dict, Generator
|
9
|
+
from typing import Tuple, List, Dict
|
11
10
|
|
12
11
|
import numpy as np
|
13
12
|
|
13
|
+
from .base import AbstractAnalyzer
|
14
14
|
from ..types import BackendArray
|
15
15
|
from ._utils import cart_to_score
|
16
16
|
from ..backends import backend as be
|
17
17
|
from ..matching_utils import (
|
18
18
|
create_mask,
|
19
19
|
array_to_memmap,
|
20
|
-
generate_tempfile_name,
|
21
20
|
apply_convolution_mode,
|
21
|
+
generate_tempfile_name,
|
22
22
|
)
|
23
23
|
|
24
|
-
|
25
24
|
__all__ = [
|
26
25
|
"MaxScoreOverRotations",
|
26
|
+
"MaxScoreOverRotationsConstrained",
|
27
27
|
"MaxScoreOverTranslations",
|
28
28
|
]
|
29
29
|
|
30
30
|
|
31
|
-
class MaxScoreOverRotations:
|
31
|
+
class MaxScoreOverRotations(AbstractAnalyzer):
|
32
32
|
"""
|
33
33
|
Determine the rotation maximizing the score over all possible translations.
|
34
34
|
|
@@ -36,10 +36,6 @@ class MaxScoreOverRotations:
|
|
36
36
|
----------
|
37
37
|
shape : tuple of int
|
38
38
|
Shape of array passed to :py:meth:`MaxScoreOverRotations.__call__`.
|
39
|
-
scores : BackendArray, optional
|
40
|
-
Array mapping translations to scores.
|
41
|
-
rotations : BackendArray, optional
|
42
|
-
Array mapping translations to rotation indices.
|
43
39
|
offset : BackendArray, optional
|
44
40
|
Coordinate origin considered during merging, zero by default.
|
45
41
|
score_threshold : float, optional
|
@@ -50,58 +46,47 @@ class MaxScoreOverRotations:
|
|
50
46
|
Memmap internal arrays, False by default.
|
51
47
|
thread_safe: bool, optional
|
52
48
|
Allow class to be modified by multiple processes, True by default.
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
scores : BackendArray
|
59
|
-
Mapping of translations to scores.
|
60
|
-
rotations : BackendArray
|
61
|
-
Mmapping of translations to rotation indices.
|
62
|
-
rotation_mapping : Dict
|
63
|
-
Mapping of rotations to rotation indices.
|
64
|
-
offset : BackendArray, optional
|
65
|
-
Coordinate origin considered during merging, zero by default
|
49
|
+
inversion_mapping : bool, optional
|
50
|
+
Do not use rotation matrix bytestrings for intermediate data handling.
|
51
|
+
This is useful for GPU backend where analyzers are not shared across
|
52
|
+
devices and every rotation is only observed once. It is generally
|
53
|
+
safe to deactivate inversion mapping, but at a cost of performance.
|
66
54
|
|
67
55
|
Examples
|
68
56
|
--------
|
69
57
|
The following achieves the minimal definition of a :py:class:`MaxScoreOverRotations`
|
70
58
|
instance
|
71
59
|
|
60
|
+
>>> import numpy as np
|
72
61
|
>>> from tme.analyzer import MaxScoreOverRotations
|
73
|
-
>>> analyzer = MaxScoreOverRotations(shape
|
62
|
+
>>> analyzer = MaxScoreOverRotations(shape=(50, 50))
|
74
63
|
|
75
64
|
The following simulates a template matching run by creating random data for a range
|
76
65
|
of rotations and sending it to ``analyzer`` via its __call__ method
|
77
66
|
|
67
|
+
>>> state = analyzer.init_state()
|
78
68
|
>>> for rotation_number in range(10):
|
79
69
|
>>> scores = np.random.rand(50,50)
|
80
70
|
>>> rotation = np.random.rand(scores.ndim, scores.ndim)
|
81
|
-
>>> analyzer(scores
|
71
|
+
>>> state = analyzer(state, scores=scores, rotation_matrix=rotation)
|
82
72
|
|
83
|
-
The aggregated scores can be extracted by invoking the
|
73
|
+
The aggregated scores can be extracted by invoking the result method of
|
84
74
|
``analyzer``
|
85
75
|
|
86
|
-
>>> results =
|
76
|
+
>>> results = analyzer.result(state)
|
87
77
|
|
88
78
|
The ``results`` tuple contains (1) the maximum scores for each translation,
|
89
79
|
(2) an offset which is relevant when merging results from split template matching
|
90
80
|
using :py:meth:`MaxScoreOverRotations.merge`, (3) the rotation used to obtain a
|
91
|
-
score for a given translation, (4) a dictionary mapping
|
92
|
-
|
81
|
+
score for a given translation, (4) a dictionary mapping indices used in (2) to
|
82
|
+
rotation matrices (2).
|
93
83
|
|
94
84
|
We can extract the ``optimal_score``, ``optimal_translation`` and ``optimal_rotation``
|
95
85
|
as follows
|
96
86
|
|
97
87
|
>>> optimal_score = results[0].max()
|
98
88
|
>>> optimal_translation = np.where(results[0] == results[0].max())
|
99
|
-
>>>
|
100
|
-
>>> for key, value in results[3].items():
|
101
|
-
>>> if value != optimal_rotation_index:
|
102
|
-
>>> continue
|
103
|
-
>>> optimal_rotation = np.frombuffer(key, rotation.dtype)
|
104
|
-
>>> optimal_rotation = optimal_rotation.reshape(scores.ndim, scores.ndim)
|
89
|
+
>>> optimal_rotation = results[2][optimal_translation]
|
105
90
|
|
106
91
|
The outlined procedure is a trivial method to identify high scoring peaks.
|
107
92
|
Alternatively, :py:class:`PeakCaller` offers a range of more elaborate approaches
|
@@ -111,156 +96,221 @@ class MaxScoreOverRotations:
|
|
111
96
|
def __init__(
|
112
97
|
self,
|
113
98
|
shape: Tuple[int],
|
114
|
-
scores: BackendArray = None,
|
115
|
-
rotations: BackendArray = None,
|
116
99
|
offset: BackendArray = None,
|
117
100
|
score_threshold: float = 0,
|
118
101
|
shm_handler: object = None,
|
119
102
|
use_memmap: bool = False,
|
120
|
-
|
121
|
-
|
103
|
+
inversion_mapping: bool = False,
|
104
|
+
jax_mode: bool = False,
|
122
105
|
**kwargs,
|
123
106
|
):
|
107
|
+
self._use_memmap = use_memmap
|
108
|
+
self._score_threshold = score_threshold
|
124
109
|
self._shape = tuple(int(x) for x in shape)
|
110
|
+
self._inversion_mapping = inversion_mapping
|
125
111
|
|
126
|
-
self.
|
127
|
-
if self.
|
128
|
-
self.
|
129
|
-
shape=self._shape, dtype=be._float_dtype, fill_value=score_threshold
|
130
|
-
)
|
131
|
-
self.rotations = rotations
|
132
|
-
if self.rotations is None:
|
133
|
-
self.rotations = be.full(self._shape, dtype=be._int_dtype, fill_value=-1)
|
134
|
-
|
135
|
-
self.scores = be.to_sharedarr(self.scores, shm_handler)
|
136
|
-
self.rotations = be.to_sharedarr(self.rotations, shm_handler)
|
112
|
+
self._jax_mode = jax_mode
|
113
|
+
if self._jax_mode:
|
114
|
+
self._inversion_mapping = False
|
137
115
|
|
138
116
|
if offset is None:
|
139
117
|
offset = be.zeros(len(self._shape), be._int_dtype)
|
140
|
-
self.
|
118
|
+
self._offset = be.astype(be.to_backend_array(offset), int)
|
141
119
|
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
120
|
+
@property
|
121
|
+
def shareable(self):
|
122
|
+
return True
|
123
|
+
|
124
|
+
def init_state(self):
|
125
|
+
"""
|
126
|
+
Initialize the analysis state.
|
127
|
+
|
128
|
+
Returns
|
129
|
+
-------
|
130
|
+
tuple
|
131
|
+
Initial state tuple containing (scores, rotations, rotation_mapping) where:
|
132
|
+
- scores : BackendArray of shape `self._shape` filled with `score_threshold`.
|
133
|
+
- rotations : BackendArray of shape `self._shape` filled with -1.
|
134
|
+
- rotation_mapping : dict, empty mapping from rotation bytes to indices.
|
135
|
+
"""
|
136
|
+
scores = be.full(
|
137
|
+
shape=self._shape, dtype=be._float_dtype, fill_value=self._score_threshold
|
138
|
+
)
|
139
|
+
rotations = be.full(self._shape, dtype=be._int_dtype, fill_value=-1)
|
140
|
+
return scores, rotations, {}
|
147
141
|
|
148
|
-
def
|
142
|
+
def __call__(
|
149
143
|
self,
|
150
|
-
|
151
|
-
|
152
|
-
|
144
|
+
state: Tuple,
|
145
|
+
scores: BackendArray,
|
146
|
+
rotation_matrix: BackendArray,
|
147
|
+
**kwargs,
|
148
|
+
) -> Tuple:
|
149
|
+
"""
|
150
|
+
Update the parameter store.
|
151
|
+
|
152
|
+
Parameters
|
153
|
+
----------
|
154
|
+
state : tuple
|
155
|
+
Current state tuple (scores, rotations, rotation_mapping) where:
|
156
|
+
- scores : BackendArray, current maximum scores.
|
157
|
+
- rotations : BackendArray, current rotation indices.
|
158
|
+
- rotation_mapping : dict, mapping from rotation bytes to indices.
|
159
|
+
scores : BackendArray
|
160
|
+
Array of new scores to update analyzer with.
|
161
|
+
rotation_matrix : BackendArray
|
162
|
+
Square matrix used to obtain the current rotation.
|
163
|
+
Returns
|
164
|
+
-------
|
165
|
+
tuple
|
166
|
+
Updated state tuple (scores, rotations, rotation_mapping).
|
167
|
+
"""
|
168
|
+
# be.tobytes behaviour caused overhead for certain GPU/CUDA combinations
|
169
|
+
# If the analyzer is not shared and each rotation is unique, we can
|
170
|
+
# use index to rotation mapping and invert prior to merging.
|
171
|
+
prev_scores, rotations, rotation_mapping = state
|
172
|
+
|
173
|
+
rotation_index = len(rotation_mapping)
|
174
|
+
rotation_matrix = be.astype(rotation_matrix, be._float_dtype)
|
175
|
+
if self._inversion_mapping:
|
176
|
+
rotation_mapping[rotation_index] = rotation_matrix
|
177
|
+
elif self._jax_mode:
|
178
|
+
rotation_index = kwargs.get("rotation_index", 0)
|
179
|
+
else:
|
180
|
+
rotation = be.tobytes(rotation_matrix)
|
181
|
+
rotation_index = rotation_mapping.setdefault(rotation, rotation_index)
|
182
|
+
|
183
|
+
scores, rotations = be.max_score_over_rotations(
|
184
|
+
scores=scores,
|
185
|
+
max_scores=prev_scores,
|
186
|
+
rotations=rotations,
|
187
|
+
rotation_index=rotation_index,
|
188
|
+
)
|
189
|
+
return scores, rotations, rotation_mapping
|
190
|
+
|
191
|
+
@staticmethod
|
192
|
+
def _invert_rmap(rotation_mapping: dict) -> dict:
|
193
|
+
"""
|
194
|
+
Invert dictionary from rotation matrix bytestrings mapping to rotation
|
195
|
+
indices ro rotation indices mapping to rotation matrices.
|
196
|
+
"""
|
197
|
+
new_map, ndim = {}, None
|
198
|
+
for k, v in rotation_mapping.items():
|
199
|
+
nbytes = be.datatype_bytes(be._float_dtype)
|
200
|
+
dtype = np.float32 if nbytes == 4 else np.float16
|
201
|
+
rmat = np.frombuffer(k, dtype=dtype)
|
202
|
+
if ndim is None:
|
203
|
+
ndim = int(np.sqrt(rmat.size))
|
204
|
+
new_map[v] = rmat.reshape(ndim, ndim)
|
205
|
+
return new_map
|
206
|
+
|
207
|
+
def result(
|
208
|
+
self,
|
209
|
+
state,
|
210
|
+
targetshape: Tuple[int] = None,
|
211
|
+
templateshape: Tuple[int] = None,
|
212
|
+
convolution_shape: Tuple[int] = None,
|
153
213
|
fourier_shift: Tuple[int] = None,
|
154
214
|
convolution_mode: str = None,
|
155
|
-
shm_handler=None,
|
156
215
|
**kwargs,
|
157
|
-
) ->
|
158
|
-
"""
|
159
|
-
|
160
|
-
|
216
|
+
) -> Tuple:
|
217
|
+
"""
|
218
|
+
Finalize the analysis result with optional postprocessing.
|
219
|
+
|
220
|
+
Parameters
|
221
|
+
----------
|
222
|
+
state : tuple
|
223
|
+
Current state tuple (scores, rotations, rotation_mapping) where:
|
224
|
+
- scores : BackendArray, current maximum scores.
|
225
|
+
- rotations : BackendArray, current rotation indices.
|
226
|
+
- rotation_mapping : dict, mapping from rotation indices to matrices.
|
227
|
+
targetshape : Tuple[int], optional
|
228
|
+
Shape of the target for convolution mode correction.
|
229
|
+
templateshape : Tuple[int], optional
|
230
|
+
Shape of the template for convolution mode correction.
|
231
|
+
convolution_shape : Tuple[int], optional
|
232
|
+
Shape used for convolution.
|
233
|
+
fourier_shift : Tuple[int], optional.
|
234
|
+
Shift to apply for Fourier correction.
|
235
|
+
convolution_mode : str, optional
|
236
|
+
Convolution mode for padding correction.
|
237
|
+
**kwargs
|
238
|
+
Additional keyword arguments.
|
239
|
+
|
240
|
+
Returns
|
241
|
+
-------
|
242
|
+
tuple
|
243
|
+
Final result tuple (scores, offset, rotations, rotation_mapping).
|
244
|
+
"""
|
245
|
+
scores, rotations, rotation_mapping = state
|
246
|
+
|
247
|
+
# Apply postprocessing if parameters are provided
|
161
248
|
if fourier_shift is not None:
|
162
249
|
axis = tuple(i for i in range(len(fourier_shift)))
|
163
250
|
scores = be.roll(scores, shift=fourier_shift, axis=axis)
|
164
251
|
rotations = be.roll(rotations, shift=fourier_shift, axis=axis)
|
165
252
|
|
166
|
-
convargs = {
|
167
|
-
"s1": targetshape,
|
168
|
-
"s2": templateshape,
|
169
|
-
"convolution_mode": convolution_mode,
|
170
|
-
"convolution_shape": convolution_shape,
|
171
|
-
}
|
172
253
|
if convolution_mode is not None:
|
254
|
+
convargs = {
|
255
|
+
"s1": targetshape,
|
256
|
+
"s2": templateshape,
|
257
|
+
"convolution_mode": convolution_mode,
|
258
|
+
"convolution_shape": convolution_shape,
|
259
|
+
}
|
173
260
|
scores = apply_convolution_mode(scores, **convargs)
|
174
261
|
rotations = apply_convolution_mode(rotations, **convargs)
|
175
262
|
|
176
|
-
self._shape, self.scores, self.rotations = scores.shape, scores, rotations
|
177
|
-
if shm_handler is not None:
|
178
|
-
self.scores = be.to_sharedarr(scores, shm_handler)
|
179
|
-
self.rotations = be.to_sharedarr(rotations, shm_handler)
|
180
|
-
return self
|
181
|
-
|
182
|
-
def __iter__(self) -> Generator:
|
183
|
-
scores = be.from_sharedarr(self.scores)
|
184
|
-
rotations = be.from_sharedarr(self.rotations)
|
185
|
-
|
186
263
|
scores = be.to_numpy_array(scores)
|
187
264
|
rotations = be.to_numpy_array(rotations)
|
188
265
|
if self._use_memmap:
|
189
266
|
scores = array_to_memmap(scores)
|
190
267
|
rotations = array_to_memmap(rotations)
|
191
|
-
else:
|
192
|
-
if type(self.scores) is not type(scores):
|
193
|
-
# Copy to avoid invalidation by shared memory handler
|
194
|
-
scores, rotations = scores.copy(), rotations.copy()
|
195
268
|
|
196
269
|
if self._inversion_mapping:
|
197
|
-
|
198
|
-
be.tobytes(v): k for k, v in self.rotation_mapping.items()
|
199
|
-
}
|
270
|
+
rotation_mapping = {be.tobytes(v): k for k, v in rotation_mapping.items()}
|
200
271
|
|
201
|
-
|
272
|
+
return (
|
202
273
|
scores,
|
203
|
-
be.to_numpy_array(self.
|
274
|
+
be.to_numpy_array(self._offset),
|
204
275
|
rotations,
|
205
|
-
|
276
|
+
self._invert_rmap(rotation_mapping),
|
206
277
|
)
|
207
|
-
yield from param_store
|
208
278
|
|
209
|
-
def
|
279
|
+
def _harmonize_states(states: List[Tuple]):
|
210
280
|
"""
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
----------
|
215
|
-
scores : BackendArray
|
216
|
-
Array of scores.
|
217
|
-
rotation_matrix : BackendArray
|
218
|
-
Square matrix describing the current rotation.
|
281
|
+
Create consistent reference frame for merging different analyzer
|
282
|
+
instances, w.r.t. to rotations and output shape from different
|
283
|
+
splits of the target.
|
219
284
|
"""
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
rotation_index = len(self.rotation_mapping)
|
225
|
-
if self._inversion_mapping:
|
226
|
-
self.rotation_mapping[rotation_index] = rotation_matrix
|
227
|
-
else:
|
228
|
-
rotation = be.tobytes(rotation_matrix)
|
229
|
-
rotation_index = self.rotation_mapping.setdefault(
|
230
|
-
rotation, rotation_index
|
231
|
-
)
|
232
|
-
self.scores, self.rotations = be.max_score_over_rotations(
|
233
|
-
scores=scores,
|
234
|
-
max_scores=self.scores,
|
235
|
-
rotations=self.rotations,
|
236
|
-
rotation_index=rotation_index,
|
237
|
-
)
|
238
|
-
return None
|
285
|
+
new_rotation_mapping, out_shape = {}, None
|
286
|
+
for i in range(len(states)):
|
287
|
+
if states[i] is None:
|
288
|
+
continue
|
239
289
|
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
290
|
+
scores, offset, rotations, rotation_mapping = states[i]
|
291
|
+
if out_shape is None:
|
292
|
+
out_shape = np.zeros(scores.ndim, int)
|
293
|
+
out_shape = np.maximum(out_shape, np.add(offset, scores.shape))
|
294
|
+
|
295
|
+
new_param = {}
|
296
|
+
for key, value in rotation_mapping.items():
|
297
|
+
rotation_bytes = be.tobytes(value)
|
298
|
+
new_param[rotation_bytes] = key
|
299
|
+
if rotation_bytes not in new_rotation_mapping:
|
300
|
+
new_rotation_mapping[rotation_bytes] = len(new_rotation_mapping)
|
301
|
+
states[i] = (scores, offset, rotations, new_param)
|
302
|
+
out_shape = tuple(int(x) for x in out_shape)
|
303
|
+
return new_rotation_mapping, out_shape, states
|
254
304
|
|
255
305
|
@classmethod
|
256
|
-
def merge(cls,
|
306
|
+
def merge(cls, results: List[Tuple], **kwargs) -> Tuple:
|
257
307
|
"""
|
258
308
|
Merge multiple instances of the current class.
|
259
309
|
|
260
310
|
Parameters
|
261
311
|
----------
|
262
|
-
|
263
|
-
List of instance's internal state created by applying `
|
312
|
+
results : list of tuple
|
313
|
+
List of instance's internal state created by applying `result`.
|
264
314
|
**kwargs : dict, optional
|
265
315
|
Optional keyword arguments.
|
266
316
|
|
@@ -276,8 +326,8 @@ class MaxScoreOverRotations:
|
|
276
326
|
Mapping between rotations and rotation indices.
|
277
327
|
"""
|
278
328
|
use_memmap = kwargs.get("use_memmap", False)
|
279
|
-
if len(
|
280
|
-
ret =
|
329
|
+
if len(results) == 1:
|
330
|
+
ret = results[0]
|
281
331
|
if use_memmap:
|
282
332
|
scores, offset, rotations, rotation_mapping = ret
|
283
333
|
scores = array_to_memmap(scores)
|
@@ -287,25 +337,12 @@ class MaxScoreOverRotations:
|
|
287
337
|
return ret
|
288
338
|
|
289
339
|
# Determine output array shape and create consistent rotation map
|
290
|
-
|
291
|
-
for i in range(len(param_stores)):
|
292
|
-
if param_stores[i] is None:
|
293
|
-
continue
|
294
|
-
|
295
|
-
scores, offset, rotations, rotation_mapping = param_stores[i]
|
296
|
-
if out_shape is None:
|
297
|
-
out_shape = np.zeros(scores.ndim, int)
|
298
|
-
scores_dtype, rotations_dtype = scores.dtype, rotations.dtype
|
299
|
-
out_shape = np.maximum(out_shape, np.add(offset, scores.shape))
|
300
|
-
|
301
|
-
for key, value in rotation_mapping.items():
|
302
|
-
if key not in new_rotation_mapping:
|
303
|
-
new_rotation_mapping[key] = len(new_rotation_mapping)
|
304
|
-
|
340
|
+
master_rotation_mapping, out_shape, results = cls._harmonize_states(results)
|
305
341
|
if out_shape is None:
|
306
342
|
return None
|
307
343
|
|
308
|
-
|
344
|
+
scores_dtype = results[0][0].dtype
|
345
|
+
rotations_dtype = results[0][2].dtype
|
309
346
|
if use_memmap:
|
310
347
|
scores_out_filename = generate_tempfile_name()
|
311
348
|
rotations_out_filename = generate_tempfile_name()
|
@@ -331,8 +368,8 @@ class MaxScoreOverRotations:
|
|
331
368
|
)
|
332
369
|
rotations_out = np.full(out_shape, fill_value=-1, dtype=rotations_dtype)
|
333
370
|
|
334
|
-
for i in range(len(
|
335
|
-
if
|
371
|
+
for i in range(len(results)):
|
372
|
+
if results[i] is None:
|
336
373
|
continue
|
337
374
|
|
338
375
|
if use_memmap:
|
@@ -348,7 +385,7 @@ class MaxScoreOverRotations:
|
|
348
385
|
shape=out_shape,
|
349
386
|
dtype=rotations_dtype,
|
350
387
|
)
|
351
|
-
scores, offset, rotations, rotation_mapping =
|
388
|
+
scores, offset, rotations, rotation_mapping = results[i]
|
352
389
|
stops = np.add(offset, scores.shape).astype(int)
|
353
390
|
indices = tuple(slice(*pos) for pos in zip(offset, stops))
|
354
391
|
|
@@ -359,7 +396,7 @@ class MaxScoreOverRotations:
|
|
359
396
|
len(rotation_mapping) + 1, dtype=rotations_out.dtype
|
360
397
|
)
|
361
398
|
for key, value in rotation_mapping.items():
|
362
|
-
lookup_table[value] =
|
399
|
+
lookup_table[value] = master_rotation_mapping[key]
|
363
400
|
|
364
401
|
updated_rotations = rotations[indices_update]
|
365
402
|
if len(updated_rotations):
|
@@ -372,7 +409,7 @@ class MaxScoreOverRotations:
|
|
372
409
|
rotations_out.flush()
|
373
410
|
scores_out, rotations_out = None, None
|
374
411
|
|
375
|
-
|
412
|
+
results[i] = None
|
376
413
|
scores, rotations = None, None
|
377
414
|
|
378
415
|
if use_memmap:
|
@@ -390,13 +427,166 @@ class MaxScoreOverRotations:
|
|
390
427
|
scores_out,
|
391
428
|
np.zeros(scores_out.ndim, dtype=int),
|
392
429
|
rotations_out,
|
393
|
-
|
430
|
+
cls._invert_rmap(master_rotation_mapping),
|
394
431
|
)
|
395
432
|
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
433
|
+
|
434
|
+
class MaxScoreOverRotationsConstrained(MaxScoreOverRotations):
|
435
|
+
"""
|
436
|
+
Implements constrained template matching using rejection sampling.
|
437
|
+
|
438
|
+
Parameters
|
439
|
+
----------
|
440
|
+
cone_angle : float
|
441
|
+
Maximum accepted rotational deviation in degrees.
|
442
|
+
positions : BackendArray
|
443
|
+
Array of shape (n, d) with n seed point translations.
|
444
|
+
positions : BackendArray
|
445
|
+
Array of shape (n, d, d) with n seed point rotation matrices.
|
446
|
+
reference : BackendArray
|
447
|
+
Reference orientation of the template, wlog defaults to (0,0,1).
|
448
|
+
acceptance_radius : int or tuple of ints
|
449
|
+
Translational acceptance radius around seed point in voxels.
|
450
|
+
**kwargs : dict, optional
|
451
|
+
Keyword aguments passed to the constructor of :py:class:`MaxScoreOverRotations`.
|
452
|
+
"""
|
453
|
+
|
454
|
+
def __init__(
|
455
|
+
self,
|
456
|
+
cone_angle: float,
|
457
|
+
positions: BackendArray,
|
458
|
+
rotations: BackendArray,
|
459
|
+
reference: BackendArray = (0, 0, 1),
|
460
|
+
acceptance_radius: int = 10,
|
461
|
+
**kwargs,
|
462
|
+
):
|
463
|
+
MaxScoreOverRotations.__init__(self, **kwargs)
|
464
|
+
|
465
|
+
if not isinstance(acceptance_radius, (int, Tuple)):
|
466
|
+
raise ValueError("acceptance_radius needs to be of type int or tuple.")
|
467
|
+
|
468
|
+
if isinstance(acceptance_radius, int):
|
469
|
+
acceptance_radius = (
|
470
|
+
acceptance_radius,
|
471
|
+
acceptance_radius,
|
472
|
+
acceptance_radius,
|
473
|
+
)
|
474
|
+
acceptance_radius = tuple(int(x) for x in acceptance_radius)
|
475
|
+
|
476
|
+
self._cone_angle = float(np.radians(cone_angle))
|
477
|
+
self._cone_cutoff = float(np.tan(self._cone_angle))
|
478
|
+
self._reference = be.astype(
|
479
|
+
be.reshape(be.to_backend_array(reference), (-1,)), be._float_dtype
|
480
|
+
)
|
481
|
+
positions = be.astype(be.to_backend_array(positions), be._int_dtype)
|
482
|
+
|
483
|
+
ndim = positions.shape[1]
|
484
|
+
rotate_mask = len(set(acceptance_radius)) != 1
|
485
|
+
extend = max(acceptance_radius)
|
486
|
+
mask = create_mask(
|
487
|
+
mask_type="ellipse",
|
488
|
+
radius=acceptance_radius,
|
489
|
+
shape=tuple(2 * extend + 1 for _ in range(ndim)),
|
490
|
+
center=tuple(extend for _ in range(ndim)),
|
491
|
+
)
|
492
|
+
self._score_mask = be.astype(be.to_backend_array(mask), be._float_dtype)
|
493
|
+
|
494
|
+
# Map position from real space to shifted score space
|
495
|
+
lower_limit = be.to_backend_array(self._offset)
|
496
|
+
positions = be.subtract(positions, lower_limit)
|
497
|
+
positions, valid_positions = cart_to_score(
|
498
|
+
positions=positions,
|
499
|
+
fast_shape=kwargs.get("fast_shape", None),
|
500
|
+
targetshape=kwargs.get("targetshape", None),
|
501
|
+
templateshape=kwargs.get("templateshape", None),
|
502
|
+
fourier_shift=kwargs.get("fourier_shift", None),
|
503
|
+
convolution_mode=kwargs.get("convolution_mode", None),
|
504
|
+
convolution_shape=kwargs.get("convolution_shape", None),
|
505
|
+
)
|
506
|
+
|
507
|
+
self._positions = positions[valid_positions]
|
508
|
+
rotations = be.to_backend_array(rotations)[valid_positions]
|
509
|
+
ex = be.astype(be.to_backend_array((1, 0, 0)), be._float_dtype)
|
510
|
+
ey = be.astype(be.to_backend_array((0, 1, 0)), be._float_dtype)
|
511
|
+
ez = be.astype(be.to_backend_array((0, 0, 1)), be._float_dtype)
|
512
|
+
|
513
|
+
self._normals_x = (rotations @ ex[..., None])[..., 0]
|
514
|
+
self._normals_y = (rotations @ ey[..., None])[..., 0]
|
515
|
+
self._normals_z = (rotations @ ez[..., None])[..., 0]
|
516
|
+
|
517
|
+
# Periodic wrapping could be avoided by padding the target
|
518
|
+
shape = be.to_backend_array(self._shape)
|
519
|
+
starts = be.subtract(self._positions, extend)
|
520
|
+
ret, (n, d), mshape = [], self._positions.shape, self._score_mask.shape
|
521
|
+
if starts.shape[0] > 0:
|
522
|
+
for i in range(d):
|
523
|
+
indices = starts[:, slice(i, i + 1)] + be.arange(mshape[i])[None]
|
524
|
+
indices = be.mod(indices, shape[i], out=indices)
|
525
|
+
indices_shape = (n, *tuple(1 if k != i else -1 for k in range(d)))
|
526
|
+
ret.append(be.reshape(indices, indices_shape))
|
527
|
+
|
528
|
+
self._index_grid = tuple(ret)
|
529
|
+
self._mask_shape = tuple(1 if i != 0 else -1 for i in range(1 + ndim))
|
530
|
+
|
531
|
+
if rotate_mask:
|
532
|
+
self._score_mask = be.zeros(
|
533
|
+
(rotations.shape[0], *self._score_mask.shape), dtype=be._float_dtype
|
534
|
+
)
|
535
|
+
for i in range(rotations.shape[0]):
|
536
|
+
mask = create_mask(
|
537
|
+
mask_type="ellipse",
|
538
|
+
radius=acceptance_radius,
|
539
|
+
shape=tuple(2 * extend + 1 for _ in range(ndim)),
|
540
|
+
center=tuple(extend for _ in range(ndim)),
|
541
|
+
orientation=be.to_numpy_array(rotations[i]),
|
542
|
+
)
|
543
|
+
self._score_mask[i] = be.astype(
|
544
|
+
be.to_backend_array(mask), be._float_dtype
|
545
|
+
)
|
546
|
+
|
547
|
+
def __call__(
|
548
|
+
self, state: Tuple, scores: BackendArray, rotation_matrix: BackendArray
|
549
|
+
) -> Tuple:
|
550
|
+
mask = self._get_constraint(rotation_matrix)
|
551
|
+
mask = self._get_score_mask(mask=mask, scores=scores)
|
552
|
+
|
553
|
+
scores = be.multiply(scores, mask, out=scores)
|
554
|
+
return super().__call__(state, scores=scores, rotation_matrix=rotation_matrix)
|
555
|
+
|
556
|
+
def _get_constraint(self, rotation_matrix: BackendArray) -> BackendArray:
|
557
|
+
"""
|
558
|
+
Determine whether the angle between projection of reference w.r.t to
|
559
|
+
a given rotation matrix and a set of rotations fall within the set
|
560
|
+
cone_angle cutoff.
|
561
|
+
|
562
|
+
Parameters
|
563
|
+
----------
|
564
|
+
rotation_matrix : BackendArray
|
565
|
+
Rotation matrix with shape (d,d).
|
566
|
+
|
567
|
+
Returns
|
568
|
+
-------
|
569
|
+
BackerndArray
|
570
|
+
Boolean mask of shape (n, )
|
571
|
+
"""
|
572
|
+
template_rot = rotation_matrix @ self._reference
|
573
|
+
|
574
|
+
x = be.sum(be.multiply(self._normals_x, template_rot), axis=1)
|
575
|
+
y = be.sum(be.multiply(self._normals_y, template_rot), axis=1)
|
576
|
+
z = be.sum(be.multiply(self._normals_z, template_rot), axis=1)
|
577
|
+
|
578
|
+
return be.sqrt(x**2 + y**2) <= (z * self._cone_cutoff)
|
579
|
+
|
580
|
+
def _get_score_mask(self, mask: BackendArray, scores: BackendArray, **kwargs):
|
581
|
+
score_mask = be.zeros(scores.shape, scores.dtype)
|
582
|
+
|
583
|
+
if be.sum(mask) == 0:
|
584
|
+
return score_mask
|
585
|
+
mask = be.reshape(mask, self._mask_shape)
|
586
|
+
|
587
|
+
score_mask = be.addat(score_mask, self._index_grid, self._score_mask * mask)
|
588
|
+
return score_mask > 0
|
589
|
+
|
400
590
|
|
401
591
|
class MaxScoreOverTranslations(MaxScoreOverRotations):
|
402
592
|
"""
|
@@ -436,43 +626,40 @@ class MaxScoreOverTranslations(MaxScoreOverRotations):
|
|
436
626
|
super().__init__(
|
437
627
|
shape=shape_reduced, shm_handler=shm_handler, offset=offset, **kwargs
|
438
628
|
)
|
439
|
-
|
440
|
-
self.rotations = be.full(1, dtype=be._int_dtype, fill_value=-1)
|
441
|
-
self.rotations = be.to_sharedarr(self.rotations, shm_handler)
|
442
629
|
self._aggregate_axis = aggregate_axis
|
443
630
|
|
444
|
-
def
|
445
|
-
|
446
|
-
|
447
|
-
|
448
|
-
|
449
|
-
|
450
|
-
rotation = be.tobytes(rotation_matrix)
|
451
|
-
rotation_index = self.rotation_mapping.setdefault(
|
452
|
-
rotation, rotation_index
|
453
|
-
)
|
454
|
-
max_score = be.max(scores, axis=self._aggregate_axis)
|
455
|
-
self.scores[rotation_index] = max_score
|
456
|
-
return None
|
631
|
+
def init_state(self):
|
632
|
+
scores = be.full(
|
633
|
+
shape=self._shape, dtype=be._float_dtype, fill_value=self._score_threshold
|
634
|
+
)
|
635
|
+
rotations = be.full(1, dtype=be._int_dtype, fill_value=-1)
|
636
|
+
return scores, rotations, {}
|
457
637
|
|
458
|
-
|
459
|
-
|
460
|
-
|
461
|
-
|
462
|
-
|
463
|
-
|
464
|
-
|
465
|
-
|
466
|
-
|
638
|
+
def __call__(
|
639
|
+
self, state, scores: BackendArray, rotation_matrix: BackendArray
|
640
|
+
) -> Tuple:
|
641
|
+
prev_scores, rotations, rotation_mapping = state
|
642
|
+
|
643
|
+
rotation_index = len(rotation_mapping)
|
644
|
+
if self._inversion_mapping:
|
645
|
+
rotation_mapping[rotation_index] = rotation_matrix
|
646
|
+
else:
|
647
|
+
rotation = be.tobytes(rotation_matrix)
|
648
|
+
rotation_index = rotation_mapping.setdefault(rotation, rotation_index)
|
649
|
+
max_score = be.max(scores, axis=self._aggregate_axis)
|
650
|
+
|
651
|
+
update = prev_scores[rotation_index]
|
652
|
+
update = be.maximum(max_score, update, out=update)
|
653
|
+
return prev_scores, rotations, rotation_mapping
|
467
654
|
|
468
655
|
@classmethod
|
469
|
-
def merge(cls,
|
656
|
+
def merge(cls, states: List[Tuple], **kwargs) -> Tuple:
|
470
657
|
"""
|
471
658
|
Merge multiple instances of the current class.
|
472
659
|
|
473
660
|
Parameters
|
474
661
|
----------
|
475
|
-
|
662
|
+
states : list of tuple
|
476
663
|
List of instance's internal state created by applying `tuple(instance)`.
|
477
664
|
**kwargs : dict, optional
|
478
665
|
Optional keyword arguments.
|
@@ -488,31 +675,18 @@ class MaxScoreOverTranslations(MaxScoreOverRotations):
|
|
488
675
|
Dict
|
489
676
|
Mapping between rotations and rotation indices.
|
490
677
|
"""
|
491
|
-
if len(
|
492
|
-
return
|
678
|
+
if len(states) == 1:
|
679
|
+
return states[0]
|
493
680
|
|
494
681
|
# Determine output array shape and create consistent rotation map
|
495
|
-
|
496
|
-
for i in range(len(param_stores)):
|
497
|
-
if param_stores[i] is None:
|
498
|
-
continue
|
499
|
-
|
500
|
-
scores, offset, rotations, rotation_mapping = param_stores[i]
|
501
|
-
if out_shape is None:
|
502
|
-
out_shape = np.zeros(scores.ndim, int)
|
503
|
-
scores_dtype, rotations_out = scores.dtype, rotations
|
504
|
-
out_shape = np.maximum(out_shape, np.add(offset, scores.shape))
|
505
|
-
|
506
|
-
for key, value in rotation_mapping.items():
|
507
|
-
if key not in new_rotation_mapping:
|
508
|
-
new_rotation_mapping[key] = len(new_rotation_mapping)
|
509
|
-
|
682
|
+
states, master_rotation_mapping, out_shape = cls._harmonize_states(states)
|
510
683
|
if out_shape is None:
|
511
684
|
return None
|
512
685
|
|
513
|
-
out_shape[0] = len(
|
686
|
+
out_shape[0] = len(master_rotation_mapping)
|
514
687
|
out_shape = tuple(int(x) for x in out_shape)
|
515
688
|
|
689
|
+
scores_dtype = states[0][0].dtype
|
516
690
|
use_memmap = kwargs.get("use_memmap", False)
|
517
691
|
if use_memmap:
|
518
692
|
scores_out_filename = generate_tempfile_name()
|
@@ -528,8 +702,8 @@ class MaxScoreOverTranslations(MaxScoreOverRotations):
|
|
528
702
|
dtype=scores_dtype,
|
529
703
|
)
|
530
704
|
|
531
|
-
for i in range(len(
|
532
|
-
if
|
705
|
+
for i in range(len(states)):
|
706
|
+
if states[i] is None:
|
533
707
|
continue
|
534
708
|
|
535
709
|
if use_memmap:
|
@@ -539,11 +713,11 @@ class MaxScoreOverTranslations(MaxScoreOverRotations):
|
|
539
713
|
shape=out_shape,
|
540
714
|
dtype=scores_dtype,
|
541
715
|
)
|
542
|
-
scores, offset, rotations, rotation_mapping =
|
716
|
+
scores, offset, rotations, rotation_mapping = states[i]
|
543
717
|
|
544
718
|
outer_table = np.arange(len(rotation_mapping), dtype=int)
|
545
719
|
lookup_table = np.array(
|
546
|
-
[
|
720
|
+
[master_rotation_mapping[key] for key in rotation_mapping.keys()],
|
547
721
|
dtype=int,
|
548
722
|
)
|
549
723
|
|
@@ -559,7 +733,7 @@ class MaxScoreOverTranslations(MaxScoreOverRotations):
|
|
559
733
|
scores_out.flush()
|
560
734
|
scores_out = None
|
561
735
|
|
562
|
-
|
736
|
+
states[i], scores = None, None
|
563
737
|
|
564
738
|
if use_memmap:
|
565
739
|
scores_out = np.memmap(
|
@@ -569,9 +743,9 @@ class MaxScoreOverTranslations(MaxScoreOverRotations):
|
|
569
743
|
return (
|
570
744
|
scores_out,
|
571
745
|
np.zeros(scores_out.ndim, dtype=int),
|
572
|
-
|
573
|
-
|
746
|
+
states[2],
|
747
|
+
cls._invert_rmap(master_rotation_mapping),
|
574
748
|
)
|
575
749
|
|
576
|
-
def
|
577
|
-
return
|
750
|
+
def result(self, state: Tuple, **kwargs) -> Tuple:
|
751
|
+
return state
|