pytme 0.2.9__cp311-cp311-macosx_15_0_arm64.whl → 0.3b0__cp311-cp311-macosx_15_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pytme-0.2.9.data/scripts/estimate_ram_usage.py → pytme-0.3b0.data/scripts/estimate_memory_usage.py +16 -33
- {pytme-0.2.9.data → pytme-0.3b0.data}/scripts/match_template.py +224 -223
- {pytme-0.2.9.data → pytme-0.3b0.data}/scripts/postprocess.py +283 -163
- {pytme-0.2.9.data → pytme-0.3b0.data}/scripts/preprocess.py +11 -8
- {pytme-0.2.9.data → pytme-0.3b0.data}/scripts/preprocessor_gui.py +10 -9
- {pytme-0.2.9.dist-info → pytme-0.3b0.dist-info}/METADATA +11 -9
- {pytme-0.2.9.dist-info → pytme-0.3b0.dist-info}/RECORD +61 -58
- {pytme-0.2.9.dist-info → pytme-0.3b0.dist-info}/entry_points.txt +1 -1
- scripts/{estimate_ram_usage.py → estimate_memory_usage.py} +16 -33
- scripts/extract_candidates.py +224 -0
- scripts/match_template.py +224 -223
- scripts/postprocess.py +283 -163
- scripts/preprocess.py +11 -8
- scripts/preprocessor_gui.py +10 -9
- scripts/refine_matches.py +626 -0
- tests/preprocessing/test_frequency_filters.py +9 -4
- tests/test_analyzer.py +143 -138
- tests/test_matching_cli.py +85 -29
- tests/test_matching_exhaustive.py +1 -2
- tests/test_matching_optimization.py +4 -9
- tests/test_orientations.py +0 -1
- tme/__version__.py +1 -1
- tme/analyzer/__init__.py +2 -0
- tme/analyzer/_utils.py +25 -17
- tme/analyzer/aggregation.py +385 -220
- tme/analyzer/base.py +138 -0
- tme/analyzer/peaks.py +150 -88
- tme/analyzer/proxy.py +122 -0
- tme/backends/__init__.py +4 -3
- tme/backends/_cupy_utils.py +25 -24
- tme/backends/_jax_utils.py +4 -3
- tme/backends/cupy_backend.py +4 -13
- tme/backends/jax_backend.py +6 -8
- tme/backends/matching_backend.py +4 -3
- tme/backends/mlx_backend.py +4 -3
- tme/backends/npfftw_backend.py +7 -5
- tme/backends/pytorch_backend.py +14 -4
- tme/cli.py +126 -0
- tme/density.py +4 -3
- tme/filters/__init__.py +1 -1
- tme/filters/_utils.py +4 -3
- tme/filters/bandpass.py +6 -4
- tme/filters/compose.py +5 -4
- tme/filters/ctf.py +426 -214
- tme/filters/reconstruction.py +58 -28
- tme/filters/wedge.py +139 -61
- tme/filters/whitening.py +36 -36
- tme/matching_data.py +4 -3
- tme/matching_exhaustive.py +17 -16
- tme/matching_optimization.py +5 -4
- tme/matching_scores.py +4 -3
- tme/matching_utils.py +6 -4
- tme/memory.py +4 -3
- tme/orientations.py +9 -6
- tme/parser.py +5 -4
- tme/preprocessor.py +4 -3
- tme/rotations.py +10 -7
- tme/structure.py +4 -3
- tests/data/Maps/.DS_Store +0 -0
- tests/data/Structures/.DS_Store +0 -0
- {pytme-0.2.9.dist-info → pytme-0.3b0.dist-info}/WHEEL +0 -0
- {pytme-0.2.9.dist-info → pytme-0.3b0.dist-info}/licenses/LICENSE +0 -0
- {pytme-0.2.9.dist-info → pytme-0.3b0.dist-info}/top_level.txt +0 -0
tme/analyzer/aggregation.py
CHANGED
@@ -1,34 +1,34 @@
|
|
1
|
-
"""
|
1
|
+
"""
|
2
|
+
Implements classes to analyze outputs from exhaustive template matching.
|
2
3
|
|
3
|
-
|
4
|
+
Copyright (c) 2023 European Molecular Biology Laboratory
|
4
5
|
|
5
|
-
|
6
|
+
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
6
7
|
"""
|
7
8
|
|
8
|
-
from
|
9
|
-
from multiprocessing import Manager
|
10
|
-
from typing import Tuple, List, Dict, Generator
|
9
|
+
from typing import Tuple, List, Dict
|
11
10
|
|
12
11
|
import numpy as np
|
13
12
|
|
13
|
+
from .base import AbstractAnalyzer
|
14
14
|
from ..types import BackendArray
|
15
15
|
from ._utils import cart_to_score
|
16
16
|
from ..backends import backend as be
|
17
17
|
from ..matching_utils import (
|
18
18
|
create_mask,
|
19
19
|
array_to_memmap,
|
20
|
-
generate_tempfile_name,
|
21
20
|
apply_convolution_mode,
|
21
|
+
generate_tempfile_name,
|
22
22
|
)
|
23
23
|
|
24
|
-
|
25
24
|
__all__ = [
|
26
25
|
"MaxScoreOverRotations",
|
26
|
+
"MaxScoreOverRotationsConstrained",
|
27
27
|
"MaxScoreOverTranslations",
|
28
28
|
]
|
29
29
|
|
30
30
|
|
31
|
-
class MaxScoreOverRotations:
|
31
|
+
class MaxScoreOverRotations(AbstractAnalyzer):
|
32
32
|
"""
|
33
33
|
Determine the rotation maximizing the score over all possible translations.
|
34
34
|
|
@@ -36,10 +36,6 @@ class MaxScoreOverRotations:
|
|
36
36
|
----------
|
37
37
|
shape : tuple of int
|
38
38
|
Shape of array passed to :py:meth:`MaxScoreOverRotations.__call__`.
|
39
|
-
scores : BackendArray, optional
|
40
|
-
Array mapping translations to scores.
|
41
|
-
rotations : BackendArray, optional
|
42
|
-
Array mapping translations to rotation indices.
|
43
39
|
offset : BackendArray, optional
|
44
40
|
Coordinate origin considered during merging, zero by default.
|
45
41
|
score_threshold : float, optional
|
@@ -50,19 +46,11 @@ class MaxScoreOverRotations:
|
|
50
46
|
Memmap internal arrays, False by default.
|
51
47
|
thread_safe: bool, optional
|
52
48
|
Allow class to be modified by multiple processes, True by default.
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
scores : BackendArray
|
59
|
-
Mapping of translations to scores.
|
60
|
-
rotations : BackendArray
|
61
|
-
Mmapping of translations to rotation indices.
|
62
|
-
rotation_mapping : Dict
|
63
|
-
Mapping of rotations to rotation indices.
|
64
|
-
offset : BackendArray, optional
|
65
|
-
Coordinate origin considered during merging, zero by default
|
49
|
+
inversion_mapping : bool, optional
|
50
|
+
Do not use rotation matrix bytestrings for intermediate data handling.
|
51
|
+
This is useful for GPU backend where analyzers are not shared across
|
52
|
+
devices and every rotation is only observed once. It is generally
|
53
|
+
safe to deactivate inversion mapping, but at a cost of performance.
|
66
54
|
|
67
55
|
Examples
|
68
56
|
--------
|
@@ -75,33 +63,29 @@ class MaxScoreOverRotations:
|
|
75
63
|
The following simulates a template matching run by creating random data for a range
|
76
64
|
of rotations and sending it to ``analyzer`` via its __call__ method
|
77
65
|
|
66
|
+
>> state = analyzer.init_state()
|
78
67
|
>>> for rotation_number in range(10):
|
79
68
|
>>> scores = np.random.rand(50,50)
|
80
69
|
>>> rotation = np.random.rand(scores.ndim, scores.ndim)
|
81
|
-
>>> analyzer(scores = scores, rotation_matrix = rotation)
|
70
|
+
>>> state, analyzer(state, scores = scores, rotation_matrix = rotation)
|
82
71
|
|
83
|
-
The aggregated scores can be extracted by invoking the
|
72
|
+
The aggregated scores can be extracted by invoking the result method of
|
84
73
|
``analyzer``
|
85
74
|
|
86
|
-
>>> results =
|
75
|
+
>>> results = analyzer.result()
|
87
76
|
|
88
77
|
The ``results`` tuple contains (1) the maximum scores for each translation,
|
89
78
|
(2) an offset which is relevant when merging results from split template matching
|
90
79
|
using :py:meth:`MaxScoreOverRotations.merge`, (3) the rotation used to obtain a
|
91
|
-
score for a given translation, (4) a dictionary mapping
|
92
|
-
|
80
|
+
score for a given translation, (4) a dictionary mapping indices used in (2) to
|
81
|
+
rotation matrices (2).
|
93
82
|
|
94
83
|
We can extract the ``optimal_score``, ``optimal_translation`` and ``optimal_rotation``
|
95
84
|
as follows
|
96
85
|
|
97
86
|
>>> optimal_score = results[0].max()
|
98
87
|
>>> optimal_translation = np.where(results[0] == results[0].max())
|
99
|
-
>>>
|
100
|
-
>>> for key, value in results[3].items():
|
101
|
-
>>> if value != optimal_rotation_index:
|
102
|
-
>>> continue
|
103
|
-
>>> optimal_rotation = np.frombuffer(key, rotation.dtype)
|
104
|
-
>>> optimal_rotation = optimal_rotation.reshape(scores.ndim, scores.ndim)
|
88
|
+
>>> optimal_rotation = results[2][optimal_translation]
|
105
89
|
|
106
90
|
The outlined procedure is a trivial method to identify high scoring peaks.
|
107
91
|
Alternatively, :py:class:`PeakCaller` offers a range of more elaborate approaches
|
@@ -111,156 +95,213 @@ class MaxScoreOverRotations:
|
|
111
95
|
def __init__(
|
112
96
|
self,
|
113
97
|
shape: Tuple[int],
|
114
|
-
scores: BackendArray = None,
|
115
|
-
rotations: BackendArray = None,
|
116
98
|
offset: BackendArray = None,
|
117
99
|
score_threshold: float = 0,
|
118
100
|
shm_handler: object = None,
|
119
101
|
use_memmap: bool = False,
|
120
|
-
|
121
|
-
only_unique_rotations: bool = False,
|
102
|
+
inversion_mapping: bool = False,
|
122
103
|
**kwargs,
|
123
104
|
):
|
105
|
+
self._use_memmap = use_memmap
|
106
|
+
self._score_threshold = score_threshold
|
124
107
|
self._shape = tuple(int(x) for x in shape)
|
125
|
-
|
126
|
-
self.scores = scores
|
127
|
-
if self.scores is None:
|
128
|
-
self.scores = be.full(
|
129
|
-
shape=self._shape, dtype=be._float_dtype, fill_value=score_threshold
|
130
|
-
)
|
131
|
-
self.rotations = rotations
|
132
|
-
if self.rotations is None:
|
133
|
-
self.rotations = be.full(self._shape, dtype=be._int_dtype, fill_value=-1)
|
134
|
-
|
135
|
-
self.scores = be.to_sharedarr(self.scores, shm_handler)
|
136
|
-
self.rotations = be.to_sharedarr(self.rotations, shm_handler)
|
108
|
+
self._inversion_mapping = inversion_mapping
|
137
109
|
|
138
110
|
if offset is None:
|
139
111
|
offset = be.zeros(len(self._shape), be._int_dtype)
|
140
|
-
self.
|
112
|
+
self._offset = be.astype(be.to_backend_array(offset), int)
|
141
113
|
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
114
|
+
@property
|
115
|
+
def shareable(self):
|
116
|
+
return True
|
117
|
+
|
118
|
+
def init_state(self):
|
119
|
+
"""
|
120
|
+
Initialize the analysis state.
|
147
121
|
|
148
|
-
|
122
|
+
Returns
|
123
|
+
-------
|
124
|
+
tuple
|
125
|
+
Initial state tuple containing (scores, rotations, rotation_mapping) where:
|
126
|
+
- scores : BackendArray of shape `self._shape` filled with `score_threshold`.
|
127
|
+
- rotations : BackendArray of shape `self._shape` filled with -1.
|
128
|
+
- rotation_mapping : dict, empty mapping from rotation bytes to indices.
|
129
|
+
"""
|
130
|
+
scores = be.full(
|
131
|
+
shape=self._shape, dtype=be._float_dtype, fill_value=self._score_threshold
|
132
|
+
)
|
133
|
+
rotations = be.full(self._shape, dtype=be._int_dtype, fill_value=-1)
|
134
|
+
return scores, rotations, {}
|
135
|
+
|
136
|
+
def __call__(
|
149
137
|
self,
|
150
|
-
|
151
|
-
|
152
|
-
|
138
|
+
state: Tuple,
|
139
|
+
scores: BackendArray,
|
140
|
+
rotation_matrix: BackendArray,
|
141
|
+
) -> Tuple:
|
142
|
+
"""
|
143
|
+
Update the parameter store.
|
144
|
+
|
145
|
+
Parameters
|
146
|
+
----------
|
147
|
+
state : tuple
|
148
|
+
Current state tuple (scores, rotations, rotation_mapping) where:
|
149
|
+
- scores : BackendArray, current maximum scores.
|
150
|
+
- rotations : BackendArray, current rotation indices.
|
151
|
+
- rotation_mapping : dict, mapping from rotation bytes to indices.
|
152
|
+
scores : BackendArray
|
153
|
+
Array of new scores to update analyzer with.
|
154
|
+
rotation_matrix : BackendArray
|
155
|
+
Square matrix used to obtain the current rotation.
|
156
|
+
Returns
|
157
|
+
-------
|
158
|
+
tuple
|
159
|
+
Updated state tuple (scores, rotations, rotation_mapping).
|
160
|
+
"""
|
161
|
+
# be.tobytes behaviour caused overhead for certain GPU/CUDA combinations
|
162
|
+
# If the analyzer is not shared and each rotation is unique, we can
|
163
|
+
# use index to rotation mapping and invert prior to merging.
|
164
|
+
prev_scores, rotations, rotation_mapping = state
|
165
|
+
|
166
|
+
rotation_index = len(rotation_mapping)
|
167
|
+
rotation_matrix = be.astype(rotation_matrix, be._float_dtype)
|
168
|
+
if self._inversion_mapping:
|
169
|
+
rotation_mapping[rotation_index] = rotation_matrix
|
170
|
+
else:
|
171
|
+
rotation = be.tobytes(rotation_matrix)
|
172
|
+
rotation_index = rotation_mapping.setdefault(rotation, rotation_index)
|
173
|
+
|
174
|
+
scores, rotations = be.max_score_over_rotations(
|
175
|
+
scores=scores,
|
176
|
+
max_scores=prev_scores,
|
177
|
+
rotations=rotations,
|
178
|
+
rotation_index=rotation_index,
|
179
|
+
)
|
180
|
+
return scores, rotations, rotation_mapping
|
181
|
+
|
182
|
+
@staticmethod
|
183
|
+
def _invert_rmap(rotation_mapping: dict) -> dict:
|
184
|
+
"""
|
185
|
+
Invert dictionary from rotation matrix bytestrings mapping to rotation
|
186
|
+
indices ro rotation indices mapping to rotation matrices.
|
187
|
+
"""
|
188
|
+
new_map, ndim = {}, None
|
189
|
+
for k, v in rotation_mapping.items():
|
190
|
+
nbytes = be.datatype_bytes(be._float_dtype)
|
191
|
+
dtype = np.float32 if nbytes == 4 else np.float16
|
192
|
+
rmat = np.frombuffer(k, dtype=dtype)
|
193
|
+
if ndim is None:
|
194
|
+
ndim = int(np.sqrt(rmat.size))
|
195
|
+
new_map[v] = rmat.reshape(ndim, ndim)
|
196
|
+
return new_map
|
197
|
+
|
198
|
+
def result(
|
199
|
+
self,
|
200
|
+
state,
|
201
|
+
targetshape: Tuple[int] = None,
|
202
|
+
templateshape: Tuple[int] = None,
|
203
|
+
convolution_shape: Tuple[int] = None,
|
153
204
|
fourier_shift: Tuple[int] = None,
|
154
205
|
convolution_mode: str = None,
|
155
|
-
shm_handler=None,
|
156
206
|
**kwargs,
|
157
|
-
) ->
|
158
|
-
"""
|
159
|
-
|
160
|
-
|
207
|
+
) -> Tuple:
|
208
|
+
"""
|
209
|
+
Finalize the analysis result with optional postprocessing.
|
210
|
+
|
211
|
+
Parameters
|
212
|
+
----------
|
213
|
+
state : tuple
|
214
|
+
Current state tuple (scores, rotations, rotation_mapping) where:
|
215
|
+
- scores : BackendArray, current maximum scores.
|
216
|
+
- rotations : BackendArray, current rotation indices.
|
217
|
+
- rotation_mapping : dict, mapping from rotation indices to matrices.
|
218
|
+
targetshape : Tuple[int], optional
|
219
|
+
Shape of the target for convolution mode correction.
|
220
|
+
templateshape : Tuple[int], optional
|
221
|
+
Shape of the template for convolution mode correction.
|
222
|
+
convolution_shape : Tuple[int], optional
|
223
|
+
Shape used for convolution.
|
224
|
+
fourier_shift : Tuple[int], optional.
|
225
|
+
Shift to apply for Fourier correction.
|
226
|
+
convolution_mode : str, optional
|
227
|
+
Convolution mode for padding correction.
|
228
|
+
**kwargs
|
229
|
+
Additional keyword arguments.
|
230
|
+
|
231
|
+
Returns
|
232
|
+
-------
|
233
|
+
tuple
|
234
|
+
Final result tuple (scores, offset, rotations, rotation_mapping).
|
235
|
+
"""
|
236
|
+
scores, rotations, rotation_mapping = state
|
237
|
+
|
238
|
+
# Apply postprocessing if parameters are provided
|
161
239
|
if fourier_shift is not None:
|
162
240
|
axis = tuple(i for i in range(len(fourier_shift)))
|
163
241
|
scores = be.roll(scores, shift=fourier_shift, axis=axis)
|
164
242
|
rotations = be.roll(rotations, shift=fourier_shift, axis=axis)
|
165
243
|
|
166
|
-
convargs = {
|
167
|
-
"s1": targetshape,
|
168
|
-
"s2": templateshape,
|
169
|
-
"convolution_mode": convolution_mode,
|
170
|
-
"convolution_shape": convolution_shape,
|
171
|
-
}
|
172
244
|
if convolution_mode is not None:
|
245
|
+
convargs = {
|
246
|
+
"s1": targetshape,
|
247
|
+
"s2": templateshape,
|
248
|
+
"convolution_mode": convolution_mode,
|
249
|
+
"convolution_shape": convolution_shape,
|
250
|
+
}
|
173
251
|
scores = apply_convolution_mode(scores, **convargs)
|
174
252
|
rotations = apply_convolution_mode(rotations, **convargs)
|
175
253
|
|
176
|
-
self._shape, self.scores, self.rotations = scores.shape, scores, rotations
|
177
|
-
if shm_handler is not None:
|
178
|
-
self.scores = be.to_sharedarr(scores, shm_handler)
|
179
|
-
self.rotations = be.to_sharedarr(rotations, shm_handler)
|
180
|
-
return self
|
181
|
-
|
182
|
-
def __iter__(self) -> Generator:
|
183
|
-
scores = be.from_sharedarr(self.scores)
|
184
|
-
rotations = be.from_sharedarr(self.rotations)
|
185
|
-
|
186
254
|
scores = be.to_numpy_array(scores)
|
187
255
|
rotations = be.to_numpy_array(rotations)
|
188
256
|
if self._use_memmap:
|
189
257
|
scores = array_to_memmap(scores)
|
190
258
|
rotations = array_to_memmap(rotations)
|
191
|
-
else:
|
192
|
-
if type(self.scores) is not type(scores):
|
193
|
-
# Copy to avoid invalidation by shared memory handler
|
194
|
-
scores, rotations = scores.copy(), rotations.copy()
|
195
259
|
|
196
260
|
if self._inversion_mapping:
|
197
|
-
|
198
|
-
be.tobytes(v): k for k, v in self.rotation_mapping.items()
|
199
|
-
}
|
261
|
+
rotation_mapping = {be.tobytes(v): k for k, v in rotation_mapping.items()}
|
200
262
|
|
201
|
-
|
263
|
+
return (
|
202
264
|
scores,
|
203
|
-
be.to_numpy_array(self.
|
265
|
+
be.to_numpy_array(self._offset),
|
204
266
|
rotations,
|
205
|
-
|
267
|
+
self._invert_rmap(rotation_mapping),
|
206
268
|
)
|
207
|
-
yield from param_store
|
208
269
|
|
209
|
-
def
|
270
|
+
def _harmonize_states(states: List[Tuple]):
|
210
271
|
"""
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
----------
|
215
|
-
scores : BackendArray
|
216
|
-
Array of scores.
|
217
|
-
rotation_matrix : BackendArray
|
218
|
-
Square matrix describing the current rotation.
|
272
|
+
Create consistent reference frame for merging different analyzer
|
273
|
+
instances, w.r.t. to rotations and output shape from different
|
274
|
+
splits of the target.
|
219
275
|
"""
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
rotation_index = len(self.rotation_mapping)
|
225
|
-
if self._inversion_mapping:
|
226
|
-
self.rotation_mapping[rotation_index] = rotation_matrix
|
227
|
-
else:
|
228
|
-
rotation = be.tobytes(rotation_matrix)
|
229
|
-
rotation_index = self.rotation_mapping.setdefault(
|
230
|
-
rotation, rotation_index
|
231
|
-
)
|
232
|
-
self.scores, self.rotations = be.max_score_over_rotations(
|
233
|
-
scores=scores,
|
234
|
-
max_scores=self.scores,
|
235
|
-
rotations=self.rotations,
|
236
|
-
rotation_index=rotation_index,
|
237
|
-
)
|
238
|
-
return None
|
276
|
+
new_rotation_mapping, out_shape = {}, None
|
277
|
+
for i in range(len(states)):
|
278
|
+
if states[i] is None:
|
279
|
+
continue
|
239
280
|
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
281
|
+
scores, offset, rotations, rotation_mapping = states[i]
|
282
|
+
if out_shape is None:
|
283
|
+
out_shape = np.zeros(scores.ndim, int)
|
284
|
+
out_shape = np.maximum(out_shape, np.add(offset, scores.shape))
|
285
|
+
|
286
|
+
new_param = {}
|
287
|
+
for key, value in rotation_mapping.items():
|
288
|
+
rotation_bytes = be.tobytes(value)
|
289
|
+
new_param[rotation_bytes] = key
|
290
|
+
if rotation_bytes not in new_rotation_mapping:
|
291
|
+
new_rotation_mapping[rotation_bytes] = len(new_rotation_mapping)
|
292
|
+
states[i] = (scores, offset, rotations, new_param)
|
293
|
+
out_shape = tuple(int(x) for x in out_shape)
|
294
|
+
return new_rotation_mapping, out_shape, states
|
254
295
|
|
255
296
|
@classmethod
|
256
|
-
def merge(cls,
|
297
|
+
def merge(cls, results: List[Tuple], **kwargs) -> Tuple:
|
257
298
|
"""
|
258
299
|
Merge multiple instances of the current class.
|
259
300
|
|
260
301
|
Parameters
|
261
302
|
----------
|
262
|
-
|
263
|
-
List of instance's internal state created by applying `
|
303
|
+
results : list of tuple
|
304
|
+
List of instance's internal state created by applying `result`.
|
264
305
|
**kwargs : dict, optional
|
265
306
|
Optional keyword arguments.
|
266
307
|
|
@@ -276,8 +317,8 @@ class MaxScoreOverRotations:
|
|
276
317
|
Mapping between rotations and rotation indices.
|
277
318
|
"""
|
278
319
|
use_memmap = kwargs.get("use_memmap", False)
|
279
|
-
if len(
|
280
|
-
ret =
|
320
|
+
if len(results) == 1:
|
321
|
+
ret = results[0]
|
281
322
|
if use_memmap:
|
282
323
|
scores, offset, rotations, rotation_mapping = ret
|
283
324
|
scores = array_to_memmap(scores)
|
@@ -287,25 +328,12 @@ class MaxScoreOverRotations:
|
|
287
328
|
return ret
|
288
329
|
|
289
330
|
# Determine output array shape and create consistent rotation map
|
290
|
-
|
291
|
-
for i in range(len(param_stores)):
|
292
|
-
if param_stores[i] is None:
|
293
|
-
continue
|
294
|
-
|
295
|
-
scores, offset, rotations, rotation_mapping = param_stores[i]
|
296
|
-
if out_shape is None:
|
297
|
-
out_shape = np.zeros(scores.ndim, int)
|
298
|
-
scores_dtype, rotations_dtype = scores.dtype, rotations.dtype
|
299
|
-
out_shape = np.maximum(out_shape, np.add(offset, scores.shape))
|
300
|
-
|
301
|
-
for key, value in rotation_mapping.items():
|
302
|
-
if key not in new_rotation_mapping:
|
303
|
-
new_rotation_mapping[key] = len(new_rotation_mapping)
|
304
|
-
|
331
|
+
master_rotation_mapping, out_shape, results = cls._harmonize_states(results)
|
305
332
|
if out_shape is None:
|
306
333
|
return None
|
307
334
|
|
308
|
-
|
335
|
+
scores_dtype = results[0][0].dtype
|
336
|
+
rotations_dtype = results[0][2].dtype
|
309
337
|
if use_memmap:
|
310
338
|
scores_out_filename = generate_tempfile_name()
|
311
339
|
rotations_out_filename = generate_tempfile_name()
|
@@ -331,8 +359,8 @@ class MaxScoreOverRotations:
|
|
331
359
|
)
|
332
360
|
rotations_out = np.full(out_shape, fill_value=-1, dtype=rotations_dtype)
|
333
361
|
|
334
|
-
for i in range(len(
|
335
|
-
if
|
362
|
+
for i in range(len(results)):
|
363
|
+
if results[i] is None:
|
336
364
|
continue
|
337
365
|
|
338
366
|
if use_memmap:
|
@@ -348,7 +376,7 @@ class MaxScoreOverRotations:
|
|
348
376
|
shape=out_shape,
|
349
377
|
dtype=rotations_dtype,
|
350
378
|
)
|
351
|
-
scores, offset, rotations, rotation_mapping =
|
379
|
+
scores, offset, rotations, rotation_mapping = results[i]
|
352
380
|
stops = np.add(offset, scores.shape).astype(int)
|
353
381
|
indices = tuple(slice(*pos) for pos in zip(offset, stops))
|
354
382
|
|
@@ -359,7 +387,7 @@ class MaxScoreOverRotations:
|
|
359
387
|
len(rotation_mapping) + 1, dtype=rotations_out.dtype
|
360
388
|
)
|
361
389
|
for key, value in rotation_mapping.items():
|
362
|
-
lookup_table[value] =
|
390
|
+
lookup_table[value] = master_rotation_mapping[key]
|
363
391
|
|
364
392
|
updated_rotations = rotations[indices_update]
|
365
393
|
if len(updated_rotations):
|
@@ -372,7 +400,7 @@ class MaxScoreOverRotations:
|
|
372
400
|
rotations_out.flush()
|
373
401
|
scores_out, rotations_out = None, None
|
374
402
|
|
375
|
-
|
403
|
+
results[i] = None
|
376
404
|
scores, rotations = None, None
|
377
405
|
|
378
406
|
if use_memmap:
|
@@ -390,13 +418,166 @@ class MaxScoreOverRotations:
|
|
390
418
|
scores_out,
|
391
419
|
np.zeros(scores_out.ndim, dtype=int),
|
392
420
|
rotations_out,
|
393
|
-
|
421
|
+
cls._invert_rmap(master_rotation_mapping),
|
394
422
|
)
|
395
423
|
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
424
|
+
|
425
|
+
class MaxScoreOverRotationsConstrained(MaxScoreOverRotations):
|
426
|
+
"""
|
427
|
+
Implements constrained template matching using rejection sampling.
|
428
|
+
|
429
|
+
Parameters
|
430
|
+
----------
|
431
|
+
cone_angle : float
|
432
|
+
Maximum accepted rotational deviation in degrees.
|
433
|
+
positions : BackendArray
|
434
|
+
Array of shape (n, d) with n seed point translations.
|
435
|
+
positions : BackendArray
|
436
|
+
Array of shape (n, d, d) with n seed point rotation matrices.
|
437
|
+
reference : BackendArray
|
438
|
+
Reference orientation of the template, wlog defaults to (0,0,1).
|
439
|
+
acceptance_radius : int or tuple of ints
|
440
|
+
Translational acceptance radius around seed point in voxels.
|
441
|
+
**kwargs : dict, optional
|
442
|
+
Keyword aguments passed to the constructor of :py:class:`MaxScoreOverRotations`.
|
443
|
+
"""
|
444
|
+
|
445
|
+
def __init__(
|
446
|
+
self,
|
447
|
+
cone_angle: float,
|
448
|
+
positions: BackendArray,
|
449
|
+
rotations: BackendArray,
|
450
|
+
reference: BackendArray = (0, 0, 1),
|
451
|
+
acceptance_radius: int = 10,
|
452
|
+
**kwargs,
|
453
|
+
):
|
454
|
+
MaxScoreOverRotations.__init__(self, **kwargs)
|
455
|
+
|
456
|
+
if not isinstance(acceptance_radius, (int, Tuple)):
|
457
|
+
raise ValueError("acceptance_radius needs to be of type int or tuple.")
|
458
|
+
|
459
|
+
if isinstance(acceptance_radius, int):
|
460
|
+
acceptance_radius = (
|
461
|
+
acceptance_radius,
|
462
|
+
acceptance_radius,
|
463
|
+
acceptance_radius,
|
464
|
+
)
|
465
|
+
acceptance_radius = tuple(int(x) for x in acceptance_radius)
|
466
|
+
|
467
|
+
self._cone_angle = float(np.radians(cone_angle))
|
468
|
+
self._cone_cutoff = float(np.tan(self._cone_angle))
|
469
|
+
self._reference = be.astype(
|
470
|
+
be.reshape(be.to_backend_array(reference), (-1,)), be._float_dtype
|
471
|
+
)
|
472
|
+
positions = be.astype(be.to_backend_array(positions), be._int_dtype)
|
473
|
+
|
474
|
+
ndim = positions.shape[1]
|
475
|
+
rotate_mask = len(set(acceptance_radius)) != 1
|
476
|
+
extend = max(acceptance_radius)
|
477
|
+
mask = create_mask(
|
478
|
+
mask_type="ellipse",
|
479
|
+
radius=acceptance_radius,
|
480
|
+
shape=tuple(2 * extend + 1 for _ in range(ndim)),
|
481
|
+
center=tuple(extend for _ in range(ndim)),
|
482
|
+
)
|
483
|
+
self._score_mask = be.astype(be.to_backend_array(mask), be._float_dtype)
|
484
|
+
|
485
|
+
# Map position from real space to shifted score space
|
486
|
+
lower_limit = be.to_backend_array(self._offset)
|
487
|
+
positions = be.subtract(positions, lower_limit)
|
488
|
+
positions, valid_positions = cart_to_score(
|
489
|
+
positions=positions,
|
490
|
+
fast_shape=kwargs.get("fast_shape", None),
|
491
|
+
targetshape=kwargs.get("targetshape", None),
|
492
|
+
templateshape=kwargs.get("templateshape", None),
|
493
|
+
fourier_shift=kwargs.get("fourier_shift", None),
|
494
|
+
convolution_mode=kwargs.get("convolution_mode", None),
|
495
|
+
convolution_shape=kwargs.get("convolution_shape", None),
|
496
|
+
)
|
497
|
+
|
498
|
+
self._positions = positions[valid_positions]
|
499
|
+
rotations = be.to_backend_array(rotations)[valid_positions]
|
500
|
+
ex = be.astype(be.to_backend_array((1, 0, 0)), be._float_dtype)
|
501
|
+
ey = be.astype(be.to_backend_array((0, 1, 0)), be._float_dtype)
|
502
|
+
ez = be.astype(be.to_backend_array((0, 0, 1)), be._float_dtype)
|
503
|
+
|
504
|
+
self._normals_x = (rotations @ ex[..., None])[..., 0]
|
505
|
+
self._normals_y = (rotations @ ey[..., None])[..., 0]
|
506
|
+
self._normals_z = (rotations @ ez[..., None])[..., 0]
|
507
|
+
|
508
|
+
# Periodic wrapping could be avoided by padding the target
|
509
|
+
shape = be.to_backend_array(self._shape)
|
510
|
+
starts = be.subtract(self._positions, extend)
|
511
|
+
ret, (n, d), mshape = [], self._positions.shape, self._score_mask.shape
|
512
|
+
if starts.shape[0] > 0:
|
513
|
+
for i in range(d):
|
514
|
+
indices = starts[:, slice(i, i + 1)] + be.arange(mshape[i])[None]
|
515
|
+
indices = be.mod(indices, shape[i], out=indices)
|
516
|
+
indices_shape = (n, *tuple(1 if k != i else -1 for k in range(d)))
|
517
|
+
ret.append(be.reshape(indices, indices_shape))
|
518
|
+
|
519
|
+
self._index_grid = tuple(ret)
|
520
|
+
self._mask_shape = tuple(1 if i != 0 else -1 for i in range(1 + ndim))
|
521
|
+
|
522
|
+
if rotate_mask:
|
523
|
+
self._score_mask = be.zeros(
|
524
|
+
(rotations.shape[0], *self._score_mask.shape), dtype=be._float_dtype
|
525
|
+
)
|
526
|
+
for i in range(rotations.shape[0]):
|
527
|
+
mask = create_mask(
|
528
|
+
mask_type="ellipse",
|
529
|
+
radius=acceptance_radius,
|
530
|
+
shape=tuple(2 * extend + 1 for _ in range(ndim)),
|
531
|
+
center=tuple(extend for _ in range(ndim)),
|
532
|
+
orientation=be.to_numpy_array(rotations[i]),
|
533
|
+
)
|
534
|
+
self._score_mask[i] = be.astype(
|
535
|
+
be.to_backend_array(mask), be._float_dtype
|
536
|
+
)
|
537
|
+
|
538
|
+
def __call__(
|
539
|
+
self, state: Tuple, scores: BackendArray, rotation_matrix: BackendArray
|
540
|
+
) -> Tuple:
|
541
|
+
mask = self._get_constraint(rotation_matrix)
|
542
|
+
mask = self._get_score_mask(mask=mask, scores=scores)
|
543
|
+
|
544
|
+
scores = be.multiply(scores, mask, out=scores)
|
545
|
+
return super().__call__(state, scores=scores, rotation_matrix=rotation_matrix)
|
546
|
+
|
547
|
+
def _get_constraint(self, rotation_matrix: BackendArray) -> BackendArray:
|
548
|
+
"""
|
549
|
+
Determine whether the angle between projection of reference w.r.t to
|
550
|
+
a given rotation matrix and a set of rotations fall within the set
|
551
|
+
cone_angle cutoff.
|
552
|
+
|
553
|
+
Parameters
|
554
|
+
----------
|
555
|
+
rotation_matrix : BackendArray
|
556
|
+
Rotation matrix with shape (d,d).
|
557
|
+
|
558
|
+
Returns
|
559
|
+
-------
|
560
|
+
BackerndArray
|
561
|
+
Boolean mask of shape (n, )
|
562
|
+
"""
|
563
|
+
template_rot = rotation_matrix @ self._reference
|
564
|
+
|
565
|
+
x = be.sum(be.multiply(self._normals_x, template_rot), axis=1)
|
566
|
+
y = be.sum(be.multiply(self._normals_y, template_rot), axis=1)
|
567
|
+
z = be.sum(be.multiply(self._normals_z, template_rot), axis=1)
|
568
|
+
|
569
|
+
return be.sqrt(x**2 + y**2) <= (z * self._cone_cutoff)
|
570
|
+
|
571
|
+
def _get_score_mask(self, mask: BackendArray, scores: BackendArray, **kwargs):
|
572
|
+
score_mask = be.zeros(scores.shape, scores.dtype)
|
573
|
+
|
574
|
+
if be.sum(mask) == 0:
|
575
|
+
return score_mask
|
576
|
+
mask = be.reshape(mask, self._mask_shape)
|
577
|
+
|
578
|
+
score_mask = be.addat(score_mask, self._index_grid, self._score_mask * mask)
|
579
|
+
return score_mask > 0
|
580
|
+
|
400
581
|
|
401
582
|
class MaxScoreOverTranslations(MaxScoreOverRotations):
|
402
583
|
"""
|
@@ -436,43 +617,40 @@ class MaxScoreOverTranslations(MaxScoreOverRotations):
|
|
436
617
|
super().__init__(
|
437
618
|
shape=shape_reduced, shm_handler=shm_handler, offset=offset, **kwargs
|
438
619
|
)
|
439
|
-
|
440
|
-
self.rotations = be.full(1, dtype=be._int_dtype, fill_value=-1)
|
441
|
-
self.rotations = be.to_sharedarr(self.rotations, shm_handler)
|
442
620
|
self._aggregate_axis = aggregate_axis
|
443
621
|
|
444
|
-
def
|
445
|
-
|
446
|
-
|
447
|
-
|
448
|
-
|
449
|
-
|
450
|
-
rotation = be.tobytes(rotation_matrix)
|
451
|
-
rotation_index = self.rotation_mapping.setdefault(
|
452
|
-
rotation, rotation_index
|
453
|
-
)
|
454
|
-
max_score = be.max(scores, axis=self._aggregate_axis)
|
455
|
-
self.scores[rotation_index] = max_score
|
456
|
-
return None
|
622
|
+
def init_state(self):
|
623
|
+
scores = be.full(
|
624
|
+
shape=self._shape, dtype=be._float_dtype, fill_value=self._score_threshold
|
625
|
+
)
|
626
|
+
rotations = be.full(1, dtype=be._int_dtype, fill_value=-1)
|
627
|
+
return scores, rotations, {}
|
457
628
|
|
458
|
-
|
459
|
-
|
460
|
-
|
461
|
-
|
462
|
-
|
463
|
-
|
464
|
-
|
465
|
-
|
466
|
-
|
629
|
+
def __call__(
|
630
|
+
self, state, scores: BackendArray, rotation_matrix: BackendArray
|
631
|
+
) -> Tuple:
|
632
|
+
prev_scores, rotations, rotation_mapping = state
|
633
|
+
|
634
|
+
rotation_index = len(rotation_mapping)
|
635
|
+
if self._inversion_mapping:
|
636
|
+
rotation_mapping[rotation_index] = rotation_matrix
|
637
|
+
else:
|
638
|
+
rotation = be.tobytes(rotation_matrix)
|
639
|
+
rotation_index = rotation_mapping.setdefault(rotation, rotation_index)
|
640
|
+
max_score = be.max(scores, axis=self._aggregate_axis)
|
641
|
+
|
642
|
+
update = prev_scores[rotation_index]
|
643
|
+
update = be.maximum(max_score, update, out=update)
|
644
|
+
return prev_scores, rotations, rotation_mapping
|
467
645
|
|
468
646
|
@classmethod
|
469
|
-
def merge(cls,
|
647
|
+
def merge(cls, states: List[Tuple], **kwargs) -> Tuple:
|
470
648
|
"""
|
471
649
|
Merge multiple instances of the current class.
|
472
650
|
|
473
651
|
Parameters
|
474
652
|
----------
|
475
|
-
|
653
|
+
states : list of tuple
|
476
654
|
List of instance's internal state created by applying `tuple(instance)`.
|
477
655
|
**kwargs : dict, optional
|
478
656
|
Optional keyword arguments.
|
@@ -488,31 +666,18 @@ class MaxScoreOverTranslations(MaxScoreOverRotations):
|
|
488
666
|
Dict
|
489
667
|
Mapping between rotations and rotation indices.
|
490
668
|
"""
|
491
|
-
if len(
|
492
|
-
return
|
669
|
+
if len(states) == 1:
|
670
|
+
return states[0]
|
493
671
|
|
494
672
|
# Determine output array shape and create consistent rotation map
|
495
|
-
|
496
|
-
for i in range(len(param_stores)):
|
497
|
-
if param_stores[i] is None:
|
498
|
-
continue
|
499
|
-
|
500
|
-
scores, offset, rotations, rotation_mapping = param_stores[i]
|
501
|
-
if out_shape is None:
|
502
|
-
out_shape = np.zeros(scores.ndim, int)
|
503
|
-
scores_dtype, rotations_out = scores.dtype, rotations
|
504
|
-
out_shape = np.maximum(out_shape, np.add(offset, scores.shape))
|
505
|
-
|
506
|
-
for key, value in rotation_mapping.items():
|
507
|
-
if key not in new_rotation_mapping:
|
508
|
-
new_rotation_mapping[key] = len(new_rotation_mapping)
|
509
|
-
|
673
|
+
states, master_rotation_mapping, out_shape = cls._harmonize_states(states)
|
510
674
|
if out_shape is None:
|
511
675
|
return None
|
512
676
|
|
513
|
-
out_shape[0] = len(
|
677
|
+
out_shape[0] = len(master_rotation_mapping)
|
514
678
|
out_shape = tuple(int(x) for x in out_shape)
|
515
679
|
|
680
|
+
scores_dtype = states[0][0].dtype
|
516
681
|
use_memmap = kwargs.get("use_memmap", False)
|
517
682
|
if use_memmap:
|
518
683
|
scores_out_filename = generate_tempfile_name()
|
@@ -528,8 +693,8 @@ class MaxScoreOverTranslations(MaxScoreOverRotations):
|
|
528
693
|
dtype=scores_dtype,
|
529
694
|
)
|
530
695
|
|
531
|
-
for i in range(len(
|
532
|
-
if
|
696
|
+
for i in range(len(states)):
|
697
|
+
if states[i] is None:
|
533
698
|
continue
|
534
699
|
|
535
700
|
if use_memmap:
|
@@ -539,11 +704,11 @@ class MaxScoreOverTranslations(MaxScoreOverRotations):
|
|
539
704
|
shape=out_shape,
|
540
705
|
dtype=scores_dtype,
|
541
706
|
)
|
542
|
-
scores, offset, rotations, rotation_mapping =
|
707
|
+
scores, offset, rotations, rotation_mapping = states[i]
|
543
708
|
|
544
709
|
outer_table = np.arange(len(rotation_mapping), dtype=int)
|
545
710
|
lookup_table = np.array(
|
546
|
-
[
|
711
|
+
[master_rotation_mapping[key] for key in rotation_mapping.keys()],
|
547
712
|
dtype=int,
|
548
713
|
)
|
549
714
|
|
@@ -559,7 +724,7 @@ class MaxScoreOverTranslations(MaxScoreOverRotations):
|
|
559
724
|
scores_out.flush()
|
560
725
|
scores_out = None
|
561
726
|
|
562
|
-
|
727
|
+
states[i], scores = None, None
|
563
728
|
|
564
729
|
if use_memmap:
|
565
730
|
scores_out = np.memmap(
|
@@ -569,8 +734,8 @@ class MaxScoreOverTranslations(MaxScoreOverRotations):
|
|
569
734
|
return (
|
570
735
|
scores_out,
|
571
736
|
np.zeros(scores_out.ndim, dtype=int),
|
572
|
-
|
573
|
-
|
737
|
+
states[2],
|
738
|
+
cls._invert_rmap(master_rotation_mapping),
|
574
739
|
)
|
575
740
|
|
576
741
|
def _postprocess(self, **kwargs):
|