pytme 0.2.0b0__cp311-cp311-macosx_14_0_arm64.whl → 0.2.2__cp311-cp311-macosx_14_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pytme-0.2.2.data/scripts/match_template.py +1187 -0
- {pytme-0.2.0b0.data → pytme-0.2.2.data}/scripts/postprocess.py +170 -71
- {pytme-0.2.0b0.data → pytme-0.2.2.data}/scripts/preprocessor_gui.py +179 -86
- pytme-0.2.2.dist-info/METADATA +91 -0
- pytme-0.2.2.dist-info/RECORD +74 -0
- {pytme-0.2.0b0.dist-info → pytme-0.2.2.dist-info}/WHEEL +1 -1
- scripts/extract_candidates.py +126 -87
- scripts/match_template.py +596 -209
- scripts/match_template_filters.py +571 -223
- scripts/postprocess.py +170 -71
- scripts/preprocessor_gui.py +179 -86
- scripts/refine_matches.py +567 -159
- tme/__init__.py +0 -1
- tme/__version__.py +1 -1
- tme/analyzer.py +627 -855
- tme/backends/__init__.py +41 -11
- tme/backends/_jax_utils.py +185 -0
- tme/backends/cupy_backend.py +120 -225
- tme/backends/jax_backend.py +282 -0
- tme/backends/matching_backend.py +464 -388
- tme/backends/mlx_backend.py +45 -68
- tme/backends/npfftw_backend.py +256 -514
- tme/backends/pytorch_backend.py +41 -154
- tme/density.py +312 -421
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/matching_data.py +366 -303
- tme/matching_exhaustive.py +279 -1521
- tme/matching_optimization.py +234 -129
- tme/matching_scores.py +884 -0
- tme/matching_utils.py +281 -387
- tme/memory.py +377 -0
- tme/orientations.py +226 -66
- tme/parser.py +3 -4
- tme/preprocessing/__init__.py +2 -0
- tme/preprocessing/_utils.py +217 -0
- tme/preprocessing/composable_filter.py +31 -0
- tme/preprocessing/compose.py +55 -0
- tme/preprocessing/frequency_filters.py +388 -0
- tme/preprocessing/tilt_series.py +1011 -0
- tme/preprocessor.py +574 -530
- tme/structure.py +495 -189
- tme/types.py +5 -3
- pytme-0.2.0b0.data/scripts/match_template.py +0 -800
- pytme-0.2.0b0.dist-info/METADATA +0 -73
- pytme-0.2.0b0.dist-info/RECORD +0 -66
- tme/helpers.py +0 -881
- tme/matching_constrained.py +0 -195
- {pytme-0.2.0b0.data → pytme-0.2.2.data}/scripts/estimate_ram_usage.py +0 -0
- {pytme-0.2.0b0.data → pytme-0.2.2.data}/scripts/preprocess.py +0 -0
- {pytme-0.2.0b0.dist-info → pytme-0.2.2.dist-info}/LICENSE +0 -0
- {pytme-0.2.0b0.dist-info → pytme-0.2.2.dist-info}/entry_points.txt +0 -0
- {pytme-0.2.0b0.dist-info → pytme-0.2.2.dist-info}/top_level.txt +0 -0
tme/matching_exhaustive.py
CHANGED
@@ -4,1357 +4,119 @@
|
|
4
4
|
|
5
5
|
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
6
6
|
"""
|
7
|
-
import os
|
8
7
|
import sys
|
9
8
|
import warnings
|
10
|
-
|
11
|
-
from itertools import product
|
12
|
-
from typing import Callable, Tuple, Dict
|
9
|
+
import traceback
|
13
10
|
from functools import wraps
|
11
|
+
from itertools import product
|
12
|
+
from typing import Callable, Tuple, Dict, Optional
|
13
|
+
|
14
14
|
from joblib import Parallel, delayed
|
15
15
|
from multiprocessing.managers import SharedMemoryManager
|
16
16
|
|
17
|
-
import
|
18
|
-
from
|
19
|
-
|
20
|
-
from .
|
21
|
-
from .
|
22
|
-
|
23
|
-
handle_traceback,
|
24
|
-
split_numpy_array_slices,
|
25
|
-
conditional_execute,
|
26
|
-
)
|
27
|
-
from .matching_memory import MatchingMemoryUsage, MATCHING_MEMORY_REGISTRY
|
28
|
-
# from .preprocessing import Compose
|
29
|
-
from .preprocessor import Preprocessor
|
30
|
-
from .matching_data import MatchingData
|
31
|
-
from .backends import backend
|
32
|
-
from .types import NDArray, CallbackClass
|
33
|
-
|
34
|
-
os.environ["MKL_NUM_THREADS"] = "1"
|
35
|
-
os.environ["OMP_NUM_THREADS"] = "1"
|
36
|
-
os.environ["PYFFTW_NUM_THREADS"] = "1"
|
37
|
-
os.environ["OPENBLAS_NUM_THREADS"] = "1"
|
38
|
-
|
39
|
-
|
40
|
-
def _run_inner(backend_name, backend_args, **kwargs):
|
41
|
-
from tme.backends import backend
|
17
|
+
from .backends import backend as be
|
18
|
+
from .preprocessing import Compose
|
19
|
+
from .matching_utils import split_shape
|
20
|
+
from .matching_scores import MATCHING_EXHAUSTIVE_REGISTER
|
21
|
+
from .types import CallbackClass, MatchingData
|
22
|
+
from .memory import MatchingMemoryUsage, MATCHING_MEMORY_REGISTRY
|
42
23
|
|
43
|
-
backend.change_backend(backend_name, **backend_args)
|
44
|
-
return scan(**kwargs)
|
45
24
|
|
46
|
-
|
47
|
-
def normalize_under_mask(template: NDArray, mask: NDArray, mask_intensity) -> None:
|
25
|
+
def _handle_traceback(last_type, last_value, last_traceback):
|
48
26
|
"""
|
49
|
-
|
50
|
-
standard deviation based on the elements in mask. Subsequently, the template is
|
51
|
-
multiplied by the mask.
|
27
|
+
Handle sys.exc_info().
|
52
28
|
|
53
29
|
Parameters
|
54
30
|
----------
|
55
|
-
|
56
|
-
The
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
Mask intensity used to compute expectations.
|
62
|
-
|
63
|
-
References
|
64
|
-
----------
|
65
|
-
.. [1] T. Hrabe, Y. Chen, S. Pfeffer, L. Kuhn Cuellar, A.-V. Mangold,
|
66
|
-
and F. Förster, J. Struct. Biol. 178, 177 (2012).
|
67
|
-
.. [2] M. L. Chaillet, G. van der Schot, I. Gubins, S. Roet,
|
68
|
-
R. C. Veltkamp, and F. Förster, Int. J. Mol. Sci. 24,
|
69
|
-
13375 (2023)
|
70
|
-
|
71
|
-
Returns
|
72
|
-
-------
|
73
|
-
None
|
74
|
-
This function modifies `template` in-place and does not return any value.
|
75
|
-
"""
|
76
|
-
masked_mean = backend.sum(backend.multiply(template, mask))
|
77
|
-
masked_mean = backend.divide(masked_mean, mask_intensity)
|
78
|
-
masked_std = backend.sum(backend.multiply(backend.square(template), mask))
|
79
|
-
masked_std = backend.subtract(
|
80
|
-
masked_std / mask_intensity, backend.square(masked_mean)
|
81
|
-
)
|
82
|
-
masked_std = backend.sqrt(backend.maximum(masked_std, 0))
|
83
|
-
|
84
|
-
backend.subtract(template, masked_mean, out=template)
|
85
|
-
backend.divide(template, masked_std, out=template)
|
86
|
-
backend.multiply(template, mask, out=template)
|
87
|
-
return None
|
88
|
-
|
89
|
-
|
90
|
-
def apply_filter(ft_template, template_filter):
|
91
|
-
# This is an approximation to applying the mask, irfftn, normalize, rfftn
|
92
|
-
std_before = backend.std(ft_template)
|
93
|
-
backend.multiply(ft_template, template_filter, out=ft_template)
|
94
|
-
backend.multiply(
|
95
|
-
ft_template, std_before / backend.std(ft_template), out=ft_template
|
96
|
-
)
|
97
|
-
|
98
|
-
|
99
|
-
def cc_setup(
|
100
|
-
rfftn: Callable,
|
101
|
-
irfftn: Callable,
|
102
|
-
template: NDArray,
|
103
|
-
target: NDArray,
|
104
|
-
fast_shape: Tuple[int],
|
105
|
-
fast_ft_shape: Tuple[int],
|
106
|
-
real_dtype: type,
|
107
|
-
complex_dtype: type,
|
108
|
-
shared_memory_handler: Callable,
|
109
|
-
callback_class: Callable,
|
110
|
-
callback_class_args: Dict,
|
111
|
-
**kwargs,
|
112
|
-
) -> Dict:
|
113
|
-
"""
|
114
|
-
Setup to compute the cross-correlation between a target f and template g
|
115
|
-
defined as:
|
116
|
-
|
117
|
-
.. math::
|
118
|
-
|
119
|
-
\\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*)
|
120
|
-
|
121
|
-
|
122
|
-
See Also
|
123
|
-
--------
|
124
|
-
:py:meth:`corr_scoring`
|
125
|
-
:py:class:`tme.matching_optimization.CrossCorrelation`
|
126
|
-
"""
|
127
|
-
target_shape = target.shape
|
128
|
-
target_pad = backend.topleft_pad(target, fast_shape)
|
129
|
-
target_pad_ft = backend.preallocate_array(fast_ft_shape, complex_dtype)
|
130
|
-
|
131
|
-
rfftn(target_pad, target_pad_ft)
|
132
|
-
|
133
|
-
target_ft_out = backend.arr_to_sharedarr(
|
134
|
-
arr=target_pad_ft, shared_memory_handler=shared_memory_handler
|
135
|
-
)
|
136
|
-
|
137
|
-
template_out = backend.arr_to_sharedarr(
|
138
|
-
arr=template, shared_memory_handler=shared_memory_handler
|
139
|
-
)
|
140
|
-
inv_denominator_buffer = backend.arr_to_sharedarr(
|
141
|
-
arr=backend.preallocate_array(1, real_dtype) + 1,
|
142
|
-
shared_memory_handler=shared_memory_handler,
|
143
|
-
)
|
144
|
-
numerator2_buffer = backend.arr_to_sharedarr(
|
145
|
-
arr=backend.preallocate_array(1, real_dtype),
|
146
|
-
shared_memory_handler=shared_memory_handler,
|
147
|
-
)
|
148
|
-
|
149
|
-
target_ft_tuple = (target_ft_out, fast_ft_shape, complex_dtype)
|
150
|
-
template_tuple = (template_out, template.shape, real_dtype)
|
151
|
-
|
152
|
-
inv_denominator_tuple = (inv_denominator_buffer, (1,), real_dtype)
|
153
|
-
numerator2_tuple = (numerator2_buffer, (1,), real_dtype)
|
154
|
-
|
155
|
-
ret = {
|
156
|
-
"template": template_tuple,
|
157
|
-
"ft_target": target_ft_tuple,
|
158
|
-
"inv_denominator": inv_denominator_tuple,
|
159
|
-
"numerator2": numerator2_tuple,
|
160
|
-
"targetshape": target_shape,
|
161
|
-
"templateshape": template.shape,
|
162
|
-
"fast_shape": fast_shape,
|
163
|
-
"fast_ft_shape": fast_ft_shape,
|
164
|
-
"real_dtype": real_dtype,
|
165
|
-
"complex_dtype": complex_dtype,
|
166
|
-
"callback_class": callback_class,
|
167
|
-
"callback_class_args": callback_class_args,
|
168
|
-
"use_memmap": kwargs.get("use_memmap", False),
|
169
|
-
"temp_dir": kwargs.get("temp_dir", None),
|
170
|
-
}
|
171
|
-
|
172
|
-
return ret
|
173
|
-
|
174
|
-
|
175
|
-
def lcc_setup(**kwargs) -> Dict:
|
176
|
-
"""
|
177
|
-
Setup to compute the cross-correlation between a laplace transformed target f
|
178
|
-
and laplace transformed template g defined as:
|
179
|
-
|
180
|
-
.. math::
|
181
|
-
|
182
|
-
\\mathcal{F}^{-1}(\\mathcal{F}(\\nabla^{2}f) \\cdot \\mathcal{F}(\\nabla^{2} g)^*)
|
183
|
-
|
184
|
-
|
185
|
-
See Also
|
186
|
-
--------
|
187
|
-
:py:meth:`corr_scoring`
|
188
|
-
:py:class:`tme.matching_optimization.LaplaceCrossCorrelation`
|
189
|
-
"""
|
190
|
-
kwargs["target"] = laplace(kwargs["target"], mode="wrap")
|
191
|
-
kwargs["template"] = laplace(kwargs["template"], mode="wrap")
|
192
|
-
return cc_setup(**kwargs)
|
193
|
-
|
194
|
-
|
195
|
-
def corr_setup(
|
196
|
-
rfftn: Callable,
|
197
|
-
irfftn: Callable,
|
198
|
-
template: NDArray,
|
199
|
-
template_mask: NDArray,
|
200
|
-
target: NDArray,
|
201
|
-
fast_shape: Tuple[int],
|
202
|
-
fast_ft_shape: Tuple[int],
|
203
|
-
real_dtype: type,
|
204
|
-
complex_dtype: type,
|
205
|
-
shared_memory_handler: Callable,
|
206
|
-
callback_class: Callable,
|
207
|
-
callback_class_args: Dict,
|
208
|
-
**kwargs,
|
209
|
-
) -> Dict:
|
210
|
-
"""
|
211
|
-
Setup to compute a normalized cross-correlation between a target f and a template g.
|
212
|
-
|
213
|
-
.. math::
|
214
|
-
|
215
|
-
\\frac{CC(f,g) - \\overline{g} \\cdot CC(f, m)}
|
216
|
-
{(CC(f^2, m) - \\frac{CC(f, m)^2}{N_g}) \\cdot \\sigma_{g}}
|
217
|
-
|
218
|
-
Where:
|
219
|
-
|
220
|
-
.. math::
|
221
|
-
|
222
|
-
CC(f,g) = \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*)
|
223
|
-
|
224
|
-
and m is a mask with the same dimension as the template filled with ones.
|
225
|
-
|
226
|
-
References
|
227
|
-
----------
|
228
|
-
.. [1] J. P. Lewis, "Fast Normalized Cross-Correlation", Industrial Light
|
229
|
-
and Magic.
|
230
|
-
|
231
|
-
See Also
|
232
|
-
--------
|
233
|
-
:py:meth:`corr_scoring`
|
234
|
-
:py:class:`tme.matching_optimization.NormalizedCrossCorrelation`.
|
235
|
-
"""
|
236
|
-
target_pad = backend.topleft_pad(target, fast_shape)
|
237
|
-
|
238
|
-
# The exact composition of the denominator is debatable
|
239
|
-
# scikit-image match_template multiplies the running sum of the target
|
240
|
-
# with a scaling factor derived from the template. This is probably appropriate
|
241
|
-
# in pattern matching situations where the template exists in the target
|
242
|
-
window_template = backend.topleft_pad(template_mask, fast_shape)
|
243
|
-
ft_window_template = backend.preallocate_array(fast_ft_shape, complex_dtype)
|
244
|
-
rfftn(window_template, ft_window_template)
|
245
|
-
window_template = None
|
246
|
-
|
247
|
-
# Target and squared target window sums
|
248
|
-
ft_target = backend.preallocate_array(fast_ft_shape, complex_dtype)
|
249
|
-
ft_target2 = backend.preallocate_array(fast_ft_shape, complex_dtype)
|
250
|
-
denominator = backend.preallocate_array(fast_shape, real_dtype)
|
251
|
-
target_window_sum = backend.preallocate_array(fast_shape, real_dtype)
|
252
|
-
rfftn(target_pad, ft_target)
|
253
|
-
|
254
|
-
rfftn(backend.square(target_pad), ft_target2)
|
255
|
-
backend.multiply(ft_target2, ft_window_template, out=ft_target2)
|
256
|
-
irfftn(ft_target2, denominator)
|
257
|
-
|
258
|
-
backend.multiply(ft_target, ft_window_template, out=ft_window_template)
|
259
|
-
irfftn(ft_window_template, target_window_sum)
|
260
|
-
|
261
|
-
target_pad, ft_target2, ft_window_template = None, None, None
|
262
|
-
|
263
|
-
# Normalizing constants
|
264
|
-
n_observations = backend.sum(template_mask)
|
265
|
-
template_mean = backend.sum(backend.multiply(template, template_mask))
|
266
|
-
template_mean = backend.divide(template_mean, n_observations)
|
267
|
-
template_ssd = backend.sum(
|
268
|
-
backend.square(
|
269
|
-
backend.multiply(backend.multiply(template, template_mean), template_mask)
|
270
|
-
)
|
271
|
-
)
|
272
|
-
template_volume = np.prod(template.shape)
|
273
|
-
backend.multiply(template, template_mask, out=template)
|
274
|
-
|
275
|
-
# Final numerator is score - numerator2
|
276
|
-
numerator2 = backend.multiply(target_window_sum, template_mean)
|
277
|
-
|
278
|
-
# Compute denominator
|
279
|
-
backend.multiply(target_window_sum, target_window_sum, out=target_window_sum)
|
280
|
-
backend.divide(target_window_sum, template_volume, out=target_window_sum)
|
281
|
-
|
282
|
-
backend.subtract(denominator, target_window_sum, out=denominator)
|
283
|
-
backend.multiply(denominator, template_ssd, out=denominator)
|
284
|
-
backend.maximum(denominator, 0, out=denominator)
|
285
|
-
backend.sqrt(denominator, out=denominator)
|
286
|
-
target_window_sum = None
|
287
|
-
|
288
|
-
# Invert denominator to compute final score as product
|
289
|
-
denominator_mask = denominator > backend.eps(denominator.dtype)
|
290
|
-
inv_denominator = backend.preallocate_array(fast_shape, real_dtype)
|
291
|
-
inv_denominator[denominator_mask] = 1 / denominator[denominator_mask]
|
292
|
-
|
293
|
-
# Convert arrays used in subsequent fitting to SharedMemory objects
|
294
|
-
template_buffer = backend.arr_to_sharedarr(
|
295
|
-
arr=template, shared_memory_handler=shared_memory_handler
|
296
|
-
)
|
297
|
-
target_ft_buffer = backend.arr_to_sharedarr(
|
298
|
-
arr=ft_target, shared_memory_handler=shared_memory_handler
|
299
|
-
)
|
300
|
-
inv_denominator_buffer = backend.arr_to_sharedarr(
|
301
|
-
arr=inv_denominator, shared_memory_handler=shared_memory_handler
|
302
|
-
)
|
303
|
-
numerator2_buffer = backend.arr_to_sharedarr(
|
304
|
-
arr=numerator2, shared_memory_handler=shared_memory_handler
|
305
|
-
)
|
306
|
-
|
307
|
-
template_tuple = (template_buffer, deepcopy(template.shape), real_dtype)
|
308
|
-
target_ft_tuple = (target_ft_buffer, fast_ft_shape, complex_dtype)
|
309
|
-
|
310
|
-
inv_denominator_tuple = (inv_denominator_buffer, fast_shape, real_dtype)
|
311
|
-
numerator2_tuple = (numerator2_buffer, fast_shape, real_dtype)
|
312
|
-
|
313
|
-
ft_target, inv_denominator, numerator2 = None, None, None
|
314
|
-
|
315
|
-
ret = {
|
316
|
-
"template": template_tuple,
|
317
|
-
"ft_target": target_ft_tuple,
|
318
|
-
"inv_denominator": inv_denominator_tuple,
|
319
|
-
"numerator2": numerator2_tuple,
|
320
|
-
"targetshape": deepcopy(target.shape),
|
321
|
-
"templateshape": deepcopy(template.shape),
|
322
|
-
"fast_shape": fast_shape,
|
323
|
-
"fast_ft_shape": fast_ft_shape,
|
324
|
-
"real_dtype": real_dtype,
|
325
|
-
"complex_dtype": complex_dtype,
|
326
|
-
"callback_class": callback_class,
|
327
|
-
"callback_class_args": callback_class_args,
|
328
|
-
"template_mean": kwargs.get("template_mean", template_mean),
|
329
|
-
}
|
330
|
-
|
331
|
-
return ret
|
332
|
-
|
333
|
-
|
334
|
-
def cam_setup(**kwargs):
|
335
|
-
"""
|
336
|
-
Setup to compute a normalized cross-correlation between a target f and a template g
|
337
|
-
over their means. In practice this can be expressed like the cross-correlation
|
338
|
-
CORR defined in :py:meth:`corr_scoring`, so that:
|
339
|
-
|
340
|
-
Notes
|
341
|
-
-----
|
342
|
-
|
343
|
-
.. math::
|
344
|
-
|
345
|
-
\\text{CORR}(f-\\overline{f}, g - \\overline{g})
|
346
|
-
|
347
|
-
Where
|
348
|
-
|
349
|
-
.. math::
|
350
|
-
|
351
|
-
\\overline{f}, \\overline{g}
|
352
|
-
|
353
|
-
are the mean of f and g respectively.
|
354
|
-
|
355
|
-
References
|
356
|
-
----------
|
357
|
-
.. [1] J. P. Lewis, "Fast Normalized Cross-Correlation", Industrial Light
|
358
|
-
and Magic.
|
359
|
-
|
360
|
-
See Also
|
361
|
-
--------
|
362
|
-
:py:meth:`corr_scoring`.
|
363
|
-
:py:class:`tme.matching_optimization.NormalizedCrossCorrelationMean`.
|
364
|
-
"""
|
365
|
-
kwargs["template"] = kwargs["template"] - kwargs["template"].mean()
|
366
|
-
return corr_setup(**kwargs)
|
367
|
-
|
368
|
-
|
369
|
-
def flc_setup(
|
370
|
-
rfftn: Callable,
|
371
|
-
irfftn: Callable,
|
372
|
-
template: NDArray,
|
373
|
-
template_mask: NDArray,
|
374
|
-
target: NDArray,
|
375
|
-
fast_shape: Tuple[int],
|
376
|
-
fast_ft_shape: Tuple[int],
|
377
|
-
real_dtype: type,
|
378
|
-
complex_dtype: type,
|
379
|
-
shared_memory_handler: Callable,
|
380
|
-
callback_class: Callable,
|
381
|
-
callback_class_args: Dict,
|
382
|
-
**kwargs,
|
383
|
-
) -> Dict:
|
384
|
-
"""
|
385
|
-
Setup to compute a normalized cross-correlation score of a target f a template g
|
386
|
-
and a mask m:
|
387
|
-
|
388
|
-
.. math::
|
389
|
-
|
390
|
-
\\frac{CC(f, \\frac{g*m - \\overline{g*m}}{\\sigma_{g*m}})}
|
391
|
-
{N_m * \\sqrt{
|
392
|
-
\\frac{CC(f^2, m)}{N_m} - (\\frac{CC(f, m)}{N_m})^2}
|
393
|
-
}
|
394
|
-
|
395
|
-
Where:
|
396
|
-
|
397
|
-
.. math::
|
398
|
-
|
399
|
-
CC(f,g) = \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*)
|
400
|
-
|
401
|
-
and Nm is the number of voxels within the template mask m.
|
402
|
-
|
403
|
-
References
|
404
|
-
----------
|
405
|
-
.. [1] W. Wan, S. Khavnekar, J. Wagner, P. Erdmann, and W. Baumeister
|
406
|
-
Microsc. Microanal. 26, 2516 (2020)
|
407
|
-
.. [2] T. Hrabe, Y. Chen, S. Pfeffer, L. Kuhn Cuellar, A.-V. Mangold,
|
408
|
-
and F. Förster, J. Struct. Biol. 178, 177 (2012).
|
409
|
-
|
410
|
-
See Also
|
411
|
-
--------
|
412
|
-
:py:meth:`flc_scoring`
|
413
|
-
"""
|
414
|
-
target_pad = backend.topleft_pad(target, fast_shape)
|
415
|
-
|
416
|
-
# Target and squared target window sums
|
417
|
-
ft_target = backend.preallocate_array(fast_ft_shape, complex_dtype)
|
418
|
-
ft_target2 = backend.preallocate_array(fast_ft_shape, complex_dtype)
|
419
|
-
rfftn(target_pad, ft_target)
|
420
|
-
backend.square(target_pad, out=target_pad)
|
421
|
-
rfftn(target_pad, ft_target2)
|
422
|
-
|
423
|
-
# Convert arrays used in subsequent fitting to SharedMemory objects
|
424
|
-
ft_target = backend.arr_to_sharedarr(
|
425
|
-
arr=ft_target, shared_memory_handler=shared_memory_handler
|
426
|
-
)
|
427
|
-
ft_target2 = backend.arr_to_sharedarr(
|
428
|
-
arr=ft_target2, shared_memory_handler=shared_memory_handler
|
429
|
-
)
|
430
|
-
|
431
|
-
normalize_under_mask(
|
432
|
-
template=template, mask=template_mask, mask_intensity=backend.sum(template_mask)
|
433
|
-
)
|
434
|
-
|
435
|
-
template_buffer = backend.arr_to_sharedarr(
|
436
|
-
arr=template, shared_memory_handler=shared_memory_handler
|
437
|
-
)
|
438
|
-
template_mask_buffer = backend.arr_to_sharedarr(
|
439
|
-
arr=template_mask, shared_memory_handler=shared_memory_handler
|
440
|
-
)
|
441
|
-
|
442
|
-
template_tuple = (template_buffer, template.shape, real_dtype)
|
443
|
-
template_mask_tuple = (template_mask_buffer, template_mask.shape, real_dtype)
|
444
|
-
|
445
|
-
target_ft_tuple = (ft_target, fast_ft_shape, complex_dtype)
|
446
|
-
target_ft2_tuple = (ft_target2, fast_ft_shape, complex_dtype)
|
447
|
-
|
448
|
-
ret = {
|
449
|
-
"template": template_tuple,
|
450
|
-
"template_mask": template_mask_tuple,
|
451
|
-
"ft_target": target_ft_tuple,
|
452
|
-
"ft_target2": target_ft2_tuple,
|
453
|
-
"targetshape": target.shape,
|
454
|
-
"templateshape": template.shape,
|
455
|
-
"fast_shape": fast_shape,
|
456
|
-
"fast_ft_shape": fast_ft_shape,
|
457
|
-
"real_dtype": real_dtype,
|
458
|
-
"complex_dtype": complex_dtype,
|
459
|
-
"callback_class": callback_class,
|
460
|
-
"callback_class_args": callback_class_args,
|
461
|
-
}
|
462
|
-
|
463
|
-
return ret
|
464
|
-
|
465
|
-
|
466
|
-
def flcSphericalMask_setup(
|
467
|
-
rfftn: Callable,
|
468
|
-
irfftn: Callable,
|
469
|
-
template: NDArray,
|
470
|
-
template_mask: NDArray,
|
471
|
-
target: NDArray,
|
472
|
-
fast_shape: Tuple[int],
|
473
|
-
fast_ft_shape: Tuple[int],
|
474
|
-
real_dtype: type,
|
475
|
-
complex_dtype: type,
|
476
|
-
shared_memory_handler: Callable,
|
477
|
-
callback_class: Callable,
|
478
|
-
callback_class_args: Dict,
|
479
|
-
**kwargs,
|
480
|
-
) -> Dict:
|
481
|
-
"""
|
482
|
-
Like :py:meth:`flc_setup` but for rotation invariant masks. In such cases
|
483
|
-
the score can be computed quicker by not computing the fourier transforms
|
484
|
-
of the mask for each rotation.
|
485
|
-
|
486
|
-
.. math::
|
487
|
-
|
488
|
-
\\frac{CC(f, \\frac{g*m - \\overline{g*m}}{\\sigma_{g*m}})}
|
489
|
-
{N_m * \\sqrt{
|
490
|
-
\\frac{CC(f^2, m)}{N_m} - (\\frac{CC(f, m)}{N_m})^2}
|
491
|
-
}
|
492
|
-
|
493
|
-
Where:
|
494
|
-
|
495
|
-
.. math::
|
496
|
-
|
497
|
-
CC(f,g) = \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*)
|
498
|
-
|
499
|
-
and Nm is the number of voxels within the template mask m.
|
500
|
-
|
501
|
-
References
|
502
|
-
----------
|
503
|
-
.. [1] W. Wan, S. Khavnekar, J. Wagner, P. Erdmann, and W. Baumeister
|
504
|
-
Microsc. Microanal. 26, 2516 (2020)
|
505
|
-
.. [2] T. Hrabe, Y. Chen, S. Pfeffer, L. Kuhn Cuellar, A.-V. Mangold,
|
506
|
-
and F. Förster, J. Struct. Biol. 178, 177 (2012).
|
507
|
-
|
508
|
-
See Also
|
509
|
-
--------
|
510
|
-
:py:meth:`corr_scoring`
|
511
|
-
"""
|
512
|
-
target_pad = backend.topleft_pad(target, fast_shape)
|
513
|
-
|
514
|
-
# Target and squared target window sums
|
515
|
-
ft_target = backend.preallocate_array(fast_ft_shape, complex_dtype)
|
516
|
-
ft_template_mask = backend.preallocate_array(fast_ft_shape, complex_dtype)
|
517
|
-
ft_temp = backend.preallocate_array(fast_ft_shape, complex_dtype)
|
518
|
-
|
519
|
-
temp = backend.preallocate_array(fast_shape, real_dtype)
|
520
|
-
temp2 = backend.preallocate_array(fast_shape, real_dtype)
|
521
|
-
numerator2 = backend.preallocate_array(1, real_dtype)
|
522
|
-
|
523
|
-
eps = backend.eps(real_dtype)
|
524
|
-
n_observations = backend.sum(template_mask)
|
525
|
-
|
526
|
-
template_mask_pad = backend.topleft_pad(template_mask, fast_shape)
|
527
|
-
rfftn(template_mask_pad, ft_template_mask)
|
528
|
-
|
529
|
-
# Denominator E(X^2) - E(X)^2
|
530
|
-
rfftn(backend.square(target_pad), ft_target)
|
531
|
-
backend.multiply(ft_target, ft_template_mask, out=ft_temp)
|
532
|
-
irfftn(ft_temp, temp2)
|
533
|
-
backend.divide(temp2, n_observations, out=temp2)
|
534
|
-
|
535
|
-
rfftn(target_pad, ft_target)
|
536
|
-
backend.multiply(ft_target, ft_template_mask, out=ft_temp)
|
537
|
-
irfftn(ft_temp, temp)
|
538
|
-
backend.divide(temp, n_observations, out=temp)
|
539
|
-
backend.square(temp, out=temp)
|
540
|
-
|
541
|
-
backend.subtract(temp2, temp, out=temp)
|
542
|
-
|
543
|
-
backend.maximum(temp, 0.0, out=temp)
|
544
|
-
backend.sqrt(temp, out=temp)
|
545
|
-
backend.multiply(temp, n_observations, out=temp)
|
546
|
-
|
547
|
-
tol = 1e3 * eps * backend.max(backend.abs(temp))
|
548
|
-
nonzero_indices = temp > tol
|
549
|
-
|
550
|
-
backend.fill(temp2, 0)
|
551
|
-
temp2[nonzero_indices] = 1 / temp[nonzero_indices]
|
552
|
-
|
553
|
-
normalize_under_mask(
|
554
|
-
template=template, mask=template_mask, mask_intensity=backend.sum(template_mask)
|
555
|
-
)
|
556
|
-
|
557
|
-
template_buffer = backend.arr_to_sharedarr(
|
558
|
-
arr=template, shared_memory_handler=shared_memory_handler
|
559
|
-
)
|
560
|
-
template_mask_buffer = backend.arr_to_sharedarr(
|
561
|
-
arr=template_mask, shared_memory_handler=shared_memory_handler
|
562
|
-
)
|
563
|
-
target_ft_buffer = backend.arr_to_sharedarr(
|
564
|
-
arr=ft_target, shared_memory_handler=shared_memory_handler
|
565
|
-
)
|
566
|
-
inv_denominator_buffer = backend.arr_to_sharedarr(
|
567
|
-
arr=temp2, shared_memory_handler=shared_memory_handler
|
568
|
-
)
|
569
|
-
numerator2_buffer = backend.arr_to_sharedarr(
|
570
|
-
arr=numerator2, shared_memory_handler=shared_memory_handler
|
571
|
-
)
|
572
|
-
|
573
|
-
template_tuple = (template_buffer, template.shape, real_dtype)
|
574
|
-
template_mask_tuple = (template_mask_buffer, template.shape, real_dtype)
|
575
|
-
target_ft_tuple = (target_ft_buffer, fast_ft_shape, complex_dtype)
|
576
|
-
|
577
|
-
inv_denominator_tuple = (inv_denominator_buffer, fast_shape, real_dtype)
|
578
|
-
numerator2_tuple = (numerator2_buffer, (1,), real_dtype)
|
579
|
-
|
580
|
-
ret = {
|
581
|
-
"template": template_tuple,
|
582
|
-
"template_mask": template_mask_tuple,
|
583
|
-
"ft_target": target_ft_tuple,
|
584
|
-
"inv_denominator": inv_denominator_tuple,
|
585
|
-
"numerator2": numerator2_tuple,
|
586
|
-
"targetshape": target.shape,
|
587
|
-
"templateshape": template.shape,
|
588
|
-
"fast_shape": fast_shape,
|
589
|
-
"fast_ft_shape": fast_ft_shape,
|
590
|
-
"real_dtype": real_dtype,
|
591
|
-
"complex_dtype": complex_dtype,
|
592
|
-
"callback_class": callback_class,
|
593
|
-
"callback_class_args": callback_class_args,
|
594
|
-
}
|
595
|
-
|
596
|
-
return ret
|
597
|
-
|
598
|
-
|
599
|
-
def mcc_setup(
|
600
|
-
rfftn: Callable,
|
601
|
-
irfftn: Callable,
|
602
|
-
template: NDArray,
|
603
|
-
template_mask: NDArray,
|
604
|
-
target: NDArray,
|
605
|
-
target_mask: NDArray,
|
606
|
-
fast_shape: Tuple[int],
|
607
|
-
fast_ft_shape: Tuple[int],
|
608
|
-
real_dtype: type,
|
609
|
-
complex_dtype: type,
|
610
|
-
shared_memory_handler: Callable,
|
611
|
-
callback_class: Callable,
|
612
|
-
callback_class_args: Dict,
|
613
|
-
**kwargs,
|
614
|
-
) -> Dict:
|
615
|
-
"""
|
616
|
-
Setup to compute a normalized cross-correlation score with masks
|
617
|
-
for both template and target.
|
618
|
-
|
619
|
-
.. math::
|
620
|
-
|
621
|
-
\\frac{
|
622
|
-
CC(f, g) - \\frac{CC(f, m) \\cdot CC(t, g)}{CC(t, m)}
|
623
|
-
}{
|
624
|
-
\\sqrt{
|
625
|
-
(CC(f ^ 2, m) - \\frac{CC(f, m) ^ 2}{CC(t, m)}) \\cdot
|
626
|
-
(CC(t, g^2) - \\frac{CC(t, g) ^ 2}{CC(t, m)})
|
627
|
-
}
|
628
|
-
}
|
629
|
-
|
630
|
-
Where:
|
631
|
-
|
632
|
-
.. math::
|
633
|
-
|
634
|
-
CC(f,g) = \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*)
|
635
|
-
|
636
|
-
|
637
|
-
References
|
638
|
-
----------
|
639
|
-
.. [1] Masked FFT registration, Dirk Padfield, CVPR 2010 conference
|
640
|
-
|
641
|
-
See Also
|
642
|
-
--------
|
643
|
-
:py:meth:`mcc_scoring`
|
644
|
-
:py:class:`tme.matching_optimization.MaskedCrossCorrelation`
|
645
|
-
"""
|
646
|
-
target = backend.multiply(target, target_mask > 0, out=target)
|
647
|
-
|
648
|
-
target_pad = backend.topleft_pad(target, fast_shape)
|
649
|
-
target_mask_pad = backend.topleft_pad(target_mask, fast_shape)
|
650
|
-
|
651
|
-
target_ft = backend.preallocate_array(fast_ft_shape, complex_dtype)
|
652
|
-
rfftn(target_pad, target_ft)
|
653
|
-
target_ft_buffer = backend.arr_to_sharedarr(
|
654
|
-
arr=target_ft, shared_memory_handler=shared_memory_handler
|
655
|
-
)
|
656
|
-
|
657
|
-
target_ft2 = backend.preallocate_array(fast_ft_shape, complex_dtype)
|
658
|
-
rfftn(backend.square(target_pad), target_ft2)
|
659
|
-
target_ft2_buffer = backend.arr_to_sharedarr(
|
660
|
-
arr=target_ft2, shared_memory_handler=shared_memory_handler
|
661
|
-
)
|
31
|
+
last_type : type
|
32
|
+
The type of the last exception.
|
33
|
+
last_value :
|
34
|
+
The value of the last exception.
|
35
|
+
last_traceback : traceback
|
36
|
+
Traceback call stack at the point where the Exception occured.
|
662
37
|
|
663
|
-
|
664
|
-
|
665
|
-
|
666
|
-
|
667
|
-
)
|
668
|
-
|
669
|
-
template_buffer = backend.arr_to_sharedarr(
|
670
|
-
arr=template, shared_memory_handler=shared_memory_handler
|
671
|
-
)
|
672
|
-
template_mask_buffer = backend.arr_to_sharedarr(
|
673
|
-
arr=template_mask, shared_memory_handler=shared_memory_handler
|
674
|
-
)
|
675
|
-
|
676
|
-
template_tuple = (template_buffer, template.shape, real_dtype)
|
677
|
-
template_mask_tuple = (template_mask_buffer, template.shape, real_dtype)
|
678
|
-
|
679
|
-
target_ft_tuple = (target_ft_buffer, fast_ft_shape, complex_dtype)
|
680
|
-
target_ft2_tuple = (target_ft2_buffer, fast_ft_shape, complex_dtype)
|
681
|
-
target_mask_ft_tuple = (target_mask_ft_buffer, fast_ft_shape, complex_dtype)
|
682
|
-
|
683
|
-
ret = {
|
684
|
-
"template": template_tuple,
|
685
|
-
"template_mask": template_mask_tuple,
|
686
|
-
"ft_target": target_ft_tuple,
|
687
|
-
"ft_target2": target_ft2_tuple,
|
688
|
-
"ft_target_mask": target_mask_ft_tuple,
|
689
|
-
"targetshape": target.shape,
|
690
|
-
"templateshape": template.shape,
|
691
|
-
"fast_shape": fast_shape,
|
692
|
-
"fast_ft_shape": fast_ft_shape,
|
693
|
-
"real_dtype": real_dtype,
|
694
|
-
"complex_dtype": complex_dtype,
|
695
|
-
"callback_class": callback_class,
|
696
|
-
"callback_class_args": callback_class_args,
|
697
|
-
}
|
698
|
-
|
699
|
-
return ret
|
700
|
-
|
701
|
-
|
702
|
-
def corr_scoring(
|
703
|
-
template: Tuple[type, Tuple[int], type],
|
704
|
-
ft_target: Tuple[type, Tuple[int], type],
|
705
|
-
inv_denominator: Tuple[type, Tuple[int], type],
|
706
|
-
numerator2: Tuple[type, Tuple[int], type],
|
707
|
-
template_filter: Tuple[type, Tuple[int], type],
|
708
|
-
targetshape: Tuple[int],
|
709
|
-
templateshape: Tuple[int],
|
710
|
-
fast_shape: Tuple[int],
|
711
|
-
fast_ft_shape: Tuple[int],
|
712
|
-
rotations: NDArray,
|
713
|
-
real_dtype: type,
|
714
|
-
complex_dtype: type,
|
715
|
-
callback_class: CallbackClass,
|
716
|
-
callback_class_args: Dict,
|
717
|
-
interpolation_order: int,
|
718
|
-
convolution_mode: str = "full",
|
719
|
-
**kwargs,
|
720
|
-
) -> CallbackClass:
|
721
|
-
"""
|
722
|
-
Calculates a normalized cross-correlation between a target f and a template g.
|
723
|
-
|
724
|
-
.. math::
|
725
|
-
|
726
|
-
(CC(f,g) - numerator2) \\cdot inv\\_denominator
|
727
|
-
|
728
|
-
Parameters
|
729
|
-
----------
|
730
|
-
template : Tuple[type, Tuple[int], type]
|
731
|
-
Tuple containing a pointer to the template data, its shape, and its datatype.
|
732
|
-
ft_target : Tuple[type, Tuple[int], type]
|
733
|
-
Tuple containing a pointer to the fourier tranform of the target,
|
734
|
-
its shape, and its datatype.
|
735
|
-
inv_denominator : Tuple[type, Tuple[int], type]
|
736
|
-
Tuple containing a pointer to the inverse denominator data, its shape, and its
|
737
|
-
datatype.
|
738
|
-
numerator2 : Tuple[type, Tuple[int], type]
|
739
|
-
Tuple containing a pointer to the numerator2 data, its shape, and its datatype.
|
740
|
-
fast_shape : Tuple[int]
|
741
|
-
The shape for fast Fourier transform.
|
742
|
-
fast_ft_shape : Tuple[int]
|
743
|
-
The shape for fast Fourier transform of the target.
|
744
|
-
rotations : NDArray
|
745
|
-
Array containing the rotation matrices to be applied on the template.
|
746
|
-
real_dtype : type
|
747
|
-
Data type for the real part of the array.
|
748
|
-
complex_dtype : type
|
749
|
-
Data type for the complex part of the array.
|
750
|
-
callback_class : CallbackClass
|
751
|
-
A callable class or function for processing the results after each
|
752
|
-
rotation.
|
753
|
-
callback_class_args : Dict
|
754
|
-
Dictionary of arguments to be passed to the callback class if it's
|
755
|
-
instantiable.
|
756
|
-
interpolation_order : int
|
757
|
-
The order of interpolation to be used while rotating the template.
|
758
|
-
**kwargs :
|
759
|
-
Additional arguments to be passed to the function.
|
760
|
-
|
761
|
-
Returns
|
762
|
-
-------
|
763
|
-
CallbackClass
|
764
|
-
If callback_class was provided an instance of callback_class otherwise None.
|
765
|
-
|
766
|
-
See Also
|
767
|
-
--------
|
768
|
-
:py:meth:`cc_setup`
|
769
|
-
:py:meth:`corr_setup`
|
770
|
-
:py:meth:`cam_setup`
|
771
|
-
:py:meth:`flcSphericalMask_setup`
|
772
|
-
"""
|
773
|
-
callback = callback_class
|
774
|
-
if callback_class is not None and isinstance(callback_class, type):
|
775
|
-
callback = callback_class(**callback_class_args)
|
776
|
-
|
777
|
-
template_buffer, template_shape, template_dtype = template
|
778
|
-
template = backend.sharedarr_to_arr(template_buffer, template_shape, template_dtype)
|
779
|
-
ft_target = backend.sharedarr_to_arr(*ft_target)
|
780
|
-
inv_denominator = backend.sharedarr_to_arr(*inv_denominator)
|
781
|
-
numerator2 = backend.sharedarr_to_arr(*numerator2)
|
782
|
-
template_filter = backend.sharedarr_to_arr(*template_filter)
|
783
|
-
|
784
|
-
norm_template, template_mask, mask_sum = False, 1, 1
|
785
|
-
if "template_mask" in kwargs:
|
786
|
-
template_mask = backend.sharedarr_to_arr(
|
787
|
-
kwargs["template_mask"][0], template_shape, template_dtype
|
788
|
-
)
|
789
|
-
norm_template, mask_sum = True, backend.sum(template_mask)
|
790
|
-
|
791
|
-
arr = backend.preallocate_array(fast_shape, real_dtype)
|
792
|
-
ft_temp = backend.preallocate_array(fast_ft_shape, complex_dtype)
|
793
|
-
|
794
|
-
rfftn, irfftn = backend.build_fft(
|
795
|
-
fast_shape=fast_shape,
|
796
|
-
fast_ft_shape=fast_ft_shape,
|
797
|
-
real_dtype=real_dtype,
|
798
|
-
complex_dtype=complex_dtype,
|
799
|
-
fftargs=kwargs.get("fftargs", {}),
|
800
|
-
temp_real=arr,
|
801
|
-
temp_fft=ft_temp,
|
802
|
-
)
|
803
|
-
|
804
|
-
norm_numerator = (backend.sum(numerator2) != 0) & (backend.size(numerator2) != 1)
|
805
|
-
norm_denominator = (backend.sum(inv_denominator) != 1) & (
|
806
|
-
backend.size(inv_denominator) != 1
|
807
|
-
)
|
808
|
-
|
809
|
-
norm_template = conditional_execute(normalize_under_mask, norm_template)
|
810
|
-
callback_func = conditional_execute(callback, callback_class is not None)
|
811
|
-
norm_func_numerator = conditional_execute(backend.subtract, norm_numerator)
|
812
|
-
norm_func_denominator = conditional_execute(backend.multiply, norm_denominator)
|
813
|
-
template_filter_func = conditional_execute(
|
814
|
-
apply_filter, backend.size(template_filter) != 1
|
815
|
-
)
|
816
|
-
|
817
|
-
unpadded_slice = tuple(slice(0, stop) for stop in template.shape)
|
818
|
-
for index in range(rotations.shape[0]):
|
819
|
-
rotation = rotations[index]
|
820
|
-
backend.fill(arr, 0)
|
821
|
-
backend.rotate_array(
|
822
|
-
arr=template,
|
823
|
-
rotation_matrix=rotation,
|
824
|
-
out=arr,
|
825
|
-
use_geometric_center=False,
|
826
|
-
order=interpolation_order,
|
827
|
-
)
|
828
|
-
|
829
|
-
norm_template(arr[unpadded_slice], template_mask, mask_sum)
|
830
|
-
|
831
|
-
rfftn(arr, ft_temp)
|
832
|
-
template_filter_func(ft_template=ft_temp, template_filter=template_filter)
|
833
|
-
backend.multiply(ft_target, ft_temp, out=ft_temp)
|
834
|
-
irfftn(ft_temp, arr)
|
835
|
-
|
836
|
-
norm_func_numerator(arr, numerator2, out=arr)
|
837
|
-
norm_func_denominator(arr, inv_denominator, out=arr)
|
838
|
-
|
839
|
-
callback_func(
|
840
|
-
arr,
|
841
|
-
rotation_matrix=rotation,
|
842
|
-
rotation_index=index,
|
843
|
-
**callback_class_args,
|
844
|
-
)
|
845
|
-
|
846
|
-
return callback
|
847
|
-
|
848
|
-
|
849
|
-
def flc_scoring(
|
850
|
-
template: Tuple[type, Tuple[int], type],
|
851
|
-
template_mask: Tuple[type, Tuple[int], type],
|
852
|
-
ft_target: Tuple[type, Tuple[int], type],
|
853
|
-
ft_target2: Tuple[type, Tuple[int], type],
|
854
|
-
template_filter: Tuple[type, Tuple[int], type],
|
855
|
-
targetshape: Tuple[int],
|
856
|
-
templateshape: Tuple[int],
|
857
|
-
fast_shape: Tuple[int],
|
858
|
-
fast_ft_shape: Tuple[int],
|
859
|
-
rotations: NDArray,
|
860
|
-
real_dtype: type,
|
861
|
-
complex_dtype: type,
|
862
|
-
callback_class: CallbackClass,
|
863
|
-
callback_class_args: Dict,
|
864
|
-
interpolation_order: int,
|
865
|
-
**kwargs,
|
866
|
-
) -> CallbackClass:
|
867
|
-
"""
|
868
|
-
Computes a normalized cross-correlation score of a target f a template g
|
869
|
-
and a mask m:
|
870
|
-
|
871
|
-
.. math::
|
872
|
-
|
873
|
-
\\frac{CC(f, \\frac{g*m - \\overline{g*m}}{\\sigma_{g*m}})}
|
874
|
-
{N_m * \\sqrt{
|
875
|
-
\\frac{CC(f^2, m)}{N_m} - (\\frac{CC(f, m)}{N_m})^2}
|
876
|
-
}
|
877
|
-
|
878
|
-
Where:
|
879
|
-
|
880
|
-
.. math::
|
881
|
-
|
882
|
-
CC(f,g) = \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*)
|
883
|
-
|
884
|
-
and Nm is the number of voxels within the template mask m.
|
885
|
-
|
886
|
-
References
|
887
|
-
----------
|
888
|
-
.. [1] W. Wan, S. Khavnekar, J. Wagner, P. Erdmann, and W. Baumeister
|
889
|
-
Microsc. Microanal. 26, 2516 (2020)
|
890
|
-
.. [2] T. Hrabe, Y. Chen, S. Pfeffer, L. Kuhn Cuellar, A.-V. Mangold,
|
891
|
-
and F. Förster, J. Struct. Biol. 178, 177 (2012).
|
892
|
-
"""
|
893
|
-
callback = callback_class
|
894
|
-
if callback_class is not None and isinstance(callback_class, type):
|
895
|
-
callback = callback_class(**callback_class_args)
|
896
|
-
|
897
|
-
template = backend.sharedarr_to_arr(*template)
|
898
|
-
template_mask = backend.sharedarr_to_arr(*template_mask)
|
899
|
-
ft_target = backend.sharedarr_to_arr(*ft_target)
|
900
|
-
ft_target2 = backend.sharedarr_to_arr(*ft_target2)
|
901
|
-
template_filter = backend.sharedarr_to_arr(*template_filter)
|
902
|
-
|
903
|
-
arr = backend.preallocate_array(fast_shape, real_dtype)
|
904
|
-
temp = backend.preallocate_array(fast_shape, real_dtype)
|
905
|
-
temp2 = backend.preallocate_array(fast_shape, real_dtype)
|
906
|
-
|
907
|
-
ft_temp = backend.preallocate_array(fast_ft_shape, complex_dtype)
|
908
|
-
ft_denom = backend.preallocate_array(fast_ft_shape, complex_dtype)
|
909
|
-
|
910
|
-
rfftn, irfftn = backend.build_fft(
|
911
|
-
fast_shape=fast_shape,
|
912
|
-
fast_ft_shape=fast_ft_shape,
|
913
|
-
real_dtype=real_dtype,
|
914
|
-
complex_dtype=complex_dtype,
|
915
|
-
fftargs=kwargs.get("fftargs", {}),
|
916
|
-
temp_real=arr,
|
917
|
-
temp_fft=ft_temp,
|
918
|
-
)
|
919
|
-
eps = backend.eps(real_dtype)
|
920
|
-
template_filter_func = conditional_execute(
|
921
|
-
apply_filter, backend.size(template_filter) != 1
|
922
|
-
)
|
923
|
-
callback_func = conditional_execute(callback, callback_class is not None)
|
924
|
-
|
925
|
-
unpadded_slice = tuple(slice(0, stop) for stop in template.shape)
|
926
|
-
for index in range(rotations.shape[0]):
|
927
|
-
rotation = rotations[index]
|
928
|
-
backend.fill(arr, 0)
|
929
|
-
backend.fill(temp, 0)
|
930
|
-
backend.rotate_array(
|
931
|
-
arr=template,
|
932
|
-
arr_mask=template_mask,
|
933
|
-
rotation_matrix=rotation,
|
934
|
-
out=arr,
|
935
|
-
out_mask=temp,
|
936
|
-
use_geometric_center=False,
|
937
|
-
order=interpolation_order,
|
938
|
-
)
|
939
|
-
# Given the amount of FFTs, might aswell normalize properly
|
940
|
-
n_observations = backend.sum(temp)
|
941
|
-
|
942
|
-
normalize_under_mask(
|
943
|
-
template=arr[unpadded_slice],
|
944
|
-
mask=temp[unpadded_slice],
|
945
|
-
mask_intensity=n_observations,
|
946
|
-
)
|
947
|
-
|
948
|
-
rfftn(temp, ft_temp)
|
949
|
-
|
950
|
-
backend.multiply(ft_target, ft_temp, out=ft_denom)
|
951
|
-
irfftn(ft_denom, temp)
|
952
|
-
backend.divide(temp, n_observations, out=temp)
|
953
|
-
backend.square(temp, out=temp)
|
954
|
-
|
955
|
-
backend.multiply(ft_target2, ft_temp, out=ft_denom)
|
956
|
-
irfftn(ft_denom, temp2)
|
957
|
-
backend.divide(temp2, n_observations, out=temp2)
|
958
|
-
|
959
|
-
backend.subtract(temp2, temp, out=temp)
|
960
|
-
backend.maximum(temp, 0.0, out=temp)
|
961
|
-
backend.sqrt(temp, out=temp)
|
962
|
-
backend.multiply(temp, n_observations, out=temp)
|
963
|
-
|
964
|
-
rfftn(arr, ft_temp)
|
965
|
-
template_filter_func(ft_template=ft_temp, template_filter=template_filter)
|
966
|
-
backend.multiply(ft_target, ft_temp, out=ft_temp)
|
967
|
-
irfftn(ft_temp, arr)
|
968
|
-
|
969
|
-
tol = tol = 1e3 * eps * backend.max(backend.abs(temp))
|
970
|
-
nonzero_indices = temp > tol
|
971
|
-
backend.fill(temp2, 0)
|
972
|
-
temp2[nonzero_indices] = arr[nonzero_indices] / temp[nonzero_indices]
|
973
|
-
|
974
|
-
callback_func(
|
975
|
-
temp2,
|
976
|
-
rotation_matrix=rotation,
|
977
|
-
rotation_index=index,
|
978
|
-
**callback_class_args,
|
979
|
-
)
|
980
|
-
|
981
|
-
return callback
|
982
|
-
|
983
|
-
|
984
|
-
def flc_scoring2(
|
985
|
-
template: Tuple[type, Tuple[int], type],
|
986
|
-
template_mask: Tuple[type, Tuple[int], type],
|
987
|
-
ft_target: Tuple[type, Tuple[int], type],
|
988
|
-
ft_target2: Tuple[type, Tuple[int], type],
|
989
|
-
template_filter: Tuple[type, Tuple[int], type],
|
990
|
-
targetshape: Tuple[int],
|
991
|
-
templateshape: Tuple[int],
|
992
|
-
fast_shape: Tuple[int],
|
993
|
-
fast_ft_shape: Tuple[int],
|
994
|
-
rotations: NDArray,
|
995
|
-
real_dtype: type,
|
996
|
-
complex_dtype: type,
|
997
|
-
callback_class: CallbackClass,
|
998
|
-
callback_class_args: Dict,
|
999
|
-
interpolation_order: int,
|
1000
|
-
**kwargs,
|
1001
|
-
) -> CallbackClass:
|
1002
|
-
"""
|
1003
|
-
Computes a normalized cross-correlation score of a target f a template g
|
1004
|
-
and a mask m:
|
1005
|
-
|
1006
|
-
.. math::
|
1007
|
-
|
1008
|
-
\\frac{CC(f, \\frac{g*m - \\overline{g*m}}{\\sigma_{g*m}})}
|
1009
|
-
{N_m * \\sqrt{
|
1010
|
-
\\frac{CC(f^2, m)}{N_m} - (\\frac{CC(f, m)}{N_m})^2}
|
1011
|
-
}
|
1012
|
-
|
1013
|
-
Where:
|
1014
|
-
|
1015
|
-
.. math::
|
1016
|
-
|
1017
|
-
CC(f,g) = \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*)
|
1018
|
-
|
1019
|
-
and Nm is the number of voxels within the template mask m.
|
1020
|
-
|
1021
|
-
References
|
1022
|
-
----------
|
1023
|
-
.. [1] W. Wan, S. Khavnekar, J. Wagner, P. Erdmann, and W. Baumeister
|
1024
|
-
Microsc. Microanal. 26, 2516 (2020)
|
1025
|
-
.. [2] T. Hrabe, Y. Chen, S. Pfeffer, L. Kuhn Cuellar, A.-V. Mangold,
|
1026
|
-
and F. Förster, J. Struct. Biol. 178, 177 (2012).
|
1027
|
-
"""
|
1028
|
-
callback = callback_class
|
1029
|
-
if callback_class is not None and isinstance(callback_class, type):
|
1030
|
-
callback = callback_class(**callback_class_args)
|
1031
|
-
|
1032
|
-
# Retrieve objects from shared memory
|
1033
|
-
template = backend.sharedarr_to_arr(*template)
|
1034
|
-
template_mask = backend.sharedarr_to_arr(*template_mask)
|
1035
|
-
ft_target = backend.sharedarr_to_arr(*ft_target)
|
1036
|
-
ft_target2 = backend.sharedarr_to_arr(*ft_target2)
|
1037
|
-
template_filter = backend.sharedarr_to_arr(*template_filter)
|
1038
|
-
|
1039
|
-
arr = backend.preallocate_array(fast_shape, real_dtype)
|
1040
|
-
temp = backend.preallocate_array(fast_shape, real_dtype)
|
1041
|
-
temp2 = backend.preallocate_array(fast_shape, real_dtype)
|
1042
|
-
|
1043
|
-
ft_temp = backend.preallocate_array(fast_ft_shape, complex_dtype)
|
1044
|
-
ft_denom = backend.preallocate_array(fast_ft_shape, complex_dtype)
|
1045
|
-
|
1046
|
-
eps = backend.eps(real_dtype)
|
1047
|
-
template_filter_func = conditional_execute(
|
1048
|
-
apply_filter, backend.size(template_filter) != 1
|
1049
|
-
)
|
1050
|
-
callback_func = conditional_execute(callback, callback_class is not None)
|
1051
|
-
|
1052
|
-
squeeze_axis = tuple(i for i, x in enumerate(template.shape) if x == 1)
|
1053
|
-
squeeze = tuple(
|
1054
|
-
slice(0, stop) if i not in squeeze_axis else 0
|
1055
|
-
for i, stop in enumerate(template.shape)
|
1056
|
-
)
|
1057
|
-
squeeze_fast = tuple(
|
1058
|
-
slice(0, stop) if i not in squeeze_axis else 0
|
1059
|
-
for i, stop in enumerate(fast_shape)
|
1060
|
-
)
|
1061
|
-
squeeze_fast_ft = tuple(
|
1062
|
-
slice(0, stop) if i not in squeeze_axis else 0
|
1063
|
-
for i, stop in enumerate(fast_ft_shape)
|
1064
|
-
)
|
1065
|
-
|
1066
|
-
rfftn, irfftn = backend.build_fft(
|
1067
|
-
fast_shape=temp[squeeze_fast].shape,
|
1068
|
-
fast_ft_shape=fast_ft_shape,
|
1069
|
-
real_dtype=real_dtype,
|
1070
|
-
complex_dtype=complex_dtype,
|
1071
|
-
fftargs=kwargs.get("fftargs", {}),
|
1072
|
-
inverse_fast_shape=fast_shape,
|
1073
|
-
temp_real=arr[squeeze_fast],
|
1074
|
-
temp_fft=ft_temp,
|
1075
|
-
)
|
1076
|
-
for index in range(rotations.shape[0]):
|
1077
|
-
rotation = rotations[index]
|
1078
|
-
backend.fill(arr, 0)
|
1079
|
-
backend.fill(temp, 0)
|
1080
|
-
backend.rotate_array(
|
1081
|
-
arr=template[squeeze],
|
1082
|
-
arr_mask=template_mask[squeeze],
|
1083
|
-
rotation_matrix=rotation,
|
1084
|
-
out=arr[squeeze],
|
1085
|
-
out_mask=temp[squeeze],
|
1086
|
-
use_geometric_center=False,
|
1087
|
-
order=interpolation_order,
|
1088
|
-
)
|
1089
|
-
# Given the amount of FFTs, might aswell normalize properly
|
1090
|
-
n_observations = backend.sum(temp)
|
1091
|
-
|
1092
|
-
normalize_under_mask(
|
1093
|
-
template=arr[squeeze],
|
1094
|
-
mask=temp[squeeze],
|
1095
|
-
mask_intensity=n_observations,
|
1096
|
-
)
|
1097
|
-
rfftn(temp[squeeze_fast], ft_temp[squeeze_fast_ft])
|
1098
|
-
|
1099
|
-
backend.multiply(ft_target, ft_temp[squeeze_fast_ft], out=ft_denom)
|
1100
|
-
irfftn(ft_denom, temp)
|
1101
|
-
backend.divide(temp, n_observations, out=temp)
|
1102
|
-
backend.square(temp, out=temp)
|
1103
|
-
|
1104
|
-
backend.multiply(ft_target2, ft_temp[squeeze_fast_ft], out=ft_denom)
|
1105
|
-
irfftn(ft_denom, temp2)
|
1106
|
-
backend.divide(temp2, n_observations, out=temp2)
|
1107
|
-
|
1108
|
-
backend.subtract(temp2, temp, out=temp)
|
1109
|
-
backend.maximum(temp, 0.0, out=temp)
|
1110
|
-
backend.sqrt(temp, out=temp)
|
1111
|
-
backend.multiply(temp, n_observations, out=temp)
|
1112
|
-
|
1113
|
-
rfftn(arr[squeeze_fast], ft_temp[squeeze_fast_ft])
|
1114
|
-
template_filter_func(ft_template=ft_temp, template_filter=template_filter)
|
1115
|
-
backend.multiply(ft_target, ft_temp[squeeze_fast_ft], out=ft_denom)
|
1116
|
-
irfftn(ft_denom, arr)
|
1117
|
-
|
1118
|
-
tol = tol = 1e3 * eps * backend.max(backend.abs(temp))
|
1119
|
-
nonzero_indices = temp > tol
|
1120
|
-
backend.fill(temp2, 0)
|
1121
|
-
temp2[nonzero_indices] = arr[nonzero_indices] / temp[nonzero_indices]
|
1122
|
-
|
1123
|
-
callback_func(
|
1124
|
-
temp2,
|
1125
|
-
rotation_matrix=rotation,
|
1126
|
-
rotation_index=index,
|
1127
|
-
**callback_class_args,
|
1128
|
-
)
|
1129
|
-
return callback
|
1130
|
-
|
1131
|
-
|
1132
|
-
def mcc_scoring(
|
1133
|
-
template: Tuple[type, Tuple[int], type],
|
1134
|
-
template_mask: Tuple[type, Tuple[int], type],
|
1135
|
-
ft_target: Tuple[type, Tuple[int], type],
|
1136
|
-
ft_target2: Tuple[type, Tuple[int], type],
|
1137
|
-
ft_target_mask: Tuple[type, Tuple[int], type],
|
1138
|
-
template_filter: Tuple[type, Tuple[int], type],
|
1139
|
-
targetshape: Tuple[int],
|
1140
|
-
templateshape: Tuple[int],
|
1141
|
-
fast_shape: Tuple[int],
|
1142
|
-
fast_ft_shape: Tuple[int],
|
1143
|
-
rotations: NDArray,
|
1144
|
-
real_dtype: type,
|
1145
|
-
complex_dtype: type,
|
1146
|
-
callback_class: CallbackClass,
|
1147
|
-
callback_class_args: type,
|
1148
|
-
interpolation_order: int,
|
1149
|
-
overlap_ratio: float = 0.3,
|
1150
|
-
**kwargs,
|
1151
|
-
) -> CallbackClass:
|
1152
|
-
"""
|
1153
|
-
Computes a cross-correlation score with masks for both template and target.
|
1154
|
-
|
1155
|
-
.. math::
|
1156
|
-
|
1157
|
-
\\frac{
|
1158
|
-
CC(f, g) - \\frac{CC(f, m) \\cdot CC(t, g)}{CC(t, m)}
|
1159
|
-
}{
|
1160
|
-
\\sqrt{
|
1161
|
-
(CC(f ^ 2, m) - \\frac{CC(f, m) ^ 2}{CC(t, m)}) \\cdot
|
1162
|
-
(CC(t, g^2) - \\frac{CC(t, g) ^ 2}{CC(t, m)})
|
1163
|
-
}
|
1164
|
-
}
|
1165
|
-
|
1166
|
-
Where:
|
1167
|
-
|
1168
|
-
.. math::
|
1169
|
-
|
1170
|
-
CC(f,g) = \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*)
|
1171
|
-
|
1172
|
-
|
1173
|
-
References
|
1174
|
-
----------
|
1175
|
-
.. [1] Masked FFT registration, Dirk Padfield, CVPR 2010 conference
|
1176
|
-
.. [2] https://scikit-image.org/docs/stable/api/skimage.registration.html
|
1177
|
-
|
1178
|
-
See Also
|
1179
|
-
--------
|
1180
|
-
:py:class:`tme.matching_optimization.MaskedCrossCorrelation`
|
38
|
+
Raises
|
39
|
+
------
|
40
|
+
Exception
|
41
|
+
Re-raises the last exception.
|
1181
42
|
"""
|
1182
|
-
|
1183
|
-
|
1184
|
-
|
1185
|
-
|
1186
|
-
# Retrieve objects from shared memory
|
1187
|
-
template_buffer, template_shape, template_dtype = template
|
1188
|
-
template = backend.sharedarr_to_arr(*template)
|
1189
|
-
target_ft = backend.sharedarr_to_arr(*ft_target)
|
1190
|
-
target_ft2 = backend.sharedarr_to_arr(*ft_target2)
|
1191
|
-
template_mask = backend.sharedarr_to_arr(*template_mask)
|
1192
|
-
target_mask_ft = backend.sharedarr_to_arr(*ft_target_mask)
|
1193
|
-
template_filter = backend.sharedarr_to_arr(*template_filter)
|
1194
|
-
|
1195
|
-
axes = tuple(range(template.ndim))
|
1196
|
-
eps = backend.eps(real_dtype)
|
1197
|
-
|
1198
|
-
# Allocate score and process specific arrays
|
1199
|
-
template_rot = backend.preallocate_array(fast_shape, real_dtype)
|
1200
|
-
mask_overlap = backend.preallocate_array(fast_shape, real_dtype)
|
1201
|
-
numerator = backend.preallocate_array(fast_shape, real_dtype)
|
1202
|
-
|
1203
|
-
temp = backend.preallocate_array(fast_shape, real_dtype)
|
1204
|
-
temp2 = backend.preallocate_array(fast_shape, real_dtype)
|
1205
|
-
temp3 = backend.preallocate_array(fast_shape, real_dtype)
|
1206
|
-
temp_ft = backend.preallocate_array(fast_ft_shape, complex_dtype)
|
1207
|
-
|
1208
|
-
rfftn, irfftn = backend.build_fft(
|
1209
|
-
fast_shape=fast_shape,
|
1210
|
-
fast_ft_shape=fast_ft_shape,
|
1211
|
-
real_dtype=real_dtype,
|
1212
|
-
complex_dtype=complex_dtype,
|
1213
|
-
fftargs=kwargs.get("fftargs", {}),
|
1214
|
-
temp_real=numerator,
|
1215
|
-
temp_fft=temp_ft,
|
1216
|
-
)
|
1217
|
-
|
1218
|
-
template_filter_func = conditional_execute(
|
1219
|
-
apply_filter, backend.size(template_filter) != 1
|
1220
|
-
)
|
1221
|
-
callback_func = conditional_execute(callback, callback_class is not None)
|
1222
|
-
|
1223
|
-
# Calculate scores across all rotations
|
1224
|
-
for index in range(rotations.shape[0]):
|
1225
|
-
rotation = rotations[index]
|
1226
|
-
backend.fill(template_rot, 0)
|
1227
|
-
backend.fill(temp, 0)
|
1228
|
-
|
1229
|
-
backend.rotate_array(
|
1230
|
-
arr=template,
|
1231
|
-
arr_mask=template_mask,
|
1232
|
-
rotation_matrix=rotation,
|
1233
|
-
out=template_rot,
|
1234
|
-
out_mask=temp,
|
1235
|
-
use_geometric_center=False,
|
1236
|
-
order=interpolation_order,
|
1237
|
-
)
|
1238
|
-
|
1239
|
-
backend.multiply(template_rot, temp > 0, out=template_rot)
|
1240
|
-
|
1241
|
-
# template_rot_ft
|
1242
|
-
rfftn(template_rot, temp_ft)
|
1243
|
-
template_filter_func(ft_template=temp_ft, template_filter=template_filter)
|
1244
|
-
irfftn(target_mask_ft * temp_ft, temp2)
|
1245
|
-
irfftn(target_ft * temp_ft, numerator)
|
1246
|
-
|
1247
|
-
# temp template_mask_rot | temp_ft template_mask_rot_ft
|
1248
|
-
# Calculate overlap of masks at every point in the convolution.
|
1249
|
-
# Locations with high overlap should not be taken into account.
|
1250
|
-
rfftn(temp, temp_ft)
|
1251
|
-
irfftn(temp_ft * target_mask_ft, mask_overlap)
|
1252
|
-
mask_overlap[:] = np.round(mask_overlap)
|
1253
|
-
mask_overlap[:] = np.maximum(mask_overlap, eps)
|
1254
|
-
irfftn(temp_ft * target_ft, temp)
|
1255
|
-
|
1256
|
-
backend.subtract(
|
1257
|
-
numerator,
|
1258
|
-
backend.divide(backend.multiply(temp, temp2), mask_overlap),
|
1259
|
-
out=numerator,
|
1260
|
-
)
|
1261
|
-
|
1262
|
-
# temp_3 = fixed_denom
|
1263
|
-
backend.multiply(temp_ft, target_ft2, out=temp_ft)
|
1264
|
-
irfftn(temp_ft, temp3)
|
1265
|
-
backend.subtract(
|
1266
|
-
temp3, backend.divide(backend.square(temp), mask_overlap), out=temp3
|
1267
|
-
)
|
1268
|
-
backend.maximum(temp3, 0.0, out=temp3)
|
1269
|
-
|
1270
|
-
# temp = moving_denom
|
1271
|
-
rfftn(backend.square(template_rot), temp_ft)
|
1272
|
-
backend.multiply(target_mask_ft, temp_ft, out=temp_ft)
|
1273
|
-
irfftn(temp_ft, temp)
|
1274
|
-
|
1275
|
-
backend.subtract(
|
1276
|
-
temp, backend.divide(backend.square(temp2), mask_overlap), out=temp
|
1277
|
-
)
|
1278
|
-
backend.maximum(temp, 0.0, out=temp)
|
1279
|
-
|
1280
|
-
# temp_2 = denom
|
1281
|
-
backend.multiply(temp3, temp, out=temp)
|
1282
|
-
backend.sqrt(temp, temp2)
|
43
|
+
if last_type is None:
|
44
|
+
return None
|
45
|
+
traceback.print_tb(last_traceback)
|
46
|
+
raise Exception(last_value)
|
1283
47
|
|
1284
|
-
# Pixels where `denom` is very small will introduce large
|
1285
|
-
# numbers after division. To get around this problem,
|
1286
|
-
# we zero-out problematic pixels.
|
1287
|
-
tol = 1e3 * eps * backend.max(backend.abs(temp2), axis=axes, keepdims=True)
|
1288
|
-
nonzero_indices = temp2 > tol
|
1289
48
|
|
1290
|
-
|
1291
|
-
|
1292
|
-
|
1293
|
-
|
1294
|
-
# Apply overlap ratio threshold
|
1295
|
-
number_px_threshold = overlap_ratio * backend.max(
|
1296
|
-
mask_overlap, axis=axes, keepdims=True
|
1297
|
-
)
|
1298
|
-
temp[mask_overlap < number_px_threshold] = 0.0
|
49
|
+
def _wrap_backend(func):
|
50
|
+
@wraps(func)
|
51
|
+
def wrapper(*args, backend_name: str, backend_args: Dict, **kwargs):
|
52
|
+
from tme.backends import backend as be
|
1299
53
|
|
1300
|
-
|
1301
|
-
|
1302
|
-
rotation_matrix=rotation,
|
1303
|
-
rotation_index=index,
|
1304
|
-
**callback_class_args,
|
1305
|
-
)
|
54
|
+
be.change_backend(backend_name, **backend_args)
|
55
|
+
return func(*args, **kwargs)
|
1306
56
|
|
1307
|
-
return
|
57
|
+
return wrapper
|
1308
58
|
|
1309
59
|
|
1310
|
-
def
|
60
|
+
def _setup_template_filter_apply_target_filter(
|
1311
61
|
matching_data: MatchingData,
|
1312
62
|
rfftn: Callable,
|
1313
63
|
irfftn: Callable,
|
1314
64
|
fast_shape: Tuple[int],
|
1315
65
|
fast_ft_shape: Tuple[int],
|
66
|
+
pad_template_filter: bool = True,
|
1316
67
|
):
|
1317
68
|
filter_template = isinstance(matching_data.template_filter, Compose)
|
1318
69
|
filter_target = isinstance(matching_data.target_filter, Compose)
|
1319
70
|
|
1320
|
-
template_filter =
|
1321
|
-
shape=(1,), fill_value=1, dtype=backend._default_dtype
|
1322
|
-
)
|
71
|
+
template_filter = be.full(shape=(1,), fill_value=1, dtype=be._float_dtype)
|
1323
72
|
|
1324
73
|
if not filter_template and not filter_target:
|
1325
74
|
return template_filter
|
1326
75
|
|
1327
|
-
target_temp =
|
1328
|
-
|
1329
|
-
|
1330
|
-
|
1331
|
-
|
76
|
+
target_temp = be.topleft_pad(matching_data.target, fast_shape)
|
77
|
+
target_temp_ft = be.zeros(fast_ft_shape, be._complex_dtype)
|
78
|
+
|
79
|
+
inv_mask = be.subtract(1, be.to_backend_array(matching_data._batch_mask))
|
80
|
+
filter_shape = be.multiply(be.to_backend_array(fast_ft_shape), inv_mask)
|
81
|
+
filter_shape = tuple(int(x) if x != 0 else 1 for x in filter_shape)
|
82
|
+
|
83
|
+
fast_shape = be.multiply(be.to_backend_array(fast_shape), inv_mask)
|
84
|
+
fast_shape = tuple(int(x) for x in fast_shape if x != 0)
|
85
|
+
|
86
|
+
target_temp_ft = rfftn(target_temp, target_temp_ft)
|
87
|
+
if filter_template:
|
88
|
+
# TODO: Pad to fast shapes and adapt _setup_template_filtering accordingly
|
89
|
+
template_fast_shape, template_filter_shape = fast_shape, filter_shape
|
90
|
+
if not pad_template_filter:
|
91
|
+
template_fast_shape = tuple(int(x) for x in matching_data._template.shape)
|
92
|
+
template_filter_shape = [
|
93
|
+
int(x) for x in matching_data._output_template_shape
|
94
|
+
]
|
95
|
+
template_filter_shape[-1] = template_filter_shape[-1] // 2 + 1
|
1332
96
|
|
1333
|
-
if isinstance(matching_data.template_filter, Compose):
|
1334
97
|
template_filter = matching_data.template_filter(
|
1335
|
-
shape=
|
98
|
+
shape=template_fast_shape,
|
1336
99
|
return_real_fourier=True,
|
1337
100
|
shape_is_real_fourier=False,
|
1338
101
|
data_rfft=target_temp_ft,
|
1339
|
-
|
1340
|
-
|
1341
|
-
template_filter
|
102
|
+
batch_dimension=matching_data._target_dims,
|
103
|
+
)["data"]
|
104
|
+
template_filter = be.reshape(template_filter, template_filter_shape)
|
1342
105
|
|
1343
|
-
if
|
106
|
+
if filter_target:
|
1344
107
|
target_filter = matching_data.target_filter(
|
1345
108
|
shape=fast_shape,
|
1346
109
|
return_real_fourier=True,
|
1347
110
|
shape_is_real_fourier=False,
|
1348
111
|
data_rfft=target_temp_ft,
|
1349
112
|
weight_type=None,
|
1350
|
-
|
1351
|
-
|
1352
|
-
|
113
|
+
batch_dimension=matching_data._target_dims,
|
114
|
+
)["data"]
|
115
|
+
target_filter = be.reshape(target_filter, filter_shape)
|
116
|
+
target_temp_ft = be.multiply(target_temp_ft, target_filter, out=target_temp_ft)
|
1353
117
|
|
1354
|
-
irfftn(target_temp_ft, target_temp)
|
1355
|
-
matching_data._target =
|
1356
|
-
target_temp, matching_data.target.shape
|
1357
|
-
)
|
118
|
+
target_temp = irfftn(target_temp_ft, target_temp)
|
119
|
+
matching_data._target = be.topleft_pad(target_temp, matching_data.target.shape)
|
1358
120
|
|
1359
121
|
return template_filter
|
1360
122
|
|
@@ -1368,13 +130,14 @@ def device_memory_handler(func: Callable):
|
|
1368
130
|
last_type, last_value, last_traceback = sys.exc_info()
|
1369
131
|
try:
|
1370
132
|
with SharedMemoryManager() as smh:
|
1371
|
-
|
133
|
+
gpu_index = kwargs.pop("gpu_index") if "gpu_index" in kwargs else 0
|
134
|
+
with be.set_device(gpu_index):
|
1372
135
|
return_value = func(shared_memory_handler=smh, *args, **kwargs)
|
1373
136
|
except Exception as e:
|
1374
137
|
print(e)
|
1375
138
|
last_type, last_value, last_traceback = sys.exc_info()
|
1376
139
|
finally:
|
1377
|
-
|
140
|
+
_handle_traceback(last_type, last_value, last_traceback)
|
1378
141
|
return return_value
|
1379
142
|
|
1380
143
|
return inner_function
|
@@ -1388,19 +151,20 @@ def scan(
|
|
1388
151
|
n_jobs: int = 4,
|
1389
152
|
callback_class: CallbackClass = None,
|
1390
153
|
callback_class_args: Dict = {},
|
1391
|
-
fftargs: Dict = {},
|
1392
154
|
pad_fourier: bool = True,
|
155
|
+
pad_template_filter: bool = True,
|
1393
156
|
interpolation_order: int = 3,
|
1394
157
|
jobs_per_callback_class: int = 8,
|
1395
|
-
|
1396
|
-
) -> Tuple:
|
158
|
+
shared_memory_handler=None,
|
159
|
+
) -> Optional[Tuple]:
|
1397
160
|
"""
|
1398
|
-
|
1399
|
-
|
161
|
+
Run template matching.
|
162
|
+
|
163
|
+
.. warning:: ``matching_data`` might be altered or destroyed during computation.
|
1400
164
|
|
1401
165
|
Parameters
|
1402
166
|
----------
|
1403
|
-
matching_data : MatchingData
|
167
|
+
matching_data : :py:class:`tme.matching_data.MatchingData`
|
1404
168
|
Template matching data.
|
1405
169
|
matching_setup : Callable
|
1406
170
|
Function pointer to setup function.
|
@@ -1412,186 +176,131 @@ def scan(
|
|
1412
176
|
Analyzer class pointer to operate on computed scores.
|
1413
177
|
callback_class_args : dict, optional
|
1414
178
|
Arguments passed to the callback_class. Default is an empty dictionary.
|
1415
|
-
fftargs : dict, optional
|
1416
|
-
Arguments for the FFT operations. Default is an empty dictionary.
|
1417
179
|
pad_fourier: bool, optional
|
1418
180
|
Whether to pad target and template to the full convolution shape.
|
181
|
+
pad_template_filter: bool, optional
|
182
|
+
Whether to pad potential template filters to the full convolution shape.
|
1419
183
|
interpolation_order : int, optional
|
1420
184
|
Order of spline interpolation for rotations.
|
1421
185
|
jobs_per_callback_class : int, optional
|
1422
186
|
How many jobs should be processed by a single callback_class instance,
|
1423
|
-
if
|
1424
|
-
|
1425
|
-
|
187
|
+
if one is provided.
|
188
|
+
shared_memory_handler : type, optional
|
189
|
+
Manager for shared memory objects, None by default.
|
1426
190
|
|
1427
191
|
Returns
|
1428
192
|
-------
|
1429
|
-
Tuple
|
193
|
+
Optional[Tuple]
|
1430
194
|
The merged results from callback_class if provided otherwise None.
|
195
|
+
|
196
|
+
Examples
|
197
|
+
--------
|
198
|
+
Schematically, using :py:meth:`scan` is similar to :py:meth:`scan_subsets`,
|
199
|
+
with the distinction that the objects contained in ``matching_data`` are not
|
200
|
+
split and the search is only parallelized over angles.
|
201
|
+
Assuming you have followed the example in :py:meth:`scan_subsets`, :py:meth:`scan`
|
202
|
+
can be invoked like so
|
203
|
+
|
204
|
+
>>> from tme.matching_exhaustive import scan
|
205
|
+
>>> results = scan(
|
206
|
+
>>> matching_data = matching_data,
|
207
|
+
>>> matching_score = matching_score,
|
208
|
+
>>> matching_setup = matching_setup,
|
209
|
+
>>> callback_class = callback_class,
|
210
|
+
>>> callback_class_args = callback_class_args,
|
211
|
+
>>> )
|
212
|
+
|
1431
213
|
"""
|
1432
214
|
matching_data.to_backend()
|
1433
|
-
shape_diff = backend.subtract(
|
1434
|
-
matching_data._output_target_shape, matching_data._output_template_shape
|
1435
|
-
)
|
1436
|
-
shape_diff = backend.multiply(shape_diff, ~matching_data._batch_mask)
|
1437
|
-
if backend.sum(shape_diff < 0) and not pad_fourier:
|
1438
|
-
warnings.warn(
|
1439
|
-
"Target is larger than template and Fourier padding is turned off."
|
1440
|
-
" This can lead to shifted results. You can swap template and target, or"
|
1441
|
-
" zero-pad the target."
|
1442
|
-
)
|
1443
215
|
fast_shape, fast_ft_shape, fourier_shift = matching_data.fourier_padding(
|
1444
216
|
pad_fourier=pad_fourier
|
1445
217
|
)
|
1446
218
|
|
1447
|
-
|
1448
|
-
rfftn, irfftn = backend.build_fft(
|
219
|
+
rfftn, irfftn = be.build_fft(
|
1449
220
|
fast_shape=fast_shape,
|
1450
221
|
fast_ft_shape=fast_ft_shape,
|
1451
|
-
real_dtype=
|
1452
|
-
complex_dtype=
|
1453
|
-
fftargs=fftargs,
|
222
|
+
real_dtype=be._float_dtype,
|
223
|
+
complex_dtype=be._complex_dtype,
|
1454
224
|
)
|
1455
|
-
|
1456
|
-
|
1457
|
-
|
1458
|
-
|
1459
|
-
|
1460
|
-
|
1461
|
-
|
1462
|
-
|
225
|
+
template_filter = _setup_template_filter_apply_target_filter(
|
226
|
+
matching_data=matching_data,
|
227
|
+
rfftn=rfftn,
|
228
|
+
irfftn=irfftn,
|
229
|
+
fast_shape=fast_shape,
|
230
|
+
fast_ft_shape=fast_ft_shape,
|
231
|
+
pad_template_filter=pad_template_filter,
|
232
|
+
)
|
233
|
+
template_filter = be.astype(be.to_backend_array(template_filter), be._float_dtype)
|
1463
234
|
|
1464
235
|
setup = matching_setup(
|
1465
236
|
rfftn=rfftn,
|
1466
237
|
irfftn=irfftn,
|
1467
238
|
template=matching_data.template,
|
239
|
+
template_filter=template_filter,
|
1468
240
|
template_mask=matching_data.template_mask,
|
1469
241
|
target=matching_data.target,
|
1470
242
|
target_mask=matching_data.target_mask,
|
1471
243
|
fast_shape=fast_shape,
|
1472
244
|
fast_ft_shape=fast_ft_shape,
|
1473
|
-
|
1474
|
-
complex_dtype=matching_data._complex_dtype,
|
1475
|
-
callback_class=callback_class,
|
1476
|
-
callback_class_args=callback_class_args,
|
1477
|
-
**kwargs,
|
245
|
+
shared_memory_handler=shared_memory_handler,
|
1478
246
|
)
|
1479
247
|
rfftn, irfftn = None, None
|
1480
|
-
|
1481
|
-
|
1482
|
-
template_filter, preprocessor = None, Preprocessor()
|
1483
|
-
for method, parameters in matching_data.template_filter.items():
|
1484
|
-
parameters["shape"] = fast_shape
|
1485
|
-
parameters["omit_negative_frequencies"] = True
|
1486
|
-
out = preprocessor.apply_method(method=method, parameters=parameters)
|
1487
|
-
if template_filter is None:
|
1488
|
-
template_filter = out
|
1489
|
-
np.multiply(template_filter, out, out=template_filter)
|
1490
|
-
|
1491
|
-
if template_filter is None:
|
1492
|
-
template_filter = backend.full(
|
1493
|
-
shape=(1,), fill_value=1, dtype=backend._default_dtype
|
1494
|
-
)
|
1495
|
-
|
1496
|
-
template_filter = backend.to_backend_array(template_filter)
|
1497
|
-
template_filter = backend.astype(template_filter, backend._default_dtype)
|
1498
|
-
template_filter_buffer = backend.arr_to_sharedarr(
|
1499
|
-
arr=template_filter,
|
1500
|
-
shared_memory_handler=kwargs.get("shared_memory_handler", None),
|
1501
|
-
)
|
1502
|
-
setup["template_filter"] = (
|
1503
|
-
template_filter_buffer,
|
1504
|
-
template_filter.shape,
|
1505
|
-
template_filter.dtype,
|
1506
|
-
)
|
1507
|
-
|
1508
|
-
callback_class_args["translation_offset"] = backend.astype(
|
1509
|
-
matching_data._translation_offset, int
|
1510
|
-
)
|
1511
|
-
callback_class_args["thread_safe"] = n_jobs > 1
|
1512
|
-
callback_class_args["gpu_index"] = kwargs.get("gpu_index", -1)
|
1513
|
-
|
1514
|
-
n_callback_classes = max(n_jobs // jobs_per_callback_class, 1)
|
1515
|
-
callback_class = setup.pop("callback_class", callback_class)
|
1516
|
-
callback_class_args = setup.pop("callback_class_args", callback_class_args)
|
1517
|
-
callback_classes = [callback_class for _ in range(n_callback_classes)]
|
1518
|
-
|
1519
|
-
convolution_mode = "same"
|
1520
|
-
if backend.sum(backend.to_backend_array(matching_data._target_pad)) > 0:
|
1521
|
-
convolution_mode = "valid"
|
1522
|
-
|
1523
|
-
|
1524
|
-
callback_class_args["fourier_shift"] = fourier_shift
|
1525
|
-
callback_class_args["convolution_mode"] = convolution_mode
|
1526
|
-
callback_class_args["targetshape"] = setup["targetshape"]
|
1527
|
-
callback_class_args["templateshape"] = setup["templateshape"]
|
1528
|
-
|
1529
|
-
if callback_class == MaxScoreOverRotations:
|
1530
|
-
callback_classes = [
|
1531
|
-
class_name(
|
1532
|
-
score_space_shape=fast_shape,
|
1533
|
-
score_space_dtype=matching_data._default_dtype,
|
1534
|
-
shared_memory_handler=kwargs.get("shared_memory_handler", None),
|
1535
|
-
rotation_space_dtype=backend._default_dtype_int,
|
1536
|
-
**callback_class_args,
|
1537
|
-
)
|
1538
|
-
for class_name in callback_classes
|
1539
|
-
]
|
1540
|
-
|
1541
|
-
matching_data._target, matching_data._template = None, None
|
1542
|
-
matching_data._target_mask, matching_data._template_mask = None, None
|
1543
|
-
|
1544
|
-
setup["fftargs"] = fftargs.copy()
|
1545
|
-
convolution_mode = "same"
|
1546
|
-
if backend.sum(backend.to_backend_array(matching_data._target_pad)) > 0:
|
1547
|
-
convolution_mode = "valid"
|
1548
|
-
setup["convolution_mode"] = convolution_mode
|
1549
248
|
setup["interpolation_order"] = interpolation_order
|
1550
|
-
|
1551
|
-
|
1552
|
-
|
1553
|
-
|
1554
|
-
|
1555
|
-
|
249
|
+
setup["template_filter"] = be.to_sharedarr(template_filter, shared_memory_handler)
|
250
|
+
|
251
|
+
offset = be.to_backend_array(matching_data._translation_offset)
|
252
|
+
convmode = "valid" if getattr(matching_data, "_is_padded", False) else "same"
|
253
|
+
default_callback_args = {
|
254
|
+
"offset": be.astype(offset, be._int_dtype),
|
255
|
+
"thread_safe": n_jobs > 1,
|
256
|
+
"fourier_shift": fourier_shift,
|
257
|
+
"convolution_mode": convmode,
|
258
|
+
"targetshape": matching_data.target.shape,
|
259
|
+
"templateshape": matching_data.template.shape,
|
260
|
+
"fast_shape": fast_shape,
|
261
|
+
"indices": getattr(matching_data, "indices", None),
|
262
|
+
"shared_memory_handler": shared_memory_handler,
|
263
|
+
"only_unique_rotations": True,
|
264
|
+
}
|
265
|
+
default_callback_args.update(callback_class_args)
|
1556
266
|
|
1557
|
-
|
1558
|
-
|
267
|
+
matching_data._free_data()
|
268
|
+
be.free_cache()
|
1559
269
|
|
270
|
+
# For highly parallel jobs, blocking in certain analyzers becomes a bottleneck
|
271
|
+
if getattr(callback_class, "shared", True):
|
272
|
+
jobs_per_callback_class = 1
|
273
|
+
n_callback_classes = max(n_jobs // jobs_per_callback_class, 1)
|
274
|
+
callback_classes = [
|
275
|
+
callback_class(
|
276
|
+
shape=fast_shape,
|
277
|
+
**default_callback_args,
|
278
|
+
)
|
279
|
+
if callback_class is not None
|
280
|
+
else None
|
281
|
+
for _ in range(n_callback_classes)
|
282
|
+
]
|
1560
283
|
callbacks = Parallel(n_jobs=n_jobs)(
|
1561
|
-
delayed(
|
1562
|
-
backend_name=
|
1563
|
-
backend_args=
|
284
|
+
delayed(_wrap_backend(matching_score))(
|
285
|
+
backend_name=be._backend_name,
|
286
|
+
backend_args=be._backend_args,
|
1564
287
|
rotations=rotation,
|
1565
|
-
|
1566
|
-
callback_class_args=callback_class_args,
|
288
|
+
callback=callback_classes[index % n_callback_classes],
|
1567
289
|
**setup,
|
1568
290
|
)
|
1569
|
-
for index, rotation in enumerate(
|
291
|
+
for index, rotation in enumerate(matching_data._split_rotations_on_jobs(n_jobs))
|
1570
292
|
)
|
1571
293
|
|
1572
|
-
callbacks = callbacks[0:n_callback_classes]
|
1573
294
|
callbacks = [
|
1574
|
-
tuple(callback._postprocess(
|
1575
|
-
|
1576
|
-
|
1577
|
-
targetshape = setup["targetshape"],
|
1578
|
-
templateshape = setup["templateshape"],
|
1579
|
-
shared_memory_handler=kwargs.get("shared_memory_handler", None)
|
1580
|
-
)) if hasattr(callback, "_postprocess") else tuple(callback)
|
1581
|
-
for callback in callbacks if callback is not None
|
295
|
+
tuple(callback._postprocess(**default_callback_args))
|
296
|
+
for callback in callbacks
|
297
|
+
if callback is not None
|
1582
298
|
]
|
1583
|
-
|
299
|
+
be.free_cache()
|
1584
300
|
|
1585
|
-
merged_callback = None
|
1586
301
|
if callback_class is not None:
|
1587
|
-
|
1588
|
-
|
1589
|
-
**callback_class_args,
|
1590
|
-
score_indices=matching_data.indices,
|
1591
|
-
inner_merge=True,
|
1592
|
-
)
|
1593
|
-
|
1594
|
-
return merged_callback
|
302
|
+
return callback_class.merge(callbacks, **default_callback_args)
|
303
|
+
return None
|
1595
304
|
|
1596
305
|
|
1597
306
|
def scan_subsets(
|
@@ -1605,19 +314,21 @@ def scan_subsets(
|
|
1605
314
|
template_splits: Dict = {},
|
1606
315
|
pad_target_edges: bool = False,
|
1607
316
|
pad_fourier: bool = True,
|
317
|
+
pad_template_filter: bool = True,
|
1608
318
|
interpolation_order: int = 3,
|
1609
319
|
jobs_per_callback_class: int = 8,
|
1610
|
-
|
1611
|
-
|
320
|
+
backend_name: str = None,
|
321
|
+
backend_args: Dict = {},
|
322
|
+
) -> Optional[Tuple]:
|
1612
323
|
"""
|
1613
|
-
Wrapper around :py:meth:`scan` that supports
|
1614
|
-
of
|
324
|
+
Wrapper around :py:meth:`scan` that supports matching on splits
|
325
|
+
of ``matching_data``.
|
1615
326
|
|
1616
327
|
Parameters
|
1617
328
|
----------
|
1618
|
-
matching_data : MatchingData
|
1619
|
-
|
1620
|
-
|
329
|
+
matching_data : :py:class:`tme.matching_data.MatchingData`
|
330
|
+
MatchingData instance containing relevant data.
|
331
|
+
matching_setup : type
|
1621
332
|
Function pointer to setup function.
|
1622
333
|
matching_score : type
|
1623
334
|
Function pointer to scoring function.
|
@@ -1626,88 +337,138 @@ def scan_subsets(
|
|
1626
337
|
callback_class_args : dict, optional
|
1627
338
|
Arguments passed to the callback_class. Default is an empty dictionary.
|
1628
339
|
job_schedule : tuple of int, optional
|
1629
|
-
|
340
|
+
Job scheduling scheme, default is (1, 1). First value corresponds
|
341
|
+
to the number of splits that are processed in parallel, the second
|
342
|
+
to the number of angles evaluated in parallel on each split.
|
1630
343
|
target_splits : dict, optional
|
1631
|
-
Splits for target. Default is an empty dictionary, i.e. no splits
|
344
|
+
Splits for target. Default is an empty dictionary, i.e. no splits.
|
345
|
+
See :py:meth:`tme.matching_utils.compute_parallelization_schedule`.
|
1632
346
|
template_splits : dict, optional
|
1633
347
|
Splits for template. Default is an empty dictionary, i.e. no splits.
|
348
|
+
See :py:meth:`tme.matching_utils.compute_parallelization_schedule`.
|
1634
349
|
pad_target_edges : bool, optional
|
1635
350
|
Whether to pad the target boundaries by half the template shape
|
1636
351
|
along each axis.
|
1637
352
|
pad_fourier: bool, optional
|
1638
353
|
Whether to pad target and template to the full convolution shape.
|
354
|
+
pad_template_filter: bool, optional
|
355
|
+
Whether to pad potential template filters to the full convolution shape.
|
1639
356
|
interpolation_order : int, optional
|
1640
357
|
Order of spline interpolation for rotations.
|
1641
358
|
jobs_per_callback_class : int, optional
|
1642
359
|
How many jobs should be processed by a single callback_class instance,
|
1643
360
|
if ones is provided.
|
1644
|
-
**kwargs : various
|
1645
|
-
Additional arguments.
|
1646
|
-
|
1647
|
-
Notes
|
1648
|
-
-----
|
1649
|
-
Objects in matching_data might be destroyed during computation.
|
1650
361
|
|
1651
362
|
Returns
|
1652
363
|
-------
|
1653
|
-
Tuple
|
364
|
+
Optional[Tuple]
|
1654
365
|
The merged results from callback_class if provided otherwise None.
|
366
|
+
|
367
|
+
Examples
|
368
|
+
--------
|
369
|
+
All data relevant to template matching will be contained in ``matching_data``, which
|
370
|
+
is a :py:class:`tme.matching_data.MatchingData` instance and can be created like so
|
371
|
+
|
372
|
+
>>> import numpy as np
|
373
|
+
>>> from tme.matching_data import MatchingData
|
374
|
+
>>> from tme.matching_utils import get_rotation_matrices
|
375
|
+
>>> target = np.random.rand(50,40,60)
|
376
|
+
>>> template = target[15:25, 10:20, 30:40]
|
377
|
+
>>> matching_data = MatchingData(target, template)
|
378
|
+
>>> matching_data.rotations = get_rotation_matrices(
|
379
|
+
>>> angular_sampling = 60, dim = target.ndim
|
380
|
+
>>> )
|
381
|
+
|
382
|
+
The template matching procedure is determined by ``matching_setup`` and
|
383
|
+
``matching_score``, which are unique to each score. In the following,
|
384
|
+
we will be using the `FLCSphericalMask` score, which is composed of
|
385
|
+
:py:meth:`flcSphericalMask_setup` and :py:meth:`corr_scoring`
|
386
|
+
|
387
|
+
>>> from tme.matching_exhaustive import MATCHING_EXHAUSTIVE_REGISTER
|
388
|
+
>>> funcs = MATCHING_EXHAUSTIVE_REGISTER.get("FLCSphericalMask")
|
389
|
+
>>> matching_setup, matching_score = funcs
|
390
|
+
|
391
|
+
Computed scores are flexibly analyzed by being passed through an analyzer. In the
|
392
|
+
following, we will use :py:class:`tme.analyzer.MaxScoreOverRotations` to
|
393
|
+
aggregate sores over rotations
|
394
|
+
|
395
|
+
>>> from tme.analyzer import MaxScoreOverRotations
|
396
|
+
>>> callback_class = MaxScoreOverRotations
|
397
|
+
>>> callback_class_args = {"score_threshold" : 0}
|
398
|
+
|
399
|
+
In case the entire template matching problem does not fit into memory, we can
|
400
|
+
determine the splitting procedure. In this case, we halv the first axis of the target
|
401
|
+
once. Splitting and ``job_schedule`` is typically computed using
|
402
|
+
:py:meth:`tme.matching_utils.compute_parallelization_schedule`.
|
403
|
+
|
404
|
+
>>> target_splits = {0 : 1}
|
405
|
+
|
406
|
+
Finally, we can perform template matching. Note that the data
|
407
|
+
contained in ``matching_data`` will be destroyed when running the following
|
408
|
+
|
409
|
+
>>> from tme.matching_exhaustive import scan_subsets
|
410
|
+
>>> results = scan_subsets(
|
411
|
+
>>> matching_data = matching_data,
|
412
|
+
>>> matching_score = matching_score,
|
413
|
+
>>> matching_setup = matching_setup,
|
414
|
+
>>> callback_class = callback_class,
|
415
|
+
>>> callback_class_args = callback_class_args,
|
416
|
+
>>> target_splits = target_splits,
|
417
|
+
>>> )
|
418
|
+
|
419
|
+
The ``results`` tuple contains the output of the chosen analyzer.
|
420
|
+
|
421
|
+
See Also
|
422
|
+
--------
|
423
|
+
:py:meth:`tme.matching_utils.compute_parallelization_schedule`
|
1655
424
|
"""
|
1656
|
-
|
1657
|
-
|
1658
|
-
)
|
1659
|
-
|
1660
|
-
|
1661
|
-
|
1662
|
-
|
425
|
+
template_splits = split_shape(matching_data._template.shape, splits=template_splits)
|
426
|
+
target_splits = split_shape(matching_data._target.shape, splits=target_splits)
|
427
|
+
if (len(target_splits) > 1) and not pad_target_edges:
|
428
|
+
warnings.warn(
|
429
|
+
"Target splitting without padding target edges leads to unreliable "
|
430
|
+
"similarity estimates around the split border."
|
431
|
+
)
|
432
|
+
splits = tuple(product(target_splits, template_splits))
|
1663
433
|
|
1664
434
|
outer_jobs, inner_jobs = job_schedule
|
1665
|
-
|
1666
|
-
|
1667
|
-
|
1668
|
-
|
1669
|
-
matching_data=matching_data
|
1670
|
-
|
1671
|
-
|
1672
|
-
|
1673
|
-
),
|
1674
|
-
matching_score=matching_score,
|
1675
|
-
matching_setup=matching_setup,
|
1676
|
-
n_jobs=inner_jobs,
|
435
|
+
target_pad = matching_data.target_padding(pad_target=pad_target_edges)
|
436
|
+
if hasattr(be, "scan"):
|
437
|
+
corr_scoring = MATCHING_EXHAUSTIVE_REGISTER.get("CORR", (None, None))[1]
|
438
|
+
results = be.scan(
|
439
|
+
matching_data=matching_data,
|
440
|
+
splits=splits,
|
441
|
+
n_jobs=outer_jobs,
|
442
|
+
rotate_mask=matching_score != corr_scoring,
|
1677
443
|
callback_class=callback_class,
|
1678
|
-
callback_class_args=callback_class_args,
|
1679
|
-
interpolation_order=interpolation_order,
|
1680
|
-
pad_fourier=pad_fourier,
|
1681
|
-
gpu_index=index % outer_jobs,
|
1682
|
-
**kwargs,
|
1683
444
|
)
|
1684
|
-
|
1685
|
-
|
445
|
+
else:
|
446
|
+
results = Parallel(n_jobs=outer_jobs)(
|
447
|
+
delayed(_wrap_backend(scan))(
|
448
|
+
backend_name=be._backend_name,
|
449
|
+
backend_args=be._backend_args,
|
450
|
+
matching_data=matching_data.subset_by_slice(
|
451
|
+
target_slice=target_split,
|
452
|
+
target_pad=target_pad,
|
453
|
+
template_slice=template_split,
|
454
|
+
),
|
455
|
+
matching_score=matching_score,
|
456
|
+
matching_setup=matching_setup,
|
457
|
+
n_jobs=inner_jobs,
|
458
|
+
callback_class=callback_class,
|
459
|
+
callback_class_args=callback_class_args,
|
460
|
+
interpolation_order=interpolation_order,
|
461
|
+
pad_fourier=pad_fourier,
|
462
|
+
gpu_index=index % outer_jobs,
|
463
|
+
pad_template_filter=pad_template_filter,
|
464
|
+
)
|
465
|
+
for index, (target_split, template_split) in enumerate(splits)
|
1686
466
|
)
|
1687
|
-
)
|
1688
|
-
|
1689
|
-
matching_data._target, matching_data._template = None, None
|
1690
|
-
matching_data._target_mask, matching_data._template_mask = None, None
|
1691
467
|
|
1692
|
-
|
468
|
+
matching_data._free_data()
|
1693
469
|
if callback_class is not None:
|
1694
|
-
|
1695
|
-
|
1696
|
-
)
|
1697
|
-
|
1698
|
-
return candidates
|
1699
|
-
|
1700
|
-
|
1701
|
-
MATCHING_EXHAUSTIVE_REGISTER = {
|
1702
|
-
"CC": (cc_setup, corr_scoring),
|
1703
|
-
"LCC": (lcc_setup, corr_scoring),
|
1704
|
-
"CORR": (corr_setup, corr_scoring),
|
1705
|
-
"CAM": (cam_setup, corr_scoring),
|
1706
|
-
"FLCSphericalMask": (flcSphericalMask_setup, corr_scoring),
|
1707
|
-
"FLC": (flc_setup, flc_scoring),
|
1708
|
-
"FLC2": (flc_setup, flc_scoring2),
|
1709
|
-
"MCC": (mcc_setup, mcc_scoring),
|
1710
|
-
}
|
470
|
+
return callback_class.merge(results, **callback_class_args)
|
471
|
+
return None
|
1711
472
|
|
1712
473
|
|
1713
474
|
def register_matching_exhaustive(
|
@@ -1724,20 +485,17 @@ def register_matching_exhaustive(
|
|
1724
485
|
matching : str
|
1725
486
|
Name of the matching method.
|
1726
487
|
matching_setup : Callable
|
1727
|
-
|
488
|
+
Corresponding setup function.
|
1728
489
|
matching_scoring : Callable
|
1729
|
-
|
490
|
+
Corresponing scoring function.
|
1730
491
|
memory_class : MatchingMemoryUsage
|
1731
|
-
|
1732
|
-
:py:class:`tme.matching_memory.MatchingMemoryUsage`.
|
492
|
+
Child of :py:class:`tme.memory.MatchingMemoryUsage`.
|
1733
493
|
|
1734
494
|
Raises
|
1735
495
|
------
|
1736
496
|
ValueError
|
1737
|
-
If a function with the name ``matching`` already exists in the registry
|
1738
|
-
|
1739
|
-
If ``memory_class`` is not a subclass of
|
1740
|
-
:py:class:`tme.matching_memory.MatchingMemoryUsage`.
|
497
|
+
If a function with the name ``matching`` already exists in the registry, or
|
498
|
+
if ``memory_class`` is no child of :py:class:`tme.memory.MatchingMemoryUsage`.
|
1741
499
|
"""
|
1742
500
|
|
1743
501
|
if matching in MATCHING_EXHAUSTIVE_REGISTER:
|