pytme 0.2.1__cp311-cp311-macosx_14_0_arm64.whl → 0.2.3__cp311-cp311-macosx_14_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pytme-0.2.1.data → pytme-0.2.3.data}/scripts/match_template.py +219 -216
- {pytme-0.2.1.data → pytme-0.2.3.data}/scripts/postprocess.py +86 -54
- pytme-0.2.3.data/scripts/preprocess.py +132 -0
- {pytme-0.2.1.data → pytme-0.2.3.data}/scripts/preprocessor_gui.py +181 -94
- pytme-0.2.3.dist-info/METADATA +92 -0
- pytme-0.2.3.dist-info/RECORD +75 -0
- {pytme-0.2.1.dist-info → pytme-0.2.3.dist-info}/WHEEL +1 -1
- pytme-0.2.1.data/scripts/preprocess.py → scripts/eval.py +1 -1
- scripts/extract_candidates.py +20 -13
- scripts/match_template.py +219 -216
- scripts/match_template_filters.py +154 -95
- scripts/postprocess.py +86 -54
- scripts/preprocess.py +95 -56
- scripts/preprocessor_gui.py +181 -94
- scripts/refine_matches.py +265 -61
- tme/__init__.py +0 -1
- tme/__version__.py +1 -1
- tme/analyzer.py +458 -813
- tme/backends/__init__.py +40 -11
- tme/backends/_jax_utils.py +187 -0
- tme/backends/cupy_backend.py +109 -226
- tme/backends/jax_backend.py +230 -152
- tme/backends/matching_backend.py +445 -384
- tme/backends/mlx_backend.py +32 -59
- tme/backends/npfftw_backend.py +240 -507
- tme/backends/pytorch_backend.py +30 -151
- tme/density.py +248 -371
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/matching_data.py +328 -284
- tme/matching_exhaustive.py +195 -1499
- tme/matching_optimization.py +143 -106
- tme/matching_scores.py +887 -0
- tme/matching_utils.py +287 -388
- tme/memory.py +377 -0
- tme/orientations.py +78 -21
- tme/parser.py +3 -4
- tme/preprocessing/_utils.py +61 -32
- tme/preprocessing/composable_filter.py +7 -4
- tme/preprocessing/compose.py +7 -3
- tme/preprocessing/frequency_filters.py +49 -39
- tme/preprocessing/tilt_series.py +44 -72
- tme/preprocessor.py +560 -526
- tme/structure.py +491 -188
- tme/types.py +5 -3
- pytme-0.2.1.dist-info/METADATA +0 -73
- pytme-0.2.1.dist-info/RECORD +0 -73
- tme/helpers.py +0 -881
- tme/matching_constrained.py +0 -195
- {pytme-0.2.1.data → pytme-0.2.3.data}/scripts/estimate_ram_usage.py +0 -0
- {pytme-0.2.1.dist-info → pytme-0.2.3.dist-info}/LICENSE +0 -0
- {pytme-0.2.1.dist-info → pytme-0.2.3.dist-info}/entry_points.txt +0 -0
- {pytme-0.2.1.dist-info → pytme-0.2.3.dist-info}/top_level.txt +0 -0
tme/matching_exhaustive.py
CHANGED
@@ -4,1300 +4,57 @@
|
|
4
4
|
|
5
5
|
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
6
6
|
"""
|
7
|
-
import os
|
8
7
|
import sys
|
9
8
|
import warnings
|
10
|
-
|
11
|
-
from typing import Callable, Tuple, Dict
|
9
|
+
import traceback
|
12
10
|
from functools import wraps
|
11
|
+
from itertools import product
|
12
|
+
from typing import Callable, Tuple, Dict, Optional
|
13
|
+
|
13
14
|
from joblib import Parallel, delayed
|
14
15
|
from multiprocessing.managers import SharedMemoryManager
|
15
16
|
|
16
|
-
import
|
17
|
-
from scipy.ndimage import laplace
|
18
|
-
|
19
|
-
from .analyzer import MaxScoreOverRotations
|
20
|
-
from .matching_utils import (
|
21
|
-
handle_traceback,
|
22
|
-
split_numpy_array_slices,
|
23
|
-
conditional_execute,
|
24
|
-
)
|
25
|
-
from .matching_memory import MatchingMemoryUsage, MATCHING_MEMORY_REGISTRY
|
17
|
+
from .backends import backend as be
|
26
18
|
from .preprocessing import Compose
|
27
|
-
from .
|
28
|
-
from .
|
29
|
-
from .types import
|
30
|
-
|
31
|
-
os.environ["MKL_NUM_THREADS"] = "1"
|
32
|
-
os.environ["OMP_NUM_THREADS"] = "1"
|
33
|
-
os.environ["PYFFTW_NUM_THREADS"] = "1"
|
34
|
-
os.environ["OPENBLAS_NUM_THREADS"] = "1"
|
35
|
-
|
36
|
-
|
37
|
-
def _run_inner(backend_name, backend_args, **kwargs):
|
38
|
-
from tme.backends import backend
|
39
|
-
|
40
|
-
backend.change_backend(backend_name, **backend_args)
|
41
|
-
return scan(**kwargs)
|
42
|
-
|
43
|
-
|
44
|
-
def normalize_under_mask(template: NDArray, mask: NDArray, mask_intensity) -> None:
|
45
|
-
"""
|
46
|
-
Standardizes the values in in template by subtracting the mean and dividing by the
|
47
|
-
standard deviation based on the elements in mask. Subsequently, the template is
|
48
|
-
multiplied by the mask.
|
49
|
-
|
50
|
-
Parameters
|
51
|
-
----------
|
52
|
-
template : NDArray
|
53
|
-
The data array to be normalized. This array is modified in-place.
|
54
|
-
mask : NDArray
|
55
|
-
A boolean array of the same shape as `template`. True values indicate the positions in `template`
|
56
|
-
to consider for normalization.
|
57
|
-
mask_intensity : float
|
58
|
-
Mask intensity used to compute expectations.
|
59
|
-
|
60
|
-
References
|
61
|
-
----------
|
62
|
-
.. [1] T. Hrabe, Y. Chen, S. Pfeffer, L. Kuhn Cuellar, A.-V. Mangold,
|
63
|
-
and F. Förster, J. Struct. Biol. 178, 177 (2012).
|
64
|
-
|
65
|
-
Returns
|
66
|
-
-------
|
67
|
-
None
|
68
|
-
This function modifies `template` in-place and does not return any value.
|
69
|
-
"""
|
70
|
-
masked_mean = backend.sum(backend.multiply(template, mask))
|
71
|
-
masked_mean = backend.divide(masked_mean, mask_intensity)
|
72
|
-
masked_std = backend.sum(backend.multiply(backend.square(template), mask))
|
73
|
-
masked_std = backend.subtract(
|
74
|
-
masked_std / mask_intensity, backend.square(masked_mean)
|
75
|
-
)
|
76
|
-
masked_std = backend.sqrt(backend.maximum(masked_std, 0))
|
77
|
-
|
78
|
-
backend.subtract(template, masked_mean, out=template)
|
79
|
-
backend.divide(template, masked_std, out=template)
|
80
|
-
backend.multiply(template, mask, out=template)
|
81
|
-
return None
|
82
|
-
|
83
|
-
|
84
|
-
def _normalize_under_mask_overflow_safe(
|
85
|
-
template: NDArray, mask: NDArray, mask_intensity
|
86
|
-
) -> None:
|
87
|
-
_template = backend.astype(template, backend._overflow_safe_dtype)
|
88
|
-
_mask = backend.astype(mask, backend._overflow_safe_dtype)
|
89
|
-
normalize_under_mask(template=_template, mask=_mask, mask_intensity=mask_intensity)
|
90
|
-
template[:] = backend.astype(_template, template.dtype)
|
91
|
-
return None
|
92
|
-
|
93
|
-
|
94
|
-
def apply_filter(ft_template, template_filter):
|
95
|
-
# This is an approximation to applying the mask, irfftn, normalize, rfftn
|
96
|
-
std_before = backend.std(ft_template)
|
97
|
-
backend.multiply(ft_template, template_filter, out=ft_template)
|
98
|
-
backend.multiply(
|
99
|
-
ft_template, std_before / backend.std(ft_template), out=ft_template
|
100
|
-
)
|
101
|
-
|
102
|
-
|
103
|
-
def cc_setup(
|
104
|
-
rfftn: Callable,
|
105
|
-
irfftn: Callable,
|
106
|
-
template: NDArray,
|
107
|
-
target: NDArray,
|
108
|
-
fast_shape: Tuple[int],
|
109
|
-
fast_ft_shape: Tuple[int],
|
110
|
-
shared_memory_handler: Callable,
|
111
|
-
callback_class: Callable,
|
112
|
-
callback_class_args: Dict,
|
113
|
-
**kwargs,
|
114
|
-
) -> Dict:
|
115
|
-
"""
|
116
|
-
Setup to compute the cross-correlation between a target f and template g
|
117
|
-
defined as:
|
118
|
-
|
119
|
-
.. math::
|
120
|
-
|
121
|
-
\\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*)
|
122
|
-
|
123
|
-
|
124
|
-
See Also
|
125
|
-
--------
|
126
|
-
:py:meth:`corr_scoring`
|
127
|
-
:py:class:`tme.matching_optimization.CrossCorrelation`
|
128
|
-
"""
|
129
|
-
real_dtype, complex_dtype = backend._float_dtype, backend._complex_dtype
|
130
|
-
target_pad = backend.topleft_pad(target, fast_shape)
|
131
|
-
target_pad_ft = backend.preallocate_array(fast_ft_shape, complex_dtype)
|
132
|
-
|
133
|
-
rfftn(target_pad, target_pad_ft)
|
134
|
-
|
135
|
-
target_ft_out = backend.arr_to_sharedarr(
|
136
|
-
arr=target_pad_ft, shared_memory_handler=shared_memory_handler
|
137
|
-
)
|
138
|
-
|
139
|
-
template_out = backend.arr_to_sharedarr(
|
140
|
-
arr=template, shared_memory_handler=shared_memory_handler
|
141
|
-
)
|
142
|
-
inv_denominator_buffer = backend.arr_to_sharedarr(
|
143
|
-
arr=backend.preallocate_array(1, real_dtype) + 1,
|
144
|
-
shared_memory_handler=shared_memory_handler,
|
145
|
-
)
|
146
|
-
numerator_buffer = backend.arr_to_sharedarr(
|
147
|
-
arr=backend.preallocate_array(1, real_dtype),
|
148
|
-
shared_memory_handler=shared_memory_handler,
|
149
|
-
)
|
150
|
-
|
151
|
-
target_ft_tuple = (target_ft_out, fast_ft_shape, complex_dtype)
|
152
|
-
template_tuple = (template_out, template.shape, template.dtype)
|
153
|
-
|
154
|
-
inv_denominator_tuple = (inv_denominator_buffer, (1,), real_dtype)
|
155
|
-
numerator_tuple = (numerator_buffer, (1,), real_dtype)
|
156
|
-
|
157
|
-
ret = {
|
158
|
-
"template": template_tuple,
|
159
|
-
"ft_target": target_ft_tuple,
|
160
|
-
"inv_denominator": inv_denominator_tuple,
|
161
|
-
"numerator": numerator_tuple,
|
162
|
-
"targetshape": target.shape,
|
163
|
-
"templateshape": template.shape,
|
164
|
-
"fast_shape": fast_shape,
|
165
|
-
"fast_ft_shape": fast_ft_shape,
|
166
|
-
"callback_class": callback_class,
|
167
|
-
"callback_class_args": callback_class_args,
|
168
|
-
"use_memmap": kwargs.get("use_memmap", False),
|
169
|
-
"temp_dir": kwargs.get("temp_dir", None),
|
170
|
-
}
|
171
|
-
|
172
|
-
return ret
|
173
|
-
|
174
|
-
|
175
|
-
def lcc_setup(**kwargs) -> Dict:
|
176
|
-
"""
|
177
|
-
Setup to compute the cross-correlation between a laplace transformed target f
|
178
|
-
and laplace transformed template g defined as:
|
179
|
-
|
180
|
-
.. math::
|
181
|
-
|
182
|
-
\\mathcal{F}^{-1}(\\mathcal{F}(\\nabla^{2}f) \\cdot \\mathcal{F}(\\nabla^{2} g)^*)
|
183
|
-
|
184
|
-
|
185
|
-
See Also
|
186
|
-
--------
|
187
|
-
:py:meth:`corr_scoring`
|
188
|
-
:py:class:`tme.matching_optimization.LaplaceCrossCorrelation`
|
189
|
-
"""
|
190
|
-
kwargs["target"] = laplace(kwargs["target"], mode="wrap")
|
191
|
-
kwargs["template"] = laplace(kwargs["template"], mode="wrap")
|
192
|
-
return cc_setup(**kwargs)
|
193
|
-
|
194
|
-
|
195
|
-
def corr_setup(
|
196
|
-
rfftn: Callable,
|
197
|
-
irfftn: Callable,
|
198
|
-
template: NDArray,
|
199
|
-
template_mask: NDArray,
|
200
|
-
target: NDArray,
|
201
|
-
fast_shape: Tuple[int],
|
202
|
-
fast_ft_shape: Tuple[int],
|
203
|
-
shared_memory_handler: Callable,
|
204
|
-
callback_class: Callable,
|
205
|
-
callback_class_args: Dict,
|
206
|
-
**kwargs,
|
207
|
-
) -> Dict:
|
208
|
-
"""
|
209
|
-
Setup to compute a normalized cross-correlation between a target f and a template g.
|
210
|
-
|
211
|
-
.. math::
|
212
|
-
|
213
|
-
\\frac{CC(f,g) - \\overline{g} \\cdot CC(f, m)}
|
214
|
-
{(CC(f^2, m) - \\frac{CC(f, m)^2}{N_g}) \\cdot \\sigma_{g}}
|
215
|
-
|
216
|
-
Where:
|
217
|
-
|
218
|
-
.. math::
|
219
|
-
|
220
|
-
CC(f,g) = \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*)
|
221
|
-
|
222
|
-
and m is a mask with the same dimension as the template filled with ones.
|
223
|
-
|
224
|
-
References
|
225
|
-
----------
|
226
|
-
.. [1] J. P. Lewis, "Fast Normalized Cross-Correlation", Industrial Light
|
227
|
-
and Magic.
|
228
|
-
|
229
|
-
See Also
|
230
|
-
--------
|
231
|
-
:py:meth:`corr_scoring`
|
232
|
-
:py:class:`tme.matching_optimization.NormalizedCrossCorrelation`.
|
233
|
-
"""
|
234
|
-
real_dtype, complex_dtype = backend._float_dtype, backend._complex_dtype
|
235
|
-
target_pad = backend.topleft_pad(target, fast_shape)
|
236
|
-
|
237
|
-
# The exact composition of the denominator is debatable
|
238
|
-
# scikit-image match_template multiplies the running sum of the target
|
239
|
-
# with a scaling factor derived from the template. This is probably appropriate
|
240
|
-
# in pattern matching situations where the template exists in the target
|
241
|
-
window_template = backend.topleft_pad(template_mask, fast_shape)
|
242
|
-
ft_window_template = backend.preallocate_array(fast_ft_shape, complex_dtype)
|
243
|
-
rfftn(window_template, ft_window_template)
|
244
|
-
window_template = None
|
245
|
-
|
246
|
-
# Target and squared target window sums
|
247
|
-
ft_target = backend.preallocate_array(fast_ft_shape, complex_dtype)
|
248
|
-
ft_target2 = backend.preallocate_array(fast_ft_shape, complex_dtype)
|
249
|
-
denominator = backend.preallocate_array(fast_shape, real_dtype)
|
250
|
-
target_window_sum = backend.preallocate_array(fast_shape, real_dtype)
|
251
|
-
rfftn(target_pad, ft_target)
|
252
|
-
|
253
|
-
rfftn(backend.square(target_pad), ft_target2)
|
254
|
-
backend.multiply(ft_target2, ft_window_template, out=ft_target2)
|
255
|
-
irfftn(ft_target2, denominator)
|
256
|
-
|
257
|
-
backend.multiply(ft_target, ft_window_template, out=ft_window_template)
|
258
|
-
irfftn(ft_window_template, target_window_sum)
|
259
|
-
|
260
|
-
target_pad, ft_target2, ft_window_template = None, None, None
|
261
|
-
|
262
|
-
# Normalizing constants
|
263
|
-
n_observations = backend.sum(template_mask)
|
264
|
-
template_mean = backend.sum(backend.multiply(template, template_mask))
|
265
|
-
template_mean = backend.divide(template_mean, n_observations)
|
266
|
-
template_ssd = backend.sum(
|
267
|
-
backend.square(
|
268
|
-
backend.multiply(backend.subtract(template, template_mean), template_mask)
|
269
|
-
)
|
270
|
-
)
|
271
|
-
template_volume = np.prod(tuple(int(x) for x in template.shape))
|
272
|
-
backend.multiply(template, template_mask, out=template)
|
273
|
-
|
274
|
-
# Final numerator is score - numerator
|
275
|
-
numerator = backend.multiply(target_window_sum, template_mean)
|
276
|
-
|
277
|
-
# Compute denominator
|
278
|
-
backend.multiply(target_window_sum, target_window_sum, out=target_window_sum)
|
279
|
-
backend.divide(target_window_sum, template_volume, out=target_window_sum)
|
280
|
-
|
281
|
-
backend.subtract(denominator, target_window_sum, out=denominator)
|
282
|
-
backend.multiply(denominator, template_ssd, out=denominator)
|
283
|
-
backend.maximum(denominator, 0, out=denominator)
|
284
|
-
backend.sqrt(denominator, out=denominator)
|
285
|
-
target_window_sum = None
|
286
|
-
|
287
|
-
# Invert denominator to compute final score as product
|
288
|
-
denominator_mask = denominator > backend.eps(denominator.dtype)
|
289
|
-
inv_denominator = backend.preallocate_array(fast_shape, real_dtype)
|
290
|
-
inv_denominator[denominator_mask] = 1 / denominator[denominator_mask]
|
291
|
-
|
292
|
-
# Convert arrays used in subsequent fitting to SharedMemory objects
|
293
|
-
template_buffer = backend.arr_to_sharedarr(
|
294
|
-
arr=template, shared_memory_handler=shared_memory_handler
|
295
|
-
)
|
296
|
-
target_ft_buffer = backend.arr_to_sharedarr(
|
297
|
-
arr=ft_target, shared_memory_handler=shared_memory_handler
|
298
|
-
)
|
299
|
-
inv_denominator_buffer = backend.arr_to_sharedarr(
|
300
|
-
arr=inv_denominator, shared_memory_handler=shared_memory_handler
|
301
|
-
)
|
302
|
-
numerator_buffer = backend.arr_to_sharedarr(
|
303
|
-
arr=numerator, shared_memory_handler=shared_memory_handler
|
304
|
-
)
|
305
|
-
|
306
|
-
template_tuple = (template_buffer, template.shape, template.dtype)
|
307
|
-
target_ft_tuple = (target_ft_buffer, fast_ft_shape, complex_dtype)
|
308
|
-
|
309
|
-
inv_denominator_tuple = (inv_denominator_buffer, fast_shape, real_dtype)
|
310
|
-
numerator_tuple = (numerator_buffer, fast_shape, real_dtype)
|
311
|
-
|
312
|
-
ft_target, inv_denominator, numerator = None, None, None
|
313
|
-
|
314
|
-
ret = {
|
315
|
-
"template": template_tuple,
|
316
|
-
"ft_target": target_ft_tuple,
|
317
|
-
"inv_denominator": inv_denominator_tuple,
|
318
|
-
"numerator": numerator_tuple,
|
319
|
-
"targetshape": target.shape,
|
320
|
-
"templateshape": template.shape,
|
321
|
-
"fast_shape": fast_shape,
|
322
|
-
"fast_ft_shape": fast_ft_shape,
|
323
|
-
"callback_class": callback_class,
|
324
|
-
"callback_class_args": callback_class_args,
|
325
|
-
"template_mean": kwargs.get("template_mean", template_mean),
|
326
|
-
}
|
327
|
-
|
328
|
-
return ret
|
329
|
-
|
330
|
-
|
331
|
-
def cam_setup(**kwargs):
|
332
|
-
"""
|
333
|
-
Setup to compute a normalized cross-correlation between a target f and a template g
|
334
|
-
over their means. In practice this can be expressed like the cross-correlation
|
335
|
-
CORR defined in :py:meth:`corr_scoring`, so that:
|
336
|
-
|
337
|
-
Notes
|
338
|
-
-----
|
339
|
-
|
340
|
-
.. math::
|
341
|
-
|
342
|
-
\\text{CORR}(f-\\overline{f}, g - \\overline{g})
|
343
|
-
|
344
|
-
Where
|
345
|
-
|
346
|
-
.. math::
|
347
|
-
|
348
|
-
\\overline{f}, \\overline{g}
|
349
|
-
|
350
|
-
are the mean of f and g respectively.
|
351
|
-
|
352
|
-
References
|
353
|
-
----------
|
354
|
-
.. [1] J. P. Lewis, "Fast Normalized Cross-Correlation", Industrial Light
|
355
|
-
and Magic.
|
356
|
-
|
357
|
-
See Also
|
358
|
-
--------
|
359
|
-
:py:meth:`corr_scoring`.
|
360
|
-
:py:class:`tme.matching_optimization.NormalizedCrossCorrelationMean`.
|
361
|
-
"""
|
362
|
-
kwargs["template"] = kwargs["template"] - kwargs["template"].mean()
|
363
|
-
return corr_setup(**kwargs)
|
364
|
-
|
365
|
-
|
366
|
-
def flc_setup(
|
367
|
-
rfftn: Callable,
|
368
|
-
irfftn: Callable,
|
369
|
-
template: NDArray,
|
370
|
-
template_mask: NDArray,
|
371
|
-
target: NDArray,
|
372
|
-
fast_shape: Tuple[int],
|
373
|
-
fast_ft_shape: Tuple[int],
|
374
|
-
shared_memory_handler: Callable,
|
375
|
-
callback_class: Callable,
|
376
|
-
callback_class_args: Dict,
|
377
|
-
**kwargs,
|
378
|
-
) -> Dict:
|
379
|
-
"""
|
380
|
-
Setup to compute a normalized cross-correlation score of a target f a template g
|
381
|
-
and a mask m:
|
382
|
-
|
383
|
-
.. math::
|
384
|
-
|
385
|
-
\\frac{CC(f, \\frac{g*m - \\overline{g*m}}{\\sigma_{g*m}})}
|
386
|
-
{N_m * \\sqrt{
|
387
|
-
\\frac{CC(f^2, m)}{N_m} - (\\frac{CC(f, m)}{N_m})^2}
|
388
|
-
}
|
389
|
-
|
390
|
-
Where:
|
391
|
-
|
392
|
-
.. math::
|
393
|
-
|
394
|
-
CC(f,g) = \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*)
|
395
|
-
|
396
|
-
and Nm is the number of voxels within the template mask m.
|
397
|
-
|
398
|
-
References
|
399
|
-
----------
|
400
|
-
.. [1] W. Wan, S. Khavnekar, J. Wagner, P. Erdmann, and W. Baumeister
|
401
|
-
Microsc. Microanal. 26, 2516 (2020)
|
402
|
-
.. [2] T. Hrabe, Y. Chen, S. Pfeffer, L. Kuhn Cuellar, A.-V. Mangold,
|
403
|
-
and F. Förster, J. Struct. Biol. 178, 177 (2012).
|
404
|
-
|
405
|
-
See Also
|
406
|
-
--------
|
407
|
-
:py:meth:`flc_scoring`
|
408
|
-
"""
|
409
|
-
target_pad = backend.topleft_pad(target, fast_shape)
|
410
|
-
|
411
|
-
# Target and squared target window sums
|
412
|
-
ft_target = backend.preallocate_array(fast_ft_shape, backend._complex_dtype)
|
413
|
-
ft_target2 = backend.preallocate_array(fast_ft_shape, backend._complex_dtype)
|
414
|
-
rfftn(target_pad, ft_target)
|
415
|
-
backend.square(target_pad, out=target_pad)
|
416
|
-
rfftn(target_pad, ft_target2)
|
417
|
-
|
418
|
-
# Convert arrays used in subsequent fitting to SharedMemory objects
|
419
|
-
ft_target = backend.arr_to_sharedarr(
|
420
|
-
arr=ft_target, shared_memory_handler=shared_memory_handler
|
421
|
-
)
|
422
|
-
ft_target2 = backend.arr_to_sharedarr(
|
423
|
-
arr=ft_target2, shared_memory_handler=shared_memory_handler
|
424
|
-
)
|
425
|
-
|
426
|
-
normalize_under_mask(
|
427
|
-
template=template, mask=template_mask, mask_intensity=backend.sum(template_mask)
|
428
|
-
)
|
429
|
-
|
430
|
-
template_buffer = backend.arr_to_sharedarr(
|
431
|
-
arr=template, shared_memory_handler=shared_memory_handler
|
432
|
-
)
|
433
|
-
template_mask_buffer = backend.arr_to_sharedarr(
|
434
|
-
arr=template_mask, shared_memory_handler=shared_memory_handler
|
435
|
-
)
|
436
|
-
|
437
|
-
template_tuple = (template_buffer, template.shape, template.dtype)
|
438
|
-
template_mask_tuple = (
|
439
|
-
template_mask_buffer,
|
440
|
-
template_mask.shape,
|
441
|
-
template_mask.dtype,
|
442
|
-
)
|
443
|
-
|
444
|
-
target_ft_tuple = (ft_target, fast_ft_shape, backend._complex_dtype)
|
445
|
-
target_ft2_tuple = (ft_target2, fast_ft_shape, backend._complex_dtype)
|
446
|
-
|
447
|
-
ret = {
|
448
|
-
"template": template_tuple,
|
449
|
-
"template_mask": template_mask_tuple,
|
450
|
-
"ft_target": target_ft_tuple,
|
451
|
-
"ft_target2": target_ft2_tuple,
|
452
|
-
"targetshape": target.shape,
|
453
|
-
"templateshape": template.shape,
|
454
|
-
"fast_shape": fast_shape,
|
455
|
-
"fast_ft_shape": fast_ft_shape,
|
456
|
-
"callback_class": callback_class,
|
457
|
-
"callback_class_args": callback_class_args,
|
458
|
-
}
|
459
|
-
|
460
|
-
return ret
|
461
|
-
|
462
|
-
|
463
|
-
def flcSphericalMask_setup(
|
464
|
-
rfftn: Callable,
|
465
|
-
irfftn: Callable,
|
466
|
-
template: NDArray,
|
467
|
-
template_mask: NDArray,
|
468
|
-
target: NDArray,
|
469
|
-
fast_shape: Tuple[int],
|
470
|
-
fast_ft_shape: Tuple[int],
|
471
|
-
shared_memory_handler: Callable,
|
472
|
-
callback_class: Callable,
|
473
|
-
callback_class_args: Dict,
|
474
|
-
**kwargs,
|
475
|
-
) -> Dict:
|
476
|
-
"""
|
477
|
-
Like :py:meth:`flc_setup` but for rotation invariant masks. In such cases
|
478
|
-
the score can be computed quicker by not computing the fourier transforms
|
479
|
-
of the mask for each rotation.
|
480
|
-
|
481
|
-
.. math::
|
482
|
-
|
483
|
-
\\frac{CC(f, \\frac{g*m - \\overline{g*m}}{\\sigma_{g*m}})}
|
484
|
-
{N_m * \\sqrt{
|
485
|
-
\\frac{CC(f^2, m)}{N_m} - (\\frac{CC(f, m)}{N_m})^2}
|
486
|
-
}
|
487
|
-
|
488
|
-
Where:
|
489
|
-
|
490
|
-
.. math::
|
491
|
-
|
492
|
-
CC(f,g) = \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*)
|
493
|
-
|
494
|
-
and Nm is the number of voxels within the template mask m.
|
495
|
-
|
496
|
-
References
|
497
|
-
----------
|
498
|
-
.. [1] W. Wan, S. Khavnekar, J. Wagner, P. Erdmann, and W. Baumeister
|
499
|
-
Microsc. Microanal. 26, 2516 (2020)
|
500
|
-
.. [2] T. Hrabe, Y. Chen, S. Pfeffer, L. Kuhn Cuellar, A.-V. Mangold,
|
501
|
-
and F. Förster, J. Struct. Biol. 178, 177 (2012).
|
502
|
-
|
503
|
-
See Also
|
504
|
-
--------
|
505
|
-
:py:meth:`corr_scoring`
|
506
|
-
"""
|
507
|
-
real_dtype, complex_dtype = backend._float_dtype, backend._complex_dtype
|
508
|
-
target_pad = backend.topleft_pad(target, fast_shape)
|
509
|
-
|
510
|
-
# Target and squared target window sums
|
511
|
-
ft_target = backend.preallocate_array(fast_ft_shape, complex_dtype)
|
512
|
-
ft_template_mask = backend.preallocate_array(fast_ft_shape, complex_dtype)
|
513
|
-
ft_temp = backend.preallocate_array(fast_ft_shape, complex_dtype)
|
514
|
-
|
515
|
-
temp = backend.preallocate_array(fast_shape, real_dtype)
|
516
|
-
temp2 = backend.preallocate_array(fast_shape, real_dtype)
|
517
|
-
numerator = backend.preallocate_array(1, real_dtype)
|
518
|
-
|
519
|
-
n_observations, norm_func = backend.sum(template_mask), normalize_under_mask
|
520
|
-
if backend.datatype_bytes(template_mask.dtype) == 2:
|
521
|
-
norm_func = _normalize_under_mask_overflow_safe
|
522
|
-
n_observations = backend.sum(
|
523
|
-
backend.astype(template_mask, backend._overflow_safe_dtype)
|
524
|
-
)
|
525
|
-
|
526
|
-
template_mask_pad = backend.topleft_pad(template_mask, fast_shape)
|
527
|
-
rfftn(template_mask_pad, ft_template_mask)
|
528
|
-
|
529
|
-
# Denominator E(X^2) - E(X)^2
|
530
|
-
rfftn(backend.square(target_pad), ft_target)
|
531
|
-
backend.multiply(ft_target, ft_template_mask, out=ft_temp)
|
532
|
-
irfftn(ft_temp, temp2)
|
533
|
-
backend.divide(temp2, n_observations, out=temp2)
|
534
|
-
|
535
|
-
rfftn(target_pad, ft_target)
|
536
|
-
backend.multiply(ft_target, ft_template_mask, out=ft_temp)
|
537
|
-
irfftn(ft_temp, temp)
|
538
|
-
backend.divide(temp, n_observations, out=temp)
|
539
|
-
backend.square(temp, out=temp)
|
540
|
-
|
541
|
-
backend.subtract(temp2, temp, out=temp)
|
542
|
-
|
543
|
-
backend.maximum(temp, 0.0, out=temp)
|
544
|
-
backend.sqrt(temp, out=temp)
|
545
|
-
backend.multiply(temp, n_observations, out=temp)
|
546
|
-
|
547
|
-
backend.fill(temp2, 0)
|
548
|
-
nonzero_indices = temp > backend.eps(real_dtype)
|
549
|
-
temp2[nonzero_indices] = 1 / temp[nonzero_indices]
|
550
|
-
|
551
|
-
norm_func(template=template, mask=template_mask, mask_intensity=n_observations)
|
552
|
-
|
553
|
-
template_buffer = backend.arr_to_sharedarr(
|
554
|
-
arr=template, shared_memory_handler=shared_memory_handler
|
555
|
-
)
|
556
|
-
template_mask_buffer = backend.arr_to_sharedarr(
|
557
|
-
arr=template_mask, shared_memory_handler=shared_memory_handler
|
558
|
-
)
|
559
|
-
target_ft_buffer = backend.arr_to_sharedarr(
|
560
|
-
arr=ft_target, shared_memory_handler=shared_memory_handler
|
561
|
-
)
|
562
|
-
inv_denominator_buffer = backend.arr_to_sharedarr(
|
563
|
-
arr=temp2, shared_memory_handler=shared_memory_handler
|
564
|
-
)
|
565
|
-
numerator_buffer = backend.arr_to_sharedarr(
|
566
|
-
arr=numerator, shared_memory_handler=shared_memory_handler
|
567
|
-
)
|
568
|
-
|
569
|
-
template_tuple = (template_buffer, template.shape, template.dtype)
|
570
|
-
template_mask_tuple = (template_mask_buffer, template.shape, template_mask.dtype)
|
571
|
-
target_ft_tuple = (target_ft_buffer, fast_ft_shape, complex_dtype)
|
572
|
-
|
573
|
-
inv_denominator_tuple = (inv_denominator_buffer, fast_shape, real_dtype)
|
574
|
-
numerator_tuple = (numerator_buffer, (1,), real_dtype)
|
575
|
-
|
576
|
-
ret = {
|
577
|
-
"template": template_tuple,
|
578
|
-
"template_mask": template_mask_tuple,
|
579
|
-
"ft_target": target_ft_tuple,
|
580
|
-
"inv_denominator": inv_denominator_tuple,
|
581
|
-
"numerator": numerator_tuple,
|
582
|
-
"targetshape": target.shape,
|
583
|
-
"templateshape": template.shape,
|
584
|
-
"fast_shape": fast_shape,
|
585
|
-
"fast_ft_shape": fast_ft_shape,
|
586
|
-
"callback_class": callback_class,
|
587
|
-
"callback_class_args": callback_class_args,
|
588
|
-
}
|
589
|
-
|
590
|
-
return ret
|
591
|
-
|
592
|
-
|
593
|
-
def mcc_setup(
|
594
|
-
rfftn: Callable,
|
595
|
-
irfftn: Callable,
|
596
|
-
template: NDArray,
|
597
|
-
template_mask: NDArray,
|
598
|
-
target: NDArray,
|
599
|
-
target_mask: NDArray,
|
600
|
-
fast_shape: Tuple[int],
|
601
|
-
fast_ft_shape: Tuple[int],
|
602
|
-
shared_memory_handler: Callable,
|
603
|
-
callback_class: Callable,
|
604
|
-
callback_class_args: Dict,
|
605
|
-
**kwargs,
|
606
|
-
) -> Dict:
|
607
|
-
"""
|
608
|
-
Setup to compute a normalized cross-correlation score with masks
|
609
|
-
for both template and target.
|
610
|
-
|
611
|
-
.. math::
|
612
|
-
|
613
|
-
\\frac{
|
614
|
-
CC(f, g) - \\frac{CC(f, m) \\cdot CC(t, g)}{CC(t, m)}
|
615
|
-
}{
|
616
|
-
\\sqrt{
|
617
|
-
(CC(f ^ 2, m) - \\frac{CC(f, m) ^ 2}{CC(t, m)}) \\cdot
|
618
|
-
(CC(t, g^2) - \\frac{CC(t, g) ^ 2}{CC(t, m)})
|
619
|
-
}
|
620
|
-
}
|
621
|
-
|
622
|
-
Where:
|
623
|
-
|
624
|
-
.. math::
|
625
|
-
|
626
|
-
CC(f,g) = \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*)
|
19
|
+
from .matching_utils import split_shape
|
20
|
+
from .matching_scores import MATCHING_EXHAUSTIVE_REGISTER
|
21
|
+
from .types import CallbackClass, MatchingData
|
22
|
+
from .memory import MatchingMemoryUsage, MATCHING_MEMORY_REGISTRY
|
627
23
|
|
628
24
|
|
629
|
-
|
630
|
-
----------
|
631
|
-
.. [1] Masked FFT registration, Dirk Padfield, CVPR 2010 conference
|
632
|
-
|
633
|
-
See Also
|
634
|
-
--------
|
635
|
-
:py:meth:`mcc_scoring`
|
636
|
-
:py:class:`tme.matching_optimization.MaskedCrossCorrelation`
|
25
|
+
def _handle_traceback(last_type, last_value, last_traceback):
|
637
26
|
"""
|
638
|
-
|
639
|
-
target = backend.multiply(target, target_mask > 0, out=target)
|
640
|
-
|
641
|
-
target_pad = backend.topleft_pad(target, fast_shape)
|
642
|
-
target_mask_pad = backend.topleft_pad(target_mask, fast_shape)
|
643
|
-
|
644
|
-
target_ft = backend.preallocate_array(fast_ft_shape, complex_dtype)
|
645
|
-
rfftn(target_pad, target_ft)
|
646
|
-
target_ft_buffer = backend.arr_to_sharedarr(
|
647
|
-
arr=target_ft, shared_memory_handler=shared_memory_handler
|
648
|
-
)
|
649
|
-
|
650
|
-
target_ft2 = backend.preallocate_array(fast_ft_shape, complex_dtype)
|
651
|
-
rfftn(backend.square(target_pad), target_ft2)
|
652
|
-
target_ft2_buffer = backend.arr_to_sharedarr(
|
653
|
-
arr=target_ft2, shared_memory_handler=shared_memory_handler
|
654
|
-
)
|
655
|
-
|
656
|
-
target_mask_ft = backend.preallocate_array(fast_ft_shape, complex_dtype)
|
657
|
-
rfftn(target_mask_pad, target_mask_ft)
|
658
|
-
target_mask_ft_buffer = backend.arr_to_sharedarr(
|
659
|
-
arr=target_mask_ft, shared_memory_handler=shared_memory_handler
|
660
|
-
)
|
661
|
-
|
662
|
-
template_buffer = backend.arr_to_sharedarr(
|
663
|
-
arr=template, shared_memory_handler=shared_memory_handler
|
664
|
-
)
|
665
|
-
template_mask_buffer = backend.arr_to_sharedarr(
|
666
|
-
arr=template_mask, shared_memory_handler=shared_memory_handler
|
667
|
-
)
|
668
|
-
|
669
|
-
template_tuple = (template_buffer, template.shape, template.dtype)
|
670
|
-
template_mask_tuple = (template_mask_buffer, template.shape, template_mask.dtype)
|
671
|
-
|
672
|
-
target_ft_tuple = (target_ft_buffer, fast_ft_shape, complex_dtype)
|
673
|
-
target_ft2_tuple = (target_ft2_buffer, fast_ft_shape, complex_dtype)
|
674
|
-
target_mask_ft_tuple = (target_mask_ft_buffer, fast_ft_shape, complex_dtype)
|
675
|
-
|
676
|
-
ret = {
|
677
|
-
"template": template_tuple,
|
678
|
-
"template_mask": template_mask_tuple,
|
679
|
-
"ft_target": target_ft_tuple,
|
680
|
-
"ft_target2": target_ft2_tuple,
|
681
|
-
"ft_target_mask": target_mask_ft_tuple,
|
682
|
-
"targetshape": target.shape,
|
683
|
-
"templateshape": template.shape,
|
684
|
-
"fast_shape": fast_shape,
|
685
|
-
"fast_ft_shape": fast_ft_shape,
|
686
|
-
"callback_class": callback_class,
|
687
|
-
"callback_class_args": callback_class_args,
|
688
|
-
}
|
689
|
-
|
690
|
-
return ret
|
691
|
-
|
692
|
-
|
693
|
-
def corr_scoring(
|
694
|
-
template: Tuple[type, Tuple[int], type],
|
695
|
-
ft_target: Tuple[type, Tuple[int], type],
|
696
|
-
inv_denominator: Tuple[type, Tuple[int], type],
|
697
|
-
numerator: Tuple[type, Tuple[int], type],
|
698
|
-
template_filter: Tuple[type, Tuple[int], type],
|
699
|
-
targetshape: Tuple[int],
|
700
|
-
templateshape: Tuple[int],
|
701
|
-
fast_shape: Tuple[int],
|
702
|
-
fast_ft_shape: Tuple[int],
|
703
|
-
rotations: NDArray,
|
704
|
-
callback_class: CallbackClass,
|
705
|
-
callback_class_args: Dict,
|
706
|
-
interpolation_order: int,
|
707
|
-
convolution_mode: str = "full",
|
708
|
-
**kwargs,
|
709
|
-
) -> CallbackClass:
|
710
|
-
"""
|
711
|
-
Calculates a normalized cross-correlation between a target f and a template g.
|
712
|
-
|
713
|
-
.. math::
|
714
|
-
|
715
|
-
(CC(f,g) - numerator) \\cdot inv\\_denominator
|
27
|
+
Handle sys.exc_info().
|
716
28
|
|
717
29
|
Parameters
|
718
30
|
----------
|
719
|
-
|
720
|
-
|
721
|
-
|
722
|
-
|
723
|
-
|
724
|
-
|
725
|
-
Tuple containing a pointer to the inverse denominator data, its shape, and its
|
726
|
-
datatype.
|
727
|
-
numerator : Tuple[type, Tuple[int], type]
|
728
|
-
Tuple containing a pointer to the numerator data, its shape, and its datatype.
|
729
|
-
fast_shape : Tuple[int]
|
730
|
-
The shape for fast Fourier transform.
|
731
|
-
fast_ft_shape : Tuple[int]
|
732
|
-
The shape for fast Fourier transform of the target.
|
733
|
-
rotations : NDArray
|
734
|
-
Array containing the rotation matrices to be applied on the template.
|
735
|
-
real_dtype : type
|
736
|
-
Data type for the real part of the array.
|
737
|
-
complex_dtype : type
|
738
|
-
Data type for the complex part of the array.
|
739
|
-
callback_class : CallbackClass
|
740
|
-
A callable class or function for processing the results after each
|
741
|
-
rotation.
|
742
|
-
callback_class_args : Dict
|
743
|
-
Dictionary of arguments to be passed to the callback class if it's
|
744
|
-
instantiable.
|
745
|
-
interpolation_order : int
|
746
|
-
The order of interpolation to be used while rotating the template.
|
747
|
-
**kwargs :
|
748
|
-
Additional arguments to be passed to the function.
|
749
|
-
|
750
|
-
Returns
|
751
|
-
-------
|
752
|
-
CallbackClass
|
753
|
-
If callback_class was provided an instance of callback_class otherwise None.
|
754
|
-
|
755
|
-
See Also
|
756
|
-
--------
|
757
|
-
:py:meth:`cc_setup`
|
758
|
-
:py:meth:`corr_setup`
|
759
|
-
:py:meth:`cam_setup`
|
760
|
-
:py:meth:`flcSphericalMask_setup`
|
761
|
-
"""
|
762
|
-
real_dtype, complex_dtype = backend._float_dtype, backend._complex_dtype
|
763
|
-
|
764
|
-
callback = callback_class
|
765
|
-
if callback_class is not None and isinstance(callback_class, type):
|
766
|
-
callback = callback_class(**callback_class_args)
|
767
|
-
|
768
|
-
template_buffer, template_shape, template_dtype = template
|
769
|
-
template = backend.sharedarr_to_arr(template_buffer, template_shape, template_dtype)
|
770
|
-
ft_target = backend.sharedarr_to_arr(*ft_target)
|
771
|
-
inv_denominator = backend.sharedarr_to_arr(*inv_denominator)
|
772
|
-
numerator = backend.sharedarr_to_arr(*numerator)
|
773
|
-
template_filter = backend.sharedarr_to_arr(*template_filter)
|
774
|
-
|
775
|
-
norm_func = normalize_under_mask
|
776
|
-
norm_template, template_mask, mask_sum = False, 1, 1
|
777
|
-
if "template_mask" in kwargs:
|
778
|
-
norm_template = True
|
779
|
-
template_mask = backend.sharedarr_to_arr(*kwargs["template_mask"])
|
780
|
-
mask_sum = backend.sum(template_mask)
|
781
|
-
if backend.datatype_bytes(template_mask.dtype) == 2:
|
782
|
-
norm_func = _normalize_under_mask_overflow_safe
|
783
|
-
mask_sum = backend.sum(
|
784
|
-
backend.astype(template_mask, backend._overflow_safe_dtype)
|
785
|
-
)
|
786
|
-
|
787
|
-
norm_template = conditional_execute(norm_func, norm_template)
|
788
|
-
norm_numerator = (backend.sum(numerator) != 0) & (backend.size(numerator) != 1)
|
789
|
-
norm_func_numerator = conditional_execute(backend.subtract, norm_numerator)
|
790
|
-
|
791
|
-
norm_denominator = (backend.sum(inv_denominator) != 1) & (
|
792
|
-
backend.size(inv_denominator) != 1
|
793
|
-
)
|
794
|
-
norm_func_denominator = conditional_execute(backend.multiply, norm_denominator)
|
795
|
-
callback_func = conditional_execute(callback, callback_class is not None)
|
796
|
-
template_filter_func = conditional_execute(
|
797
|
-
apply_filter, backend.size(template_filter) != 1
|
798
|
-
)
|
799
|
-
|
800
|
-
arr = backend.preallocate_array(fast_shape, real_dtype)
|
801
|
-
ft_temp = backend.preallocate_array(fast_ft_shape, complex_dtype)
|
802
|
-
|
803
|
-
rfftn, irfftn = backend.build_fft(
|
804
|
-
fast_shape=fast_shape,
|
805
|
-
fast_ft_shape=fast_ft_shape,
|
806
|
-
real_dtype=real_dtype,
|
807
|
-
complex_dtype=complex_dtype,
|
808
|
-
fftargs=kwargs.get("fftargs", {}),
|
809
|
-
temp_real=arr,
|
810
|
-
temp_fft=ft_temp,
|
811
|
-
)
|
812
|
-
|
813
|
-
unpadded_slice = tuple(slice(0, stop) for stop in template.shape)
|
814
|
-
for index in range(rotations.shape[0]):
|
815
|
-
rotation = rotations[index]
|
816
|
-
backend.fill(arr, 0)
|
817
|
-
backend.rotate_array(
|
818
|
-
arr=template,
|
819
|
-
rotation_matrix=rotation,
|
820
|
-
out=arr,
|
821
|
-
use_geometric_center=True,
|
822
|
-
order=interpolation_order,
|
823
|
-
)
|
824
|
-
norm_template(arr[unpadded_slice], template_mask, mask_sum)
|
825
|
-
|
826
|
-
rfftn(arr, ft_temp)
|
827
|
-
template_filter_func(ft_template=ft_temp, template_filter=template_filter)
|
828
|
-
backend.multiply(ft_target, ft_temp, out=ft_temp)
|
829
|
-
irfftn(ft_temp, arr)
|
830
|
-
|
831
|
-
norm_func_numerator(arr, numerator, out=arr)
|
832
|
-
norm_func_denominator(arr, inv_denominator, out=arr)
|
833
|
-
|
834
|
-
callback_func(
|
835
|
-
arr,
|
836
|
-
rotation_matrix=rotation,
|
837
|
-
rotation_index=index,
|
838
|
-
**callback_class_args,
|
839
|
-
)
|
840
|
-
|
841
|
-
return callback
|
842
|
-
|
843
|
-
|
844
|
-
def flc_scoring(
|
845
|
-
template: Tuple[type, Tuple[int], type],
|
846
|
-
template_mask: Tuple[type, Tuple[int], type],
|
847
|
-
ft_target: Tuple[type, Tuple[int], type],
|
848
|
-
ft_target2: Tuple[type, Tuple[int], type],
|
849
|
-
template_filter: Tuple[type, Tuple[int], type],
|
850
|
-
targetshape: Tuple[int],
|
851
|
-
templateshape: Tuple[int],
|
852
|
-
fast_shape: Tuple[int],
|
853
|
-
fast_ft_shape: Tuple[int],
|
854
|
-
rotations: NDArray,
|
855
|
-
callback_class: CallbackClass,
|
856
|
-
callback_class_args: Dict,
|
857
|
-
interpolation_order: int,
|
858
|
-
**kwargs,
|
859
|
-
) -> CallbackClass:
|
860
|
-
"""
|
861
|
-
Computes a normalized cross-correlation score of a target f a template g
|
862
|
-
and a mask m:
|
863
|
-
|
864
|
-
.. math::
|
865
|
-
|
866
|
-
\\frac{CC(f, \\frac{g*m - \\overline{g*m}}{\\sigma_{g*m}})}
|
867
|
-
{N_m * \\sqrt{
|
868
|
-
\\frac{CC(f^2, m)}{N_m} - (\\frac{CC(f, m)}{N_m})^2}
|
869
|
-
}
|
870
|
-
|
871
|
-
Where:
|
872
|
-
|
873
|
-
.. math::
|
874
|
-
|
875
|
-
CC(f,g) = \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*)
|
876
|
-
|
877
|
-
and Nm is the number of voxels within the template mask m.
|
878
|
-
|
879
|
-
References
|
880
|
-
----------
|
881
|
-
.. [1] W. Wan, S. Khavnekar, J. Wagner, P. Erdmann, and W. Baumeister
|
882
|
-
Microsc. Microanal. 26, 2516 (2020)
|
883
|
-
.. [2] T. Hrabe, Y. Chen, S. Pfeffer, L. Kuhn Cuellar, A.-V. Mangold,
|
884
|
-
and F. Förster, J. Struct. Biol. 178, 177 (2012).
|
885
|
-
"""
|
886
|
-
real_dtype, complex_dtype = backend._float_dtype, backend._complex_dtype
|
887
|
-
|
888
|
-
callback = callback_class
|
889
|
-
if callback_class is not None and isinstance(callback_class, type):
|
890
|
-
callback = callback_class(**callback_class_args)
|
891
|
-
|
892
|
-
template = backend.sharedarr_to_arr(*template)
|
893
|
-
template_mask = backend.sharedarr_to_arr(*template_mask)
|
894
|
-
ft_target = backend.sharedarr_to_arr(*ft_target)
|
895
|
-
ft_target2 = backend.sharedarr_to_arr(*ft_target2)
|
896
|
-
template_filter = backend.sharedarr_to_arr(*template_filter)
|
897
|
-
|
898
|
-
arr = backend.preallocate_array(fast_shape, real_dtype)
|
899
|
-
temp = backend.preallocate_array(fast_shape, real_dtype)
|
900
|
-
temp2 = backend.preallocate_array(fast_shape, real_dtype)
|
901
|
-
|
902
|
-
ft_temp = backend.preallocate_array(fast_ft_shape, complex_dtype)
|
903
|
-
ft_denom = backend.preallocate_array(fast_ft_shape, complex_dtype)
|
904
|
-
|
905
|
-
rfftn, irfftn = backend.build_fft(
|
906
|
-
fast_shape=fast_shape,
|
907
|
-
fast_ft_shape=fast_ft_shape,
|
908
|
-
real_dtype=real_dtype,
|
909
|
-
complex_dtype=complex_dtype,
|
910
|
-
fftargs=kwargs.get("fftargs", {}),
|
911
|
-
temp_real=arr,
|
912
|
-
temp_fft=ft_temp,
|
913
|
-
)
|
914
|
-
eps = backend.eps(real_dtype)
|
915
|
-
template_filter_func = conditional_execute(
|
916
|
-
apply_filter, backend.size(template_filter) != 1
|
917
|
-
)
|
918
|
-
callback_func = conditional_execute(callback, callback_class is not None)
|
919
|
-
|
920
|
-
unpadded_slice = tuple(slice(0, stop) for stop in template.shape)
|
921
|
-
for index in range(rotations.shape[0]):
|
922
|
-
rotation = rotations[index]
|
923
|
-
backend.fill(arr, 0)
|
924
|
-
backend.fill(temp, 0)
|
925
|
-
backend.rotate_array(
|
926
|
-
arr=template,
|
927
|
-
arr_mask=template_mask,
|
928
|
-
rotation_matrix=rotation,
|
929
|
-
out=arr,
|
930
|
-
out_mask=temp,
|
931
|
-
use_geometric_center=True,
|
932
|
-
order=interpolation_order,
|
933
|
-
)
|
934
|
-
# Given the amount of FFTs, might aswell normalize properly
|
935
|
-
n_observations = backend.sum(temp)
|
936
|
-
|
937
|
-
normalize_under_mask(
|
938
|
-
template=arr[unpadded_slice],
|
939
|
-
mask=temp[unpadded_slice],
|
940
|
-
mask_intensity=n_observations,
|
941
|
-
)
|
942
|
-
|
943
|
-
rfftn(temp, ft_temp)
|
944
|
-
|
945
|
-
backend.multiply(ft_target, ft_temp, out=ft_denom)
|
946
|
-
irfftn(ft_denom, temp)
|
947
|
-
backend.divide(temp, n_observations, out=temp)
|
948
|
-
backend.square(temp, out=temp)
|
949
|
-
|
950
|
-
backend.multiply(ft_target2, ft_temp, out=ft_denom)
|
951
|
-
irfftn(ft_denom, temp2)
|
952
|
-
backend.divide(temp2, n_observations, out=temp2)
|
953
|
-
|
954
|
-
backend.subtract(temp2, temp, out=temp)
|
955
|
-
backend.maximum(temp, 0.0, out=temp)
|
956
|
-
backend.sqrt(temp, out=temp)
|
957
|
-
backend.multiply(temp, n_observations, out=temp)
|
958
|
-
|
959
|
-
rfftn(arr, ft_temp)
|
960
|
-
template_filter_func(ft_template=ft_temp, template_filter=template_filter)
|
961
|
-
backend.multiply(ft_target, ft_temp, out=ft_temp)
|
962
|
-
irfftn(ft_temp, arr)
|
963
|
-
|
964
|
-
tol = eps
|
965
|
-
# tol = 1e3 * eps * backend.max(backend.abs(temp))
|
966
|
-
nonzero_indices = temp > tol
|
967
|
-
backend.fill(temp2, 0)
|
968
|
-
temp2[nonzero_indices] = arr[nonzero_indices] / temp[nonzero_indices]
|
969
|
-
|
970
|
-
callback_func(
|
971
|
-
temp2,
|
972
|
-
rotation_matrix=rotation,
|
973
|
-
rotation_index=index,
|
974
|
-
**callback_class_args,
|
975
|
-
)
|
976
|
-
|
977
|
-
return callback
|
978
|
-
|
979
|
-
|
980
|
-
def flc_scoring2(
|
981
|
-
template: Tuple[type, Tuple[int], type],
|
982
|
-
template_mask: Tuple[type, Tuple[int], type],
|
983
|
-
ft_target: Tuple[type, Tuple[int], type],
|
984
|
-
ft_target2: Tuple[type, Tuple[int], type],
|
985
|
-
template_filter: Tuple[type, Tuple[int], type],
|
986
|
-
targetshape: Tuple[int],
|
987
|
-
templateshape: Tuple[int],
|
988
|
-
fast_shape: Tuple[int],
|
989
|
-
fast_ft_shape: Tuple[int],
|
990
|
-
rotations: NDArray,
|
991
|
-
callback_class: CallbackClass,
|
992
|
-
callback_class_args: Dict,
|
993
|
-
interpolation_order: int,
|
994
|
-
**kwargs,
|
995
|
-
) -> CallbackClass:
|
996
|
-
"""
|
997
|
-
Computes a normalized cross-correlation score of a target f a template g
|
998
|
-
and a mask m:
|
999
|
-
|
1000
|
-
.. math::
|
1001
|
-
|
1002
|
-
\\frac{CC(f, \\frac{g*m - \\overline{g*m}}{\\sigma_{g*m}})}
|
1003
|
-
{N_m * \\sqrt{
|
1004
|
-
\\frac{CC(f^2, m)}{N_m} - (\\frac{CC(f, m)}{N_m})^2}
|
1005
|
-
}
|
1006
|
-
|
1007
|
-
Where:
|
1008
|
-
|
1009
|
-
.. math::
|
1010
|
-
|
1011
|
-
CC(f,g) = \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*)
|
1012
|
-
|
1013
|
-
and Nm is the number of voxels within the template mask m.
|
1014
|
-
|
1015
|
-
References
|
1016
|
-
----------
|
1017
|
-
.. [1] W. Wan, S. Khavnekar, J. Wagner, P. Erdmann, and W. Baumeister
|
1018
|
-
Microsc. Microanal. 26, 2516 (2020)
|
1019
|
-
.. [2] T. Hrabe, Y. Chen, S. Pfeffer, L. Kuhn Cuellar, A.-V. Mangold,
|
1020
|
-
and F. Förster, J. Struct. Biol. 178, 177 (2012).
|
1021
|
-
"""
|
1022
|
-
real_dtype, complex_dtype = backend._float_dtype, backend._complex_dtype
|
1023
|
-
callback = callback_class
|
1024
|
-
if callback_class is not None and isinstance(callback_class, type):
|
1025
|
-
callback = callback_class(**callback_class_args)
|
1026
|
-
|
1027
|
-
# Retrieve objects from shared memory
|
1028
|
-
template = backend.sharedarr_to_arr(*template)
|
1029
|
-
template_mask = backend.sharedarr_to_arr(*template_mask)
|
1030
|
-
ft_target = backend.sharedarr_to_arr(*ft_target)
|
1031
|
-
ft_target2 = backend.sharedarr_to_arr(*ft_target2)
|
1032
|
-
template_filter = backend.sharedarr_to_arr(*template_filter)
|
1033
|
-
|
1034
|
-
arr = backend.preallocate_array(fast_shape, real_dtype)
|
1035
|
-
temp = backend.preallocate_array(fast_shape, real_dtype)
|
1036
|
-
temp2 = backend.preallocate_array(fast_shape, real_dtype)
|
1037
|
-
|
1038
|
-
ft_temp = backend.preallocate_array(fast_ft_shape, complex_dtype)
|
1039
|
-
ft_denom = backend.preallocate_array(fast_ft_shape, complex_dtype)
|
1040
|
-
|
1041
|
-
eps = backend.eps(real_dtype)
|
1042
|
-
template_filter_func = conditional_execute(
|
1043
|
-
apply_filter, backend.size(template_filter) != 1
|
1044
|
-
)
|
1045
|
-
callback_func = conditional_execute(callback, callback_class is not None)
|
1046
|
-
|
1047
|
-
squeeze_axis = tuple(i for i, x in enumerate(template.shape) if x == 1)
|
1048
|
-
squeeze = tuple(
|
1049
|
-
slice(0, stop) if i not in squeeze_axis else 0
|
1050
|
-
for i, stop in enumerate(template.shape)
|
1051
|
-
)
|
1052
|
-
squeeze_fast = tuple(
|
1053
|
-
slice(0, stop) if i not in squeeze_axis else 0
|
1054
|
-
for i, stop in enumerate(fast_shape)
|
1055
|
-
)
|
1056
|
-
squeeze_fast_ft = tuple(
|
1057
|
-
slice(0, stop) if i not in squeeze_axis else 0
|
1058
|
-
for i, stop in enumerate(fast_ft_shape)
|
1059
|
-
)
|
1060
|
-
|
1061
|
-
rfftn, irfftn = backend.build_fft(
|
1062
|
-
fast_shape=temp[squeeze_fast].shape,
|
1063
|
-
fast_ft_shape=fast_ft_shape,
|
1064
|
-
real_dtype=real_dtype,
|
1065
|
-
complex_dtype=complex_dtype,
|
1066
|
-
fftargs=kwargs.get("fftargs", {}),
|
1067
|
-
inverse_fast_shape=fast_shape,
|
1068
|
-
temp_real=arr[squeeze_fast],
|
1069
|
-
temp_fft=ft_temp,
|
1070
|
-
)
|
1071
|
-
for index in range(rotations.shape[0]):
|
1072
|
-
rotation = rotations[index]
|
1073
|
-
backend.fill(arr, 0)
|
1074
|
-
backend.fill(temp, 0)
|
1075
|
-
backend.rotate_array(
|
1076
|
-
arr=template[squeeze],
|
1077
|
-
arr_mask=template_mask[squeeze],
|
1078
|
-
rotation_matrix=rotation,
|
1079
|
-
out=arr[squeeze],
|
1080
|
-
out_mask=temp[squeeze],
|
1081
|
-
use_geometric_center=True,
|
1082
|
-
order=interpolation_order,
|
1083
|
-
)
|
1084
|
-
# Given the amount of FFTs, might aswell normalize properly
|
1085
|
-
n_observations = backend.sum(temp)
|
1086
|
-
|
1087
|
-
normalize_under_mask(
|
1088
|
-
template=arr[squeeze],
|
1089
|
-
mask=temp[squeeze],
|
1090
|
-
mask_intensity=n_observations,
|
1091
|
-
)
|
1092
|
-
rfftn(temp[squeeze_fast], ft_temp[squeeze_fast_ft])
|
1093
|
-
|
1094
|
-
backend.multiply(ft_target, ft_temp[squeeze_fast_ft], out=ft_denom)
|
1095
|
-
irfftn(ft_denom, temp)
|
1096
|
-
backend.divide(temp, n_observations, out=temp)
|
1097
|
-
backend.square(temp, out=temp)
|
1098
|
-
|
1099
|
-
backend.multiply(ft_target2, ft_temp[squeeze_fast_ft], out=ft_denom)
|
1100
|
-
irfftn(ft_denom, temp2)
|
1101
|
-
backend.divide(temp2, n_observations, out=temp2)
|
1102
|
-
|
1103
|
-
backend.subtract(temp2, temp, out=temp)
|
1104
|
-
backend.maximum(temp, 0.0, out=temp)
|
1105
|
-
backend.sqrt(temp, out=temp)
|
1106
|
-
backend.multiply(temp, n_observations, out=temp)
|
1107
|
-
|
1108
|
-
rfftn(arr[squeeze_fast], ft_temp[squeeze_fast_ft])
|
1109
|
-
template_filter_func(ft_template=ft_temp, template_filter=template_filter)
|
1110
|
-
backend.multiply(ft_target, ft_temp[squeeze_fast_ft], out=ft_denom)
|
1111
|
-
irfftn(ft_denom, arr)
|
1112
|
-
|
1113
|
-
nonzero_indices = temp > eps
|
1114
|
-
backend.fill(temp2, 0)
|
1115
|
-
temp2[nonzero_indices] = arr[nonzero_indices] / temp[nonzero_indices]
|
1116
|
-
|
1117
|
-
callback_func(
|
1118
|
-
temp2,
|
1119
|
-
rotation_matrix=rotation,
|
1120
|
-
rotation_index=index,
|
1121
|
-
**callback_class_args,
|
1122
|
-
)
|
1123
|
-
return callback
|
1124
|
-
|
1125
|
-
|
1126
|
-
def mcc_scoring(
|
1127
|
-
template: Tuple[type, Tuple[int], type],
|
1128
|
-
template_mask: Tuple[type, Tuple[int], type],
|
1129
|
-
ft_target: Tuple[type, Tuple[int], type],
|
1130
|
-
ft_target2: Tuple[type, Tuple[int], type],
|
1131
|
-
ft_target_mask: Tuple[type, Tuple[int], type],
|
1132
|
-
template_filter: Tuple[type, Tuple[int], type],
|
1133
|
-
targetshape: Tuple[int],
|
1134
|
-
templateshape: Tuple[int],
|
1135
|
-
fast_shape: Tuple[int],
|
1136
|
-
fast_ft_shape: Tuple[int],
|
1137
|
-
rotations: NDArray,
|
1138
|
-
callback_class: CallbackClass,
|
1139
|
-
callback_class_args: type,
|
1140
|
-
interpolation_order: int,
|
1141
|
-
overlap_ratio: float = 0.3,
|
1142
|
-
**kwargs,
|
1143
|
-
) -> CallbackClass:
|
1144
|
-
"""
|
1145
|
-
Computes a cross-correlation score with masks for both template and target.
|
1146
|
-
|
1147
|
-
.. math::
|
1148
|
-
|
1149
|
-
\\frac{
|
1150
|
-
CC(f, g) - \\frac{CC(f, m) \\cdot CC(t, g)}{CC(t, m)}
|
1151
|
-
}{
|
1152
|
-
\\sqrt{
|
1153
|
-
(CC(f ^ 2, m) - \\frac{CC(f, m) ^ 2}{CC(t, m)}) \\cdot
|
1154
|
-
(CC(t, g^2) - \\frac{CC(t, g) ^ 2}{CC(t, m)})
|
1155
|
-
}
|
1156
|
-
}
|
1157
|
-
|
1158
|
-
Where:
|
1159
|
-
|
1160
|
-
.. math::
|
1161
|
-
|
1162
|
-
CC(f,g) = \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*)
|
1163
|
-
|
1164
|
-
|
1165
|
-
References
|
1166
|
-
----------
|
1167
|
-
.. [1] Masked FFT registration, Dirk Padfield, CVPR 2010 conference
|
1168
|
-
.. [2] https://scikit-image.org/docs/stable/api/skimage.registration.html
|
31
|
+
last_type : type
|
32
|
+
The type of the last exception.
|
33
|
+
last_value :
|
34
|
+
The value of the last exception.
|
35
|
+
last_traceback : traceback
|
36
|
+
Traceback call stack at the point where the Exception occured.
|
1169
37
|
|
1170
|
-
|
1171
|
-
|
1172
|
-
|
38
|
+
Raises
|
39
|
+
------
|
40
|
+
Exception
|
41
|
+
Re-raises the last exception.
|
1173
42
|
"""
|
1174
|
-
|
1175
|
-
|
1176
|
-
|
1177
|
-
|
1178
|
-
|
1179
|
-
# Retrieve objects from shared memory
|
1180
|
-
template_buffer, template_shape, template_dtype = template
|
1181
|
-
template = backend.sharedarr_to_arr(*template)
|
1182
|
-
target_ft = backend.sharedarr_to_arr(*ft_target)
|
1183
|
-
target_ft2 = backend.sharedarr_to_arr(*ft_target2)
|
1184
|
-
template_mask = backend.sharedarr_to_arr(*template_mask)
|
1185
|
-
target_mask_ft = backend.sharedarr_to_arr(*ft_target_mask)
|
1186
|
-
template_filter = backend.sharedarr_to_arr(*template_filter)
|
1187
|
-
|
1188
|
-
axes = tuple(range(template.ndim))
|
1189
|
-
eps = backend.eps(real_dtype)
|
1190
|
-
|
1191
|
-
# Allocate score and process specific arrays
|
1192
|
-
template_rot = backend.preallocate_array(fast_shape, real_dtype)
|
1193
|
-
mask_overlap = backend.preallocate_array(fast_shape, real_dtype)
|
1194
|
-
numerator = backend.preallocate_array(fast_shape, real_dtype)
|
1195
|
-
|
1196
|
-
temp = backend.preallocate_array(fast_shape, real_dtype)
|
1197
|
-
temp2 = backend.preallocate_array(fast_shape, real_dtype)
|
1198
|
-
temp3 = backend.preallocate_array(fast_shape, real_dtype)
|
1199
|
-
temp_ft = backend.preallocate_array(fast_ft_shape, complex_dtype)
|
1200
|
-
|
1201
|
-
rfftn, irfftn = backend.build_fft(
|
1202
|
-
fast_shape=fast_shape,
|
1203
|
-
fast_ft_shape=fast_ft_shape,
|
1204
|
-
real_dtype=real_dtype,
|
1205
|
-
complex_dtype=complex_dtype,
|
1206
|
-
fftargs=kwargs.get("fftargs", {}),
|
1207
|
-
temp_real=numerator,
|
1208
|
-
temp_fft=temp_ft,
|
1209
|
-
)
|
1210
|
-
|
1211
|
-
template_filter_func = conditional_execute(
|
1212
|
-
apply_filter, backend.size(template_filter) != 1
|
1213
|
-
)
|
1214
|
-
callback_func = conditional_execute(callback, callback_class is not None)
|
1215
|
-
|
1216
|
-
# Calculate scores across all rotations
|
1217
|
-
for index in range(rotations.shape[0]):
|
1218
|
-
rotation = rotations[index]
|
1219
|
-
backend.fill(template_rot, 0)
|
1220
|
-
backend.fill(temp, 0)
|
1221
|
-
|
1222
|
-
backend.rotate_array(
|
1223
|
-
arr=template,
|
1224
|
-
arr_mask=template_mask,
|
1225
|
-
rotation_matrix=rotation,
|
1226
|
-
out=template_rot,
|
1227
|
-
out_mask=temp,
|
1228
|
-
use_geometric_center=True,
|
1229
|
-
order=interpolation_order,
|
1230
|
-
)
|
1231
|
-
|
1232
|
-
backend.multiply(template_rot, temp > 0, out=template_rot)
|
1233
|
-
|
1234
|
-
# template_rot_ft
|
1235
|
-
rfftn(template_rot, temp_ft)
|
1236
|
-
template_filter_func(ft_template=temp_ft, template_filter=template_filter)
|
1237
|
-
irfftn(target_mask_ft * temp_ft, temp2)
|
1238
|
-
irfftn(target_ft * temp_ft, numerator)
|
1239
|
-
|
1240
|
-
# temp template_mask_rot | temp_ft template_mask_rot_ft
|
1241
|
-
# Calculate overlap of masks at every point in the convolution.
|
1242
|
-
# Locations with high overlap should not be taken into account.
|
1243
|
-
rfftn(temp, temp_ft)
|
1244
|
-
irfftn(temp_ft * target_mask_ft, mask_overlap)
|
1245
|
-
mask_overlap[:] = np.round(mask_overlap)
|
1246
|
-
mask_overlap[:] = np.maximum(mask_overlap, eps)
|
1247
|
-
irfftn(temp_ft * target_ft, temp)
|
1248
|
-
|
1249
|
-
backend.subtract(
|
1250
|
-
numerator,
|
1251
|
-
backend.divide(backend.multiply(temp, temp2), mask_overlap),
|
1252
|
-
out=numerator,
|
1253
|
-
)
|
43
|
+
if last_type is None:
|
44
|
+
return None
|
45
|
+
traceback.print_tb(last_traceback)
|
46
|
+
raise Exception(last_value)
|
1254
47
|
|
1255
|
-
# temp_3 = fixed_denom
|
1256
|
-
backend.multiply(temp_ft, target_ft2, out=temp_ft)
|
1257
|
-
irfftn(temp_ft, temp3)
|
1258
|
-
backend.subtract(
|
1259
|
-
temp3, backend.divide(backend.square(temp), mask_overlap), out=temp3
|
1260
|
-
)
|
1261
|
-
backend.maximum(temp3, 0.0, out=temp3)
|
1262
|
-
|
1263
|
-
# temp = moving_denom
|
1264
|
-
rfftn(backend.square(template_rot), temp_ft)
|
1265
|
-
backend.multiply(target_mask_ft, temp_ft, out=temp_ft)
|
1266
|
-
irfftn(temp_ft, temp)
|
1267
|
-
|
1268
|
-
backend.subtract(
|
1269
|
-
temp, backend.divide(backend.square(temp2), mask_overlap), out=temp
|
1270
|
-
)
|
1271
|
-
backend.maximum(temp, 0.0, out=temp)
|
1272
48
|
|
1273
|
-
|
1274
|
-
|
1275
|
-
|
1276
|
-
|
1277
|
-
# Pixels where `denom` is very small will introduce large
|
1278
|
-
# numbers after division. To get around this problem,
|
1279
|
-
# we zero-out problematic pixels.
|
1280
|
-
tol = 1e3 * eps * backend.max(backend.abs(temp2), axis=axes, keepdims=True)
|
1281
|
-
nonzero_indices = temp2 > tol
|
1282
|
-
|
1283
|
-
backend.fill(temp, 0)
|
1284
|
-
temp[nonzero_indices] = numerator[nonzero_indices] / temp2[nonzero_indices]
|
1285
|
-
backend.clip(temp, a_min=-1, a_max=1, out=temp)
|
1286
|
-
|
1287
|
-
# Apply overlap ratio threshold
|
1288
|
-
number_px_threshold = overlap_ratio * backend.max(
|
1289
|
-
mask_overlap, axis=axes, keepdims=True
|
1290
|
-
)
|
1291
|
-
temp[mask_overlap < number_px_threshold] = 0.0
|
49
|
+
def _wrap_backend(func):
|
50
|
+
@wraps(func)
|
51
|
+
def wrapper(*args, backend_name: str, backend_args: Dict, **kwargs):
|
52
|
+
from tme.backends import backend as be
|
1292
53
|
|
1293
|
-
|
1294
|
-
|
1295
|
-
rotation_matrix=rotation,
|
1296
|
-
rotation_index=index,
|
1297
|
-
**callback_class_args,
|
1298
|
-
)
|
54
|
+
be.change_backend(backend_name, **backend_args)
|
55
|
+
return func(*args, **kwargs)
|
1299
56
|
|
1300
|
-
return
|
57
|
+
return wrapper
|
1301
58
|
|
1302
59
|
|
1303
60
|
def _setup_template_filter_apply_target_filter(
|
@@ -1306,42 +63,59 @@ def _setup_template_filter_apply_target_filter(
|
|
1306
63
|
irfftn: Callable,
|
1307
64
|
fast_shape: Tuple[int],
|
1308
65
|
fast_ft_shape: Tuple[int],
|
66
|
+
pad_template_filter: bool = True,
|
1309
67
|
):
|
1310
68
|
filter_template = isinstance(matching_data.template_filter, Compose)
|
1311
69
|
filter_target = isinstance(matching_data.target_filter, Compose)
|
1312
70
|
|
1313
|
-
template_filter =
|
71
|
+
template_filter = be.full(shape=(1,), fill_value=1, dtype=be._float_dtype)
|
1314
72
|
|
1315
73
|
if not filter_template and not filter_target:
|
1316
74
|
return template_filter
|
1317
75
|
|
1318
|
-
|
1319
|
-
|
1320
|
-
)
|
1321
|
-
|
1322
|
-
|
1323
|
-
|
1324
|
-
|
1325
|
-
|
1326
|
-
|
1327
|
-
|
1328
|
-
|
1329
|
-
|
76
|
+
inv_mask = be.subtract(1, be.to_backend_array(matching_data._batch_mask))
|
77
|
+
filter_shape = be.multiply(be.to_backend_array(fast_ft_shape), inv_mask)
|
78
|
+
filter_shape = tuple(int(x) if x != 0 else 1 for x in filter_shape)
|
79
|
+
fast_shape = be.multiply(be.to_backend_array(fast_shape), inv_mask)
|
80
|
+
fast_shape = tuple(int(x) for x in fast_shape if x != 0)
|
81
|
+
|
82
|
+
fastt_shape, fastt_ft_shape = fast_shape, filter_shape
|
83
|
+
if filter_template and not pad_template_filter:
|
84
|
+
# FFT shape acrobatics for faster filter application
|
85
|
+
_, fastt_shape, _, _ = matching_data._fourier_padding(
|
86
|
+
target_shape=be.to_numpy_array(matching_data._template.shape),
|
87
|
+
template_shape=be.to_numpy_array(
|
88
|
+
[1 for _ in matching_data._template.shape]
|
89
|
+
),
|
90
|
+
pad_fourier=False,
|
91
|
+
)
|
92
|
+
matching_data.template = be.reverse(
|
93
|
+
be.topleft_pad(matching_data.template, fastt_shape)
|
94
|
+
)
|
95
|
+
matching_data.template_mask = be.reverse(
|
96
|
+
be.topleft_pad(matching_data.template_mask, fastt_shape)
|
97
|
+
)
|
98
|
+
matching_data._set_matching_dimension(
|
99
|
+
target_dims=matching_data._target_dims,
|
100
|
+
template_dims=matching_data._template_dims,
|
101
|
+
)
|
102
|
+
fastt_ft_shape = [int(x) for x in matching_data._output_template_shape]
|
103
|
+
fastt_ft_shape[-1] = fastt_ft_shape[-1] // 2 + 1
|
1330
104
|
|
1331
|
-
|
1332
|
-
|
105
|
+
target_temp = be.topleft_pad(matching_data.target, fast_shape)
|
106
|
+
target_temp_ft = be.zeros(fast_ft_shape, be._complex_dtype)
|
107
|
+
target_temp_ft = rfftn(target_temp, target_temp_ft)
|
108
|
+
if filter_template:
|
1333
109
|
template_filter = matching_data.template_filter(
|
1334
|
-
shape=
|
110
|
+
shape=fastt_shape,
|
1335
111
|
return_real_fourier=True,
|
1336
112
|
shape_is_real_fourier=False,
|
1337
113
|
data_rfft=target_temp_ft,
|
1338
114
|
batch_dimension=matching_data._target_dims,
|
1339
|
-
)
|
1340
|
-
template_filter = template_filter
|
1341
|
-
template_filter[tuple(0 for _ in range(template_filter.ndim))] = 0
|
1342
|
-
template_filter = backend.reshape(template_filter, filter_shape)
|
115
|
+
)["data"]
|
116
|
+
template_filter = be.reshape(template_filter, fastt_ft_shape)
|
1343
117
|
|
1344
|
-
if
|
118
|
+
if filter_target:
|
1345
119
|
target_filter = matching_data.target_filter(
|
1346
120
|
shape=fast_shape,
|
1347
121
|
return_real_fourier=True,
|
@@ -1349,15 +123,12 @@ def _setup_template_filter_apply_target_filter(
|
|
1349
123
|
data_rfft=target_temp_ft,
|
1350
124
|
weight_type=None,
|
1351
125
|
batch_dimension=matching_data._target_dims,
|
1352
|
-
)
|
1353
|
-
target_filter = target_filter
|
1354
|
-
|
1355
|
-
backend.multiply(target_temp_ft, target_filter, out=target_temp_ft)
|
126
|
+
)["data"]
|
127
|
+
target_filter = be.reshape(target_filter, filter_shape)
|
128
|
+
target_temp_ft = be.multiply(target_temp_ft, target_filter, out=target_temp_ft)
|
1356
129
|
|
1357
|
-
irfftn(target_temp_ft, target_temp)
|
1358
|
-
matching_data._target =
|
1359
|
-
target_temp, matching_data.target.shape
|
1360
|
-
)
|
130
|
+
target_temp = irfftn(target_temp_ft, target_temp)
|
131
|
+
matching_data._target = be.topleft_pad(target_temp, matching_data.target.shape)
|
1361
132
|
|
1362
133
|
return template_filter
|
1363
134
|
|
@@ -1371,13 +142,14 @@ def device_memory_handler(func: Callable):
|
|
1371
142
|
last_type, last_value, last_traceback = sys.exc_info()
|
1372
143
|
try:
|
1373
144
|
with SharedMemoryManager() as smh:
|
1374
|
-
|
145
|
+
gpu_index = kwargs.pop("gpu_index") if "gpu_index" in kwargs else 0
|
146
|
+
with be.set_device(gpu_index):
|
1375
147
|
return_value = func(shared_memory_handler=smh, *args, **kwargs)
|
1376
148
|
except Exception as e:
|
1377
149
|
print(e)
|
1378
150
|
last_type, last_value, last_traceback = sys.exc_info()
|
1379
151
|
finally:
|
1380
|
-
|
152
|
+
_handle_traceback(last_type, last_value, last_traceback)
|
1381
153
|
return return_value
|
1382
154
|
|
1383
155
|
return inner_function
|
@@ -1391,18 +163,20 @@ def scan(
|
|
1391
163
|
n_jobs: int = 4,
|
1392
164
|
callback_class: CallbackClass = None,
|
1393
165
|
callback_class_args: Dict = {},
|
1394
|
-
fftargs: Dict = {},
|
1395
166
|
pad_fourier: bool = True,
|
167
|
+
pad_template_filter: bool = True,
|
1396
168
|
interpolation_order: int = 3,
|
1397
169
|
jobs_per_callback_class: int = 8,
|
1398
|
-
|
1399
|
-
) -> Tuple:
|
170
|
+
shared_memory_handler=None,
|
171
|
+
) -> Optional[Tuple]:
|
1400
172
|
"""
|
1401
|
-
|
173
|
+
Run template matching.
|
174
|
+
|
175
|
+
.. warning:: ``matching_data`` might be altered or destroyed during computation.
|
1402
176
|
|
1403
177
|
Parameters
|
1404
178
|
----------
|
1405
|
-
matching_data : MatchingData
|
179
|
+
matching_data : :py:class:`tme.matching_data.MatchingData`
|
1406
180
|
Template matching data.
|
1407
181
|
matching_setup : Callable
|
1408
182
|
Function pointer to setup function.
|
@@ -1414,21 +188,21 @@ def scan(
|
|
1414
188
|
Analyzer class pointer to operate on computed scores.
|
1415
189
|
callback_class_args : dict, optional
|
1416
190
|
Arguments passed to the callback_class. Default is an empty dictionary.
|
1417
|
-
fftargs : dict, optional
|
1418
|
-
Arguments for the FFT operations. Default is an empty dictionary.
|
1419
191
|
pad_fourier: bool, optional
|
1420
192
|
Whether to pad target and template to the full convolution shape.
|
193
|
+
pad_template_filter: bool, optional
|
194
|
+
Whether to pad potential template filters to the full convolution shape.
|
1421
195
|
interpolation_order : int, optional
|
1422
196
|
Order of spline interpolation for rotations.
|
1423
197
|
jobs_per_callback_class : int, optional
|
1424
198
|
How many jobs should be processed by a single callback_class instance,
|
1425
199
|
if one is provided.
|
1426
|
-
|
1427
|
-
|
200
|
+
shared_memory_handler : type, optional
|
201
|
+
Manager for shared memory objects, None by default.
|
1428
202
|
|
1429
203
|
Returns
|
1430
204
|
-------
|
1431
|
-
Tuple
|
205
|
+
Optional[Tuple]
|
1432
206
|
The merged results from callback_class if provided otherwise None.
|
1433
207
|
|
1434
208
|
Examples
|
@@ -1450,159 +224,100 @@ def scan(
|
|
1450
224
|
|
1451
225
|
"""
|
1452
226
|
matching_data.to_backend()
|
1453
|
-
|
1454
|
-
|
1455
|
-
|
1456
|
-
|
1457
|
-
|
1458
|
-
|
1459
|
-
|
1460
|
-
|
1461
|
-
|
1462
|
-
" zero-pad the target or turn off template centering."
|
1463
|
-
)
|
1464
|
-
fast_shape, fast_ft_shape, fourier_shift = matching_data.fourier_padding(
|
1465
|
-
pad_fourier=pad_fourier
|
1466
|
-
)
|
1467
|
-
|
1468
|
-
callback_class_args["fourier_shift"] = fourier_shift
|
1469
|
-
rfftn, irfftn = backend.build_fft(
|
227
|
+
(
|
228
|
+
conv_shape,
|
229
|
+
fast_shape,
|
230
|
+
fast_ft_shape,
|
231
|
+
fourier_shift,
|
232
|
+
) = matching_data.fourier_padding(pad_fourier=pad_fourier)
|
233
|
+
template_shape = matching_data.template.shape
|
234
|
+
|
235
|
+
rfftn, irfftn = be.build_fft(
|
1470
236
|
fast_shape=fast_shape,
|
1471
237
|
fast_ft_shape=fast_ft_shape,
|
1472
|
-
real_dtype=
|
1473
|
-
complex_dtype=
|
1474
|
-
fftargs=fftargs,
|
238
|
+
real_dtype=be._float_dtype,
|
239
|
+
complex_dtype=be._complex_dtype,
|
1475
240
|
)
|
1476
|
-
|
1477
241
|
template_filter = _setup_template_filter_apply_target_filter(
|
1478
242
|
matching_data=matching_data,
|
1479
243
|
rfftn=rfftn,
|
1480
244
|
irfftn=irfftn,
|
1481
245
|
fast_shape=fast_shape,
|
1482
246
|
fast_ft_shape=fast_ft_shape,
|
247
|
+
pad_template_filter=pad_template_filter,
|
1483
248
|
)
|
249
|
+
template_filter = be.astype(be.to_backend_array(template_filter), be._float_dtype)
|
1484
250
|
|
1485
251
|
setup = matching_setup(
|
1486
252
|
rfftn=rfftn,
|
1487
253
|
irfftn=irfftn,
|
1488
254
|
template=matching_data.template,
|
255
|
+
template_filter=template_filter,
|
1489
256
|
template_mask=matching_data.template_mask,
|
1490
257
|
target=matching_data.target,
|
1491
258
|
target_mask=matching_data.target_mask,
|
1492
259
|
fast_shape=fast_shape,
|
1493
260
|
fast_ft_shape=fast_ft_shape,
|
1494
|
-
|
1495
|
-
callback_class_args=callback_class_args,
|
1496
|
-
**kwargs,
|
261
|
+
shared_memory_handler=shared_memory_handler,
|
1497
262
|
)
|
1498
263
|
rfftn, irfftn = None, None
|
1499
|
-
|
1500
|
-
template_filter = backend.to_backend_array(template_filter)
|
1501
|
-
template_filter = backend.astype(template_filter, backend._float_dtype)
|
1502
|
-
template_filter_buffer = backend.arr_to_sharedarr(
|
1503
|
-
arr=template_filter,
|
1504
|
-
shared_memory_handler=kwargs.get("shared_memory_handler", None),
|
1505
|
-
)
|
1506
|
-
setup["template_filter"] = (
|
1507
|
-
template_filter_buffer,
|
1508
|
-
template_filter.shape,
|
1509
|
-
template_filter.dtype,
|
1510
|
-
)
|
1511
|
-
|
1512
|
-
callback_class_args["translation_offset"] = backend.astype(
|
1513
|
-
matching_data._translation_offset, int
|
1514
|
-
)
|
1515
|
-
callback_class_args["thread_safe"] = n_jobs > 1
|
1516
|
-
callback_class_args["gpu_index"] = kwargs.get("gpu_index", -1)
|
1517
|
-
|
1518
|
-
n_callback_classes = max(n_jobs // jobs_per_callback_class, 1)
|
1519
|
-
callback_class = setup.pop("callback_class", callback_class)
|
1520
|
-
callback_class_args = setup.pop("callback_class_args", callback_class_args)
|
1521
|
-
callback_classes = [callback_class for _ in range(n_callback_classes)]
|
1522
|
-
|
1523
|
-
convolution_mode = "same"
|
1524
|
-
if backend.sum(backend.to_backend_array(matching_data._target_pad)) > 0:
|
1525
|
-
convolution_mode = "valid"
|
1526
|
-
|
1527
|
-
callback_class_args["fourier_shift"] = fourier_shift
|
1528
|
-
callback_class_args["convolution_mode"] = convolution_mode
|
1529
|
-
callback_class_args["targetshape"] = setup["targetshape"]
|
1530
|
-
callback_class_args["templateshape"] = setup["templateshape"]
|
1531
|
-
|
1532
|
-
if callback_class == MaxScoreOverRotations:
|
1533
|
-
callback_classes = [
|
1534
|
-
class_name(
|
1535
|
-
score_space_shape=fast_shape,
|
1536
|
-
score_space_dtype=backend._float_dtype,
|
1537
|
-
shared_memory_handler=kwargs.get("shared_memory_handler", None),
|
1538
|
-
rotation_space_dtype=backend._int_dtype,
|
1539
|
-
**callback_class_args,
|
1540
|
-
)
|
1541
|
-
for class_name in callback_classes
|
1542
|
-
]
|
1543
|
-
|
1544
|
-
matching_data._target, matching_data._template = None, None
|
1545
|
-
matching_data._target_mask, matching_data._template_mask = None, None
|
1546
|
-
|
1547
|
-
setup["fftargs"] = fftargs.copy()
|
1548
|
-
convolution_mode = "same"
|
1549
|
-
if backend.sum(backend.to_backend_array(matching_data._target_pad)) > 0:
|
1550
|
-
convolution_mode = "valid"
|
1551
|
-
setup["convolution_mode"] = convolution_mode
|
1552
264
|
setup["interpolation_order"] = interpolation_order
|
1553
|
-
|
1554
|
-
|
1555
|
-
|
1556
|
-
|
1557
|
-
|
1558
|
-
|
265
|
+
setup["template_filter"] = be.to_sharedarr(template_filter, shared_memory_handler)
|
266
|
+
|
267
|
+
offset = be.to_backend_array(matching_data._translation_offset)
|
268
|
+
convmode = "valid" if getattr(matching_data, "_is_padded", False) else "same"
|
269
|
+
default_callback_args = {
|
270
|
+
"offset": be.astype(offset, be._int_dtype),
|
271
|
+
"thread_safe": n_jobs > 1,
|
272
|
+
"fourier_shift": fourier_shift,
|
273
|
+
"convolution_mode": convmode,
|
274
|
+
"targetshape": matching_data.target.shape,
|
275
|
+
"templateshape": template_shape,
|
276
|
+
"convolution_shape": conv_shape,
|
277
|
+
"fast_shape": fast_shape,
|
278
|
+
"indices": getattr(matching_data, "indices", None),
|
279
|
+
"shared_memory_handler": shared_memory_handler,
|
280
|
+
"only_unique_rotations": True,
|
281
|
+
}
|
282
|
+
default_callback_args.update(callback_class_args)
|
1559
283
|
|
1560
|
-
|
1561
|
-
|
284
|
+
matching_data._free_data()
|
285
|
+
be.free_cache()
|
1562
286
|
|
287
|
+
# For highly parallel jobs, blocking in certain analyzers becomes a bottleneck
|
288
|
+
if getattr(callback_class, "shared", True):
|
289
|
+
jobs_per_callback_class = 1
|
290
|
+
n_callback_classes = max(n_jobs // jobs_per_callback_class, 1)
|
291
|
+
callback_classes = [
|
292
|
+
callback_class(
|
293
|
+
shape=fast_shape,
|
294
|
+
**default_callback_args,
|
295
|
+
)
|
296
|
+
if callback_class is not None
|
297
|
+
else None
|
298
|
+
for _ in range(n_callback_classes)
|
299
|
+
]
|
1563
300
|
callbacks = Parallel(n_jobs=n_jobs)(
|
1564
|
-
delayed(
|
1565
|
-
backend_name=
|
1566
|
-
backend_args=
|
301
|
+
delayed(_wrap_backend(matching_score))(
|
302
|
+
backend_name=be._backend_name,
|
303
|
+
backend_args=be._backend_args,
|
1567
304
|
rotations=rotation,
|
1568
|
-
|
1569
|
-
callback_class_args=callback_class_args,
|
305
|
+
callback=callback_classes[index % n_callback_classes],
|
1570
306
|
**setup,
|
1571
307
|
)
|
1572
|
-
for index, rotation in enumerate(
|
308
|
+
for index, rotation in enumerate(matching_data._split_rotations_on_jobs(n_jobs))
|
1573
309
|
)
|
1574
310
|
|
1575
|
-
callbacks = callbacks[0:n_callback_classes]
|
1576
311
|
callbacks = [
|
1577
|
-
tuple(
|
1578
|
-
callback._postprocess(
|
1579
|
-
fourier_shift=fourier_shift,
|
1580
|
-
convolution_mode=convolution_mode,
|
1581
|
-
targetshape=setup["targetshape"],
|
1582
|
-
templateshape=setup["templateshape"],
|
1583
|
-
shared_memory_handler=kwargs.get("shared_memory_handler", None),
|
1584
|
-
)
|
1585
|
-
)
|
1586
|
-
if hasattr(callback, "_postprocess")
|
1587
|
-
else tuple(callback)
|
312
|
+
tuple(callback._postprocess(**default_callback_args))
|
1588
313
|
for callback in callbacks
|
1589
314
|
if callback is not None
|
1590
315
|
]
|
1591
|
-
|
316
|
+
be.free_cache()
|
1592
317
|
|
1593
|
-
merged_callback = None
|
1594
318
|
if callback_class is not None:
|
1595
|
-
|
1596
|
-
|
1597
|
-
score_indices = matching_data.indices
|
1598
|
-
merged_callback = callback_class.merge(
|
1599
|
-
callbacks,
|
1600
|
-
**callback_class_args,
|
1601
|
-
score_indices=score_indices,
|
1602
|
-
inner_merge=True,
|
1603
|
-
)
|
1604
|
-
|
1605
|
-
return merged_callback
|
319
|
+
return callback_class.merge(callbacks, **default_callback_args)
|
320
|
+
return None
|
1606
321
|
|
1607
322
|
|
1608
323
|
def scan_subsets(
|
@@ -1616,10 +331,12 @@ def scan_subsets(
|
|
1616
331
|
template_splits: Dict = {},
|
1617
332
|
pad_target_edges: bool = False,
|
1618
333
|
pad_fourier: bool = True,
|
334
|
+
pad_template_filter: bool = True,
|
1619
335
|
interpolation_order: int = 3,
|
1620
336
|
jobs_per_callback_class: int = 8,
|
1621
|
-
|
1622
|
-
|
337
|
+
backend_name: str = None,
|
338
|
+
backend_args: Dict = {},
|
339
|
+
) -> Optional[Tuple]:
|
1623
340
|
"""
|
1624
341
|
Wrapper around :py:meth:`scan` that supports matching on splits
|
1625
342
|
of ``matching_data``.
|
@@ -1651,21 +368,17 @@ def scan_subsets(
|
|
1651
368
|
along each axis.
|
1652
369
|
pad_fourier: bool, optional
|
1653
370
|
Whether to pad target and template to the full convolution shape.
|
371
|
+
pad_template_filter: bool, optional
|
372
|
+
Whether to pad potential template filters to the full convolution shape.
|
1654
373
|
interpolation_order : int, optional
|
1655
374
|
Order of spline interpolation for rotations.
|
1656
375
|
jobs_per_callback_class : int, optional
|
1657
376
|
How many jobs should be processed by a single callback_class instance,
|
1658
377
|
if ones is provided.
|
1659
|
-
**kwargs : various
|
1660
|
-
Additional arguments.
|
1661
|
-
|
1662
|
-
Notes
|
1663
|
-
-----
|
1664
|
-
Objects in matching_data might be destroyed during computation.
|
1665
378
|
|
1666
379
|
Returns
|
1667
380
|
-------
|
1668
|
-
Tuple
|
381
|
+
Optional[Tuple]
|
1669
382
|
The merged results from callback_class if provided otherwise None.
|
1670
383
|
|
1671
384
|
Examples
|
@@ -1720,73 +433,59 @@ def scan_subsets(
|
|
1720
433
|
>>> target_splits = target_splits,
|
1721
434
|
>>> )
|
1722
435
|
|
1723
|
-
The
|
436
|
+
The ``results`` tuple contains the output of the chosen analyzer.
|
1724
437
|
|
1725
438
|
See Also
|
1726
439
|
--------
|
1727
440
|
:py:meth:`tme.matching_utils.compute_parallelization_schedule`
|
1728
441
|
"""
|
1729
|
-
|
1730
|
-
|
1731
|
-
)
|
442
|
+
template_splits = split_shape(matching_data._template.shape, splits=template_splits)
|
443
|
+
target_splits = split_shape(matching_data._target.shape, splits=target_splits)
|
1732
444
|
if (len(target_splits) > 1) and not pad_target_edges:
|
1733
445
|
warnings.warn(
|
1734
446
|
"Target splitting without padding target edges leads to unreliable "
|
1735
447
|
"similarity estimates around the split border."
|
1736
448
|
)
|
1737
|
-
|
1738
|
-
template_splits = split_numpy_array_slices(
|
1739
|
-
matching_data._template.shape, splits=template_splits
|
1740
|
-
)
|
1741
|
-
target_pad = matching_data.target_padding(pad_target=pad_target_edges)
|
449
|
+
splits = tuple(product(target_splits, template_splits))
|
1742
450
|
|
1743
451
|
outer_jobs, inner_jobs = job_schedule
|
1744
|
-
|
1745
|
-
|
1746
|
-
|
1747
|
-
|
1748
|
-
matching_data=matching_data
|
1749
|
-
|
1750
|
-
|
1751
|
-
|
1752
|
-
),
|
1753
|
-
matching_score=matching_score,
|
1754
|
-
matching_setup=matching_setup,
|
1755
|
-
n_jobs=inner_jobs,
|
452
|
+
target_pad = matching_data.target_padding(pad_target=pad_target_edges)
|
453
|
+
if hasattr(be, "scan"):
|
454
|
+
corr_scoring = MATCHING_EXHAUSTIVE_REGISTER.get("CORR", (None, None))[1]
|
455
|
+
results = be.scan(
|
456
|
+
matching_data=matching_data,
|
457
|
+
splits=splits,
|
458
|
+
n_jobs=outer_jobs,
|
459
|
+
rotate_mask=matching_score != corr_scoring,
|
1756
460
|
callback_class=callback_class,
|
1757
|
-
callback_class_args=callback_class_args,
|
1758
|
-
interpolation_order=interpolation_order,
|
1759
|
-
pad_fourier=pad_fourier,
|
1760
|
-
gpu_index=index % outer_jobs,
|
1761
|
-
**kwargs,
|
1762
461
|
)
|
1763
|
-
|
1764
|
-
|
462
|
+
else:
|
463
|
+
results = Parallel(n_jobs=outer_jobs)(
|
464
|
+
delayed(_wrap_backend(scan))(
|
465
|
+
backend_name=be._backend_name,
|
466
|
+
backend_args=be._backend_args,
|
467
|
+
matching_data=matching_data.subset_by_slice(
|
468
|
+
target_slice=target_split,
|
469
|
+
target_pad=target_pad,
|
470
|
+
template_slice=template_split,
|
471
|
+
),
|
472
|
+
matching_score=matching_score,
|
473
|
+
matching_setup=matching_setup,
|
474
|
+
n_jobs=inner_jobs,
|
475
|
+
callback_class=callback_class,
|
476
|
+
callback_class_args=callback_class_args,
|
477
|
+
interpolation_order=interpolation_order,
|
478
|
+
pad_fourier=pad_fourier,
|
479
|
+
gpu_index=index % outer_jobs,
|
480
|
+
pad_template_filter=pad_template_filter,
|
481
|
+
)
|
482
|
+
for index, (target_split, template_split) in enumerate(splits)
|
1765
483
|
)
|
1766
|
-
)
|
1767
|
-
|
1768
|
-
matching_data._target, matching_data._template = None, None
|
1769
|
-
matching_data._target_mask, matching_data._template_mask = None, None
|
1770
484
|
|
1771
|
-
|
485
|
+
matching_data._free_data()
|
1772
486
|
if callback_class is not None:
|
1773
|
-
|
1774
|
-
|
1775
|
-
)
|
1776
|
-
|
1777
|
-
return candidates
|
1778
|
-
|
1779
|
-
|
1780
|
-
MATCHING_EXHAUSTIVE_REGISTER = {
|
1781
|
-
"CC": (cc_setup, corr_scoring),
|
1782
|
-
"LCC": (lcc_setup, corr_scoring),
|
1783
|
-
"CORR": (corr_setup, corr_scoring),
|
1784
|
-
"CAM": (cam_setup, corr_scoring),
|
1785
|
-
"FLCSphericalMask": (flcSphericalMask_setup, corr_scoring),
|
1786
|
-
"FLC": (flc_setup, flc_scoring),
|
1787
|
-
"FLC2": (flc_setup, flc_scoring2),
|
1788
|
-
"MCC": (mcc_setup, mcc_scoring),
|
1789
|
-
}
|
487
|
+
return callback_class.merge(results, **callback_class_args)
|
488
|
+
return None
|
1790
489
|
|
1791
490
|
|
1792
491
|
def register_matching_exhaustive(
|
@@ -1803,20 +502,17 @@ def register_matching_exhaustive(
|
|
1803
502
|
matching : str
|
1804
503
|
Name of the matching method.
|
1805
504
|
matching_setup : Callable
|
1806
|
-
|
505
|
+
Corresponding setup function.
|
1807
506
|
matching_scoring : Callable
|
1808
|
-
|
507
|
+
Corresponing scoring function.
|
1809
508
|
memory_class : MatchingMemoryUsage
|
1810
|
-
|
1811
|
-
:py:class:`tme.matching_memory.MatchingMemoryUsage`.
|
509
|
+
Child of :py:class:`tme.memory.MatchingMemoryUsage`.
|
1812
510
|
|
1813
511
|
Raises
|
1814
512
|
------
|
1815
513
|
ValueError
|
1816
|
-
If a function with the name ``matching`` already exists in the registry
|
1817
|
-
|
1818
|
-
If ``memory_class`` is not a subclass of
|
1819
|
-
:py:class:`tme.matching_memory.MatchingMemoryUsage`.
|
514
|
+
If a function with the name ``matching`` already exists in the registry, or
|
515
|
+
if ``memory_class`` is no child of :py:class:`tme.memory.MatchingMemoryUsage`.
|
1820
516
|
"""
|
1821
517
|
|
1822
518
|
if matching in MATCHING_EXHAUSTIVE_REGISTER:
|