pytme 0.2.1__cp311-cp311-macosx_14_0_arm64.whl → 0.2.2__cp311-cp311-macosx_14_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pytme-0.2.1.data → pytme-0.2.2.data}/scripts/match_template.py +147 -93
- {pytme-0.2.1.data → pytme-0.2.2.data}/scripts/postprocess.py +67 -26
- {pytme-0.2.1.data → pytme-0.2.2.data}/scripts/preprocessor_gui.py +175 -85
- pytme-0.2.2.dist-info/METADATA +91 -0
- pytme-0.2.2.dist-info/RECORD +74 -0
- {pytme-0.2.1.dist-info → pytme-0.2.2.dist-info}/WHEEL +1 -1
- scripts/extract_candidates.py +20 -13
- scripts/match_template.py +147 -93
- scripts/match_template_filters.py +154 -95
- scripts/postprocess.py +67 -26
- scripts/preprocessor_gui.py +175 -85
- scripts/refine_matches.py +265 -61
- tme/__init__.py +0 -1
- tme/__version__.py +1 -1
- tme/analyzer.py +451 -809
- tme/backends/__init__.py +40 -11
- tme/backends/_jax_utils.py +185 -0
- tme/backends/cupy_backend.py +111 -223
- tme/backends/jax_backend.py +214 -150
- tme/backends/matching_backend.py +445 -384
- tme/backends/mlx_backend.py +32 -59
- tme/backends/npfftw_backend.py +239 -507
- tme/backends/pytorch_backend.py +21 -145
- tme/density.py +233 -363
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/matching_data.py +322 -285
- tme/matching_exhaustive.py +172 -1493
- tme/matching_optimization.py +143 -106
- tme/matching_scores.py +884 -0
- tme/matching_utils.py +280 -386
- tme/memory.py +377 -0
- tme/orientations.py +52 -12
- tme/parser.py +3 -4
- tme/preprocessing/_utils.py +61 -32
- tme/preprocessing/compose.py +7 -3
- tme/preprocessing/frequency_filters.py +49 -39
- tme/preprocessing/tilt_series.py +34 -40
- tme/preprocessor.py +560 -526
- tme/structure.py +491 -188
- tme/types.py +5 -3
- pytme-0.2.1.dist-info/METADATA +0 -73
- pytme-0.2.1.dist-info/RECORD +0 -73
- tme/helpers.py +0 -881
- tme/matching_constrained.py +0 -195
- {pytme-0.2.1.data → pytme-0.2.2.data}/scripts/estimate_ram_usage.py +0 -0
- {pytme-0.2.1.data → pytme-0.2.2.data}/scripts/preprocess.py +0 -0
- {pytme-0.2.1.dist-info → pytme-0.2.2.dist-info}/LICENSE +0 -0
- {pytme-0.2.1.dist-info → pytme-0.2.2.dist-info}/entry_points.txt +0 -0
- {pytme-0.2.1.dist-info → pytme-0.2.2.dist-info}/top_level.txt +0 -0
tme/matching_exhaustive.py
CHANGED
@@ -4,1300 +4,57 @@
|
|
4
4
|
|
5
5
|
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
6
6
|
"""
|
7
|
-
import os
|
8
7
|
import sys
|
9
8
|
import warnings
|
10
|
-
|
11
|
-
from typing import Callable, Tuple, Dict
|
9
|
+
import traceback
|
12
10
|
from functools import wraps
|
11
|
+
from itertools import product
|
12
|
+
from typing import Callable, Tuple, Dict, Optional
|
13
|
+
|
13
14
|
from joblib import Parallel, delayed
|
14
15
|
from multiprocessing.managers import SharedMemoryManager
|
15
16
|
|
16
|
-
import
|
17
|
-
from scipy.ndimage import laplace
|
18
|
-
|
19
|
-
from .analyzer import MaxScoreOverRotations
|
20
|
-
from .matching_utils import (
|
21
|
-
handle_traceback,
|
22
|
-
split_numpy_array_slices,
|
23
|
-
conditional_execute,
|
24
|
-
)
|
25
|
-
from .matching_memory import MatchingMemoryUsage, MATCHING_MEMORY_REGISTRY
|
17
|
+
from .backends import backend as be
|
26
18
|
from .preprocessing import Compose
|
27
|
-
from .
|
28
|
-
from .
|
29
|
-
from .types import
|
30
|
-
|
31
|
-
os.environ["MKL_NUM_THREADS"] = "1"
|
32
|
-
os.environ["OMP_NUM_THREADS"] = "1"
|
33
|
-
os.environ["PYFFTW_NUM_THREADS"] = "1"
|
34
|
-
os.environ["OPENBLAS_NUM_THREADS"] = "1"
|
35
|
-
|
36
|
-
|
37
|
-
def _run_inner(backend_name, backend_args, **kwargs):
|
38
|
-
from tme.backends import backend
|
19
|
+
from .matching_utils import split_shape
|
20
|
+
from .matching_scores import MATCHING_EXHAUSTIVE_REGISTER
|
21
|
+
from .types import CallbackClass, MatchingData
|
22
|
+
from .memory import MatchingMemoryUsage, MATCHING_MEMORY_REGISTRY
|
39
23
|
|
40
|
-
backend.change_backend(backend_name, **backend_args)
|
41
|
-
return scan(**kwargs)
|
42
24
|
|
43
|
-
|
44
|
-
def normalize_under_mask(template: NDArray, mask: NDArray, mask_intensity) -> None:
|
25
|
+
def _handle_traceback(last_type, last_value, last_traceback):
|
45
26
|
"""
|
46
|
-
|
47
|
-
standard deviation based on the elements in mask. Subsequently, the template is
|
48
|
-
multiplied by the mask.
|
27
|
+
Handle sys.exc_info().
|
49
28
|
|
50
29
|
Parameters
|
51
30
|
----------
|
52
|
-
|
53
|
-
The
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
Mask intensity used to compute expectations.
|
59
|
-
|
60
|
-
References
|
61
|
-
----------
|
62
|
-
.. [1] T. Hrabe, Y. Chen, S. Pfeffer, L. Kuhn Cuellar, A.-V. Mangold,
|
63
|
-
and F. Förster, J. Struct. Biol. 178, 177 (2012).
|
64
|
-
|
65
|
-
Returns
|
66
|
-
-------
|
67
|
-
None
|
68
|
-
This function modifies `template` in-place and does not return any value.
|
69
|
-
"""
|
70
|
-
masked_mean = backend.sum(backend.multiply(template, mask))
|
71
|
-
masked_mean = backend.divide(masked_mean, mask_intensity)
|
72
|
-
masked_std = backend.sum(backend.multiply(backend.square(template), mask))
|
73
|
-
masked_std = backend.subtract(
|
74
|
-
masked_std / mask_intensity, backend.square(masked_mean)
|
75
|
-
)
|
76
|
-
masked_std = backend.sqrt(backend.maximum(masked_std, 0))
|
77
|
-
|
78
|
-
backend.subtract(template, masked_mean, out=template)
|
79
|
-
backend.divide(template, masked_std, out=template)
|
80
|
-
backend.multiply(template, mask, out=template)
|
81
|
-
return None
|
82
|
-
|
83
|
-
|
84
|
-
def _normalize_under_mask_overflow_safe(
|
85
|
-
template: NDArray, mask: NDArray, mask_intensity
|
86
|
-
) -> None:
|
87
|
-
_template = backend.astype(template, backend._overflow_safe_dtype)
|
88
|
-
_mask = backend.astype(mask, backend._overflow_safe_dtype)
|
89
|
-
normalize_under_mask(template=_template, mask=_mask, mask_intensity=mask_intensity)
|
90
|
-
template[:] = backend.astype(_template, template.dtype)
|
91
|
-
return None
|
92
|
-
|
93
|
-
|
94
|
-
def apply_filter(ft_template, template_filter):
|
95
|
-
# This is an approximation to applying the mask, irfftn, normalize, rfftn
|
96
|
-
std_before = backend.std(ft_template)
|
97
|
-
backend.multiply(ft_template, template_filter, out=ft_template)
|
98
|
-
backend.multiply(
|
99
|
-
ft_template, std_before / backend.std(ft_template), out=ft_template
|
100
|
-
)
|
101
|
-
|
102
|
-
|
103
|
-
def cc_setup(
|
104
|
-
rfftn: Callable,
|
105
|
-
irfftn: Callable,
|
106
|
-
template: NDArray,
|
107
|
-
target: NDArray,
|
108
|
-
fast_shape: Tuple[int],
|
109
|
-
fast_ft_shape: Tuple[int],
|
110
|
-
shared_memory_handler: Callable,
|
111
|
-
callback_class: Callable,
|
112
|
-
callback_class_args: Dict,
|
113
|
-
**kwargs,
|
114
|
-
) -> Dict:
|
115
|
-
"""
|
116
|
-
Setup to compute the cross-correlation between a target f and template g
|
117
|
-
defined as:
|
118
|
-
|
119
|
-
.. math::
|
120
|
-
|
121
|
-
\\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*)
|
122
|
-
|
123
|
-
|
124
|
-
See Also
|
125
|
-
--------
|
126
|
-
:py:meth:`corr_scoring`
|
127
|
-
:py:class:`tme.matching_optimization.CrossCorrelation`
|
128
|
-
"""
|
129
|
-
real_dtype, complex_dtype = backend._float_dtype, backend._complex_dtype
|
130
|
-
target_pad = backend.topleft_pad(target, fast_shape)
|
131
|
-
target_pad_ft = backend.preallocate_array(fast_ft_shape, complex_dtype)
|
132
|
-
|
133
|
-
rfftn(target_pad, target_pad_ft)
|
134
|
-
|
135
|
-
target_ft_out = backend.arr_to_sharedarr(
|
136
|
-
arr=target_pad_ft, shared_memory_handler=shared_memory_handler
|
137
|
-
)
|
138
|
-
|
139
|
-
template_out = backend.arr_to_sharedarr(
|
140
|
-
arr=template, shared_memory_handler=shared_memory_handler
|
141
|
-
)
|
142
|
-
inv_denominator_buffer = backend.arr_to_sharedarr(
|
143
|
-
arr=backend.preallocate_array(1, real_dtype) + 1,
|
144
|
-
shared_memory_handler=shared_memory_handler,
|
145
|
-
)
|
146
|
-
numerator_buffer = backend.arr_to_sharedarr(
|
147
|
-
arr=backend.preallocate_array(1, real_dtype),
|
148
|
-
shared_memory_handler=shared_memory_handler,
|
149
|
-
)
|
150
|
-
|
151
|
-
target_ft_tuple = (target_ft_out, fast_ft_shape, complex_dtype)
|
152
|
-
template_tuple = (template_out, template.shape, template.dtype)
|
153
|
-
|
154
|
-
inv_denominator_tuple = (inv_denominator_buffer, (1,), real_dtype)
|
155
|
-
numerator_tuple = (numerator_buffer, (1,), real_dtype)
|
156
|
-
|
157
|
-
ret = {
|
158
|
-
"template": template_tuple,
|
159
|
-
"ft_target": target_ft_tuple,
|
160
|
-
"inv_denominator": inv_denominator_tuple,
|
161
|
-
"numerator": numerator_tuple,
|
162
|
-
"targetshape": target.shape,
|
163
|
-
"templateshape": template.shape,
|
164
|
-
"fast_shape": fast_shape,
|
165
|
-
"fast_ft_shape": fast_ft_shape,
|
166
|
-
"callback_class": callback_class,
|
167
|
-
"callback_class_args": callback_class_args,
|
168
|
-
"use_memmap": kwargs.get("use_memmap", False),
|
169
|
-
"temp_dir": kwargs.get("temp_dir", None),
|
170
|
-
}
|
171
|
-
|
172
|
-
return ret
|
173
|
-
|
174
|
-
|
175
|
-
def lcc_setup(**kwargs) -> Dict:
|
176
|
-
"""
|
177
|
-
Setup to compute the cross-correlation between a laplace transformed target f
|
178
|
-
and laplace transformed template g defined as:
|
179
|
-
|
180
|
-
.. math::
|
181
|
-
|
182
|
-
\\mathcal{F}^{-1}(\\mathcal{F}(\\nabla^{2}f) \\cdot \\mathcal{F}(\\nabla^{2} g)^*)
|
183
|
-
|
184
|
-
|
185
|
-
See Also
|
186
|
-
--------
|
187
|
-
:py:meth:`corr_scoring`
|
188
|
-
:py:class:`tme.matching_optimization.LaplaceCrossCorrelation`
|
189
|
-
"""
|
190
|
-
kwargs["target"] = laplace(kwargs["target"], mode="wrap")
|
191
|
-
kwargs["template"] = laplace(kwargs["template"], mode="wrap")
|
192
|
-
return cc_setup(**kwargs)
|
193
|
-
|
194
|
-
|
195
|
-
def corr_setup(
|
196
|
-
rfftn: Callable,
|
197
|
-
irfftn: Callable,
|
198
|
-
template: NDArray,
|
199
|
-
template_mask: NDArray,
|
200
|
-
target: NDArray,
|
201
|
-
fast_shape: Tuple[int],
|
202
|
-
fast_ft_shape: Tuple[int],
|
203
|
-
shared_memory_handler: Callable,
|
204
|
-
callback_class: Callable,
|
205
|
-
callback_class_args: Dict,
|
206
|
-
**kwargs,
|
207
|
-
) -> Dict:
|
208
|
-
"""
|
209
|
-
Setup to compute a normalized cross-correlation between a target f and a template g.
|
210
|
-
|
211
|
-
.. math::
|
212
|
-
|
213
|
-
\\frac{CC(f,g) - \\overline{g} \\cdot CC(f, m)}
|
214
|
-
{(CC(f^2, m) - \\frac{CC(f, m)^2}{N_g}) \\cdot \\sigma_{g}}
|
215
|
-
|
216
|
-
Where:
|
217
|
-
|
218
|
-
.. math::
|
219
|
-
|
220
|
-
CC(f,g) = \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*)
|
221
|
-
|
222
|
-
and m is a mask with the same dimension as the template filled with ones.
|
223
|
-
|
224
|
-
References
|
225
|
-
----------
|
226
|
-
.. [1] J. P. Lewis, "Fast Normalized Cross-Correlation", Industrial Light
|
227
|
-
and Magic.
|
228
|
-
|
229
|
-
See Also
|
230
|
-
--------
|
231
|
-
:py:meth:`corr_scoring`
|
232
|
-
:py:class:`tme.matching_optimization.NormalizedCrossCorrelation`.
|
233
|
-
"""
|
234
|
-
real_dtype, complex_dtype = backend._float_dtype, backend._complex_dtype
|
235
|
-
target_pad = backend.topleft_pad(target, fast_shape)
|
236
|
-
|
237
|
-
# The exact composition of the denominator is debatable
|
238
|
-
# scikit-image match_template multiplies the running sum of the target
|
239
|
-
# with a scaling factor derived from the template. This is probably appropriate
|
240
|
-
# in pattern matching situations where the template exists in the target
|
241
|
-
window_template = backend.topleft_pad(template_mask, fast_shape)
|
242
|
-
ft_window_template = backend.preallocate_array(fast_ft_shape, complex_dtype)
|
243
|
-
rfftn(window_template, ft_window_template)
|
244
|
-
window_template = None
|
245
|
-
|
246
|
-
# Target and squared target window sums
|
247
|
-
ft_target = backend.preallocate_array(fast_ft_shape, complex_dtype)
|
248
|
-
ft_target2 = backend.preallocate_array(fast_ft_shape, complex_dtype)
|
249
|
-
denominator = backend.preallocate_array(fast_shape, real_dtype)
|
250
|
-
target_window_sum = backend.preallocate_array(fast_shape, real_dtype)
|
251
|
-
rfftn(target_pad, ft_target)
|
252
|
-
|
253
|
-
rfftn(backend.square(target_pad), ft_target2)
|
254
|
-
backend.multiply(ft_target2, ft_window_template, out=ft_target2)
|
255
|
-
irfftn(ft_target2, denominator)
|
256
|
-
|
257
|
-
backend.multiply(ft_target, ft_window_template, out=ft_window_template)
|
258
|
-
irfftn(ft_window_template, target_window_sum)
|
259
|
-
|
260
|
-
target_pad, ft_target2, ft_window_template = None, None, None
|
261
|
-
|
262
|
-
# Normalizing constants
|
263
|
-
n_observations = backend.sum(template_mask)
|
264
|
-
template_mean = backend.sum(backend.multiply(template, template_mask))
|
265
|
-
template_mean = backend.divide(template_mean, n_observations)
|
266
|
-
template_ssd = backend.sum(
|
267
|
-
backend.square(
|
268
|
-
backend.multiply(backend.subtract(template, template_mean), template_mask)
|
269
|
-
)
|
270
|
-
)
|
271
|
-
template_volume = np.prod(tuple(int(x) for x in template.shape))
|
272
|
-
backend.multiply(template, template_mask, out=template)
|
273
|
-
|
274
|
-
# Final numerator is score - numerator
|
275
|
-
numerator = backend.multiply(target_window_sum, template_mean)
|
276
|
-
|
277
|
-
# Compute denominator
|
278
|
-
backend.multiply(target_window_sum, target_window_sum, out=target_window_sum)
|
279
|
-
backend.divide(target_window_sum, template_volume, out=target_window_sum)
|
280
|
-
|
281
|
-
backend.subtract(denominator, target_window_sum, out=denominator)
|
282
|
-
backend.multiply(denominator, template_ssd, out=denominator)
|
283
|
-
backend.maximum(denominator, 0, out=denominator)
|
284
|
-
backend.sqrt(denominator, out=denominator)
|
285
|
-
target_window_sum = None
|
286
|
-
|
287
|
-
# Invert denominator to compute final score as product
|
288
|
-
denominator_mask = denominator > backend.eps(denominator.dtype)
|
289
|
-
inv_denominator = backend.preallocate_array(fast_shape, real_dtype)
|
290
|
-
inv_denominator[denominator_mask] = 1 / denominator[denominator_mask]
|
291
|
-
|
292
|
-
# Convert arrays used in subsequent fitting to SharedMemory objects
|
293
|
-
template_buffer = backend.arr_to_sharedarr(
|
294
|
-
arr=template, shared_memory_handler=shared_memory_handler
|
295
|
-
)
|
296
|
-
target_ft_buffer = backend.arr_to_sharedarr(
|
297
|
-
arr=ft_target, shared_memory_handler=shared_memory_handler
|
298
|
-
)
|
299
|
-
inv_denominator_buffer = backend.arr_to_sharedarr(
|
300
|
-
arr=inv_denominator, shared_memory_handler=shared_memory_handler
|
301
|
-
)
|
302
|
-
numerator_buffer = backend.arr_to_sharedarr(
|
303
|
-
arr=numerator, shared_memory_handler=shared_memory_handler
|
304
|
-
)
|
305
|
-
|
306
|
-
template_tuple = (template_buffer, template.shape, template.dtype)
|
307
|
-
target_ft_tuple = (target_ft_buffer, fast_ft_shape, complex_dtype)
|
308
|
-
|
309
|
-
inv_denominator_tuple = (inv_denominator_buffer, fast_shape, real_dtype)
|
310
|
-
numerator_tuple = (numerator_buffer, fast_shape, real_dtype)
|
311
|
-
|
312
|
-
ft_target, inv_denominator, numerator = None, None, None
|
313
|
-
|
314
|
-
ret = {
|
315
|
-
"template": template_tuple,
|
316
|
-
"ft_target": target_ft_tuple,
|
317
|
-
"inv_denominator": inv_denominator_tuple,
|
318
|
-
"numerator": numerator_tuple,
|
319
|
-
"targetshape": target.shape,
|
320
|
-
"templateshape": template.shape,
|
321
|
-
"fast_shape": fast_shape,
|
322
|
-
"fast_ft_shape": fast_ft_shape,
|
323
|
-
"callback_class": callback_class,
|
324
|
-
"callback_class_args": callback_class_args,
|
325
|
-
"template_mean": kwargs.get("template_mean", template_mean),
|
326
|
-
}
|
327
|
-
|
328
|
-
return ret
|
329
|
-
|
330
|
-
|
331
|
-
def cam_setup(**kwargs):
|
332
|
-
"""
|
333
|
-
Setup to compute a normalized cross-correlation between a target f and a template g
|
334
|
-
over their means. In practice this can be expressed like the cross-correlation
|
335
|
-
CORR defined in :py:meth:`corr_scoring`, so that:
|
336
|
-
|
337
|
-
Notes
|
338
|
-
-----
|
339
|
-
|
340
|
-
.. math::
|
341
|
-
|
342
|
-
\\text{CORR}(f-\\overline{f}, g - \\overline{g})
|
343
|
-
|
344
|
-
Where
|
345
|
-
|
346
|
-
.. math::
|
347
|
-
|
348
|
-
\\overline{f}, \\overline{g}
|
349
|
-
|
350
|
-
are the mean of f and g respectively.
|
351
|
-
|
352
|
-
References
|
353
|
-
----------
|
354
|
-
.. [1] J. P. Lewis, "Fast Normalized Cross-Correlation", Industrial Light
|
355
|
-
and Magic.
|
356
|
-
|
357
|
-
See Also
|
358
|
-
--------
|
359
|
-
:py:meth:`corr_scoring`.
|
360
|
-
:py:class:`tme.matching_optimization.NormalizedCrossCorrelationMean`.
|
361
|
-
"""
|
362
|
-
kwargs["template"] = kwargs["template"] - kwargs["template"].mean()
|
363
|
-
return corr_setup(**kwargs)
|
364
|
-
|
365
|
-
|
366
|
-
def flc_setup(
|
367
|
-
rfftn: Callable,
|
368
|
-
irfftn: Callable,
|
369
|
-
template: NDArray,
|
370
|
-
template_mask: NDArray,
|
371
|
-
target: NDArray,
|
372
|
-
fast_shape: Tuple[int],
|
373
|
-
fast_ft_shape: Tuple[int],
|
374
|
-
shared_memory_handler: Callable,
|
375
|
-
callback_class: Callable,
|
376
|
-
callback_class_args: Dict,
|
377
|
-
**kwargs,
|
378
|
-
) -> Dict:
|
379
|
-
"""
|
380
|
-
Setup to compute a normalized cross-correlation score of a target f a template g
|
381
|
-
and a mask m:
|
382
|
-
|
383
|
-
.. math::
|
384
|
-
|
385
|
-
\\frac{CC(f, \\frac{g*m - \\overline{g*m}}{\\sigma_{g*m}})}
|
386
|
-
{N_m * \\sqrt{
|
387
|
-
\\frac{CC(f^2, m)}{N_m} - (\\frac{CC(f, m)}{N_m})^2}
|
388
|
-
}
|
389
|
-
|
390
|
-
Where:
|
391
|
-
|
392
|
-
.. math::
|
393
|
-
|
394
|
-
CC(f,g) = \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*)
|
31
|
+
last_type : type
|
32
|
+
The type of the last exception.
|
33
|
+
last_value :
|
34
|
+
The value of the last exception.
|
35
|
+
last_traceback : traceback
|
36
|
+
Traceback call stack at the point where the Exception occured.
|
395
37
|
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
.. [1] W. Wan, S. Khavnekar, J. Wagner, P. Erdmann, and W. Baumeister
|
401
|
-
Microsc. Microanal. 26, 2516 (2020)
|
402
|
-
.. [2] T. Hrabe, Y. Chen, S. Pfeffer, L. Kuhn Cuellar, A.-V. Mangold,
|
403
|
-
and F. Förster, J. Struct. Biol. 178, 177 (2012).
|
404
|
-
|
405
|
-
See Also
|
406
|
-
--------
|
407
|
-
:py:meth:`flc_scoring`
|
408
|
-
"""
|
409
|
-
target_pad = backend.topleft_pad(target, fast_shape)
|
410
|
-
|
411
|
-
# Target and squared target window sums
|
412
|
-
ft_target = backend.preallocate_array(fast_ft_shape, backend._complex_dtype)
|
413
|
-
ft_target2 = backend.preallocate_array(fast_ft_shape, backend._complex_dtype)
|
414
|
-
rfftn(target_pad, ft_target)
|
415
|
-
backend.square(target_pad, out=target_pad)
|
416
|
-
rfftn(target_pad, ft_target2)
|
417
|
-
|
418
|
-
# Convert arrays used in subsequent fitting to SharedMemory objects
|
419
|
-
ft_target = backend.arr_to_sharedarr(
|
420
|
-
arr=ft_target, shared_memory_handler=shared_memory_handler
|
421
|
-
)
|
422
|
-
ft_target2 = backend.arr_to_sharedarr(
|
423
|
-
arr=ft_target2, shared_memory_handler=shared_memory_handler
|
424
|
-
)
|
425
|
-
|
426
|
-
normalize_under_mask(
|
427
|
-
template=template, mask=template_mask, mask_intensity=backend.sum(template_mask)
|
428
|
-
)
|
429
|
-
|
430
|
-
template_buffer = backend.arr_to_sharedarr(
|
431
|
-
arr=template, shared_memory_handler=shared_memory_handler
|
432
|
-
)
|
433
|
-
template_mask_buffer = backend.arr_to_sharedarr(
|
434
|
-
arr=template_mask, shared_memory_handler=shared_memory_handler
|
435
|
-
)
|
436
|
-
|
437
|
-
template_tuple = (template_buffer, template.shape, template.dtype)
|
438
|
-
template_mask_tuple = (
|
439
|
-
template_mask_buffer,
|
440
|
-
template_mask.shape,
|
441
|
-
template_mask.dtype,
|
442
|
-
)
|
443
|
-
|
444
|
-
target_ft_tuple = (ft_target, fast_ft_shape, backend._complex_dtype)
|
445
|
-
target_ft2_tuple = (ft_target2, fast_ft_shape, backend._complex_dtype)
|
446
|
-
|
447
|
-
ret = {
|
448
|
-
"template": template_tuple,
|
449
|
-
"template_mask": template_mask_tuple,
|
450
|
-
"ft_target": target_ft_tuple,
|
451
|
-
"ft_target2": target_ft2_tuple,
|
452
|
-
"targetshape": target.shape,
|
453
|
-
"templateshape": template.shape,
|
454
|
-
"fast_shape": fast_shape,
|
455
|
-
"fast_ft_shape": fast_ft_shape,
|
456
|
-
"callback_class": callback_class,
|
457
|
-
"callback_class_args": callback_class_args,
|
458
|
-
}
|
459
|
-
|
460
|
-
return ret
|
461
|
-
|
462
|
-
|
463
|
-
def flcSphericalMask_setup(
|
464
|
-
rfftn: Callable,
|
465
|
-
irfftn: Callable,
|
466
|
-
template: NDArray,
|
467
|
-
template_mask: NDArray,
|
468
|
-
target: NDArray,
|
469
|
-
fast_shape: Tuple[int],
|
470
|
-
fast_ft_shape: Tuple[int],
|
471
|
-
shared_memory_handler: Callable,
|
472
|
-
callback_class: Callable,
|
473
|
-
callback_class_args: Dict,
|
474
|
-
**kwargs,
|
475
|
-
) -> Dict:
|
476
|
-
"""
|
477
|
-
Like :py:meth:`flc_setup` but for rotation invariant masks. In such cases
|
478
|
-
the score can be computed quicker by not computing the fourier transforms
|
479
|
-
of the mask for each rotation.
|
480
|
-
|
481
|
-
.. math::
|
482
|
-
|
483
|
-
\\frac{CC(f, \\frac{g*m - \\overline{g*m}}{\\sigma_{g*m}})}
|
484
|
-
{N_m * \\sqrt{
|
485
|
-
\\frac{CC(f^2, m)}{N_m} - (\\frac{CC(f, m)}{N_m})^2}
|
486
|
-
}
|
487
|
-
|
488
|
-
Where:
|
489
|
-
|
490
|
-
.. math::
|
491
|
-
|
492
|
-
CC(f,g) = \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*)
|
493
|
-
|
494
|
-
and Nm is the number of voxels within the template mask m.
|
495
|
-
|
496
|
-
References
|
497
|
-
----------
|
498
|
-
.. [1] W. Wan, S. Khavnekar, J. Wagner, P. Erdmann, and W. Baumeister
|
499
|
-
Microsc. Microanal. 26, 2516 (2020)
|
500
|
-
.. [2] T. Hrabe, Y. Chen, S. Pfeffer, L. Kuhn Cuellar, A.-V. Mangold,
|
501
|
-
and F. Förster, J. Struct. Biol. 178, 177 (2012).
|
502
|
-
|
503
|
-
See Also
|
504
|
-
--------
|
505
|
-
:py:meth:`corr_scoring`
|
506
|
-
"""
|
507
|
-
real_dtype, complex_dtype = backend._float_dtype, backend._complex_dtype
|
508
|
-
target_pad = backend.topleft_pad(target, fast_shape)
|
509
|
-
|
510
|
-
# Target and squared target window sums
|
511
|
-
ft_target = backend.preallocate_array(fast_ft_shape, complex_dtype)
|
512
|
-
ft_template_mask = backend.preallocate_array(fast_ft_shape, complex_dtype)
|
513
|
-
ft_temp = backend.preallocate_array(fast_ft_shape, complex_dtype)
|
514
|
-
|
515
|
-
temp = backend.preallocate_array(fast_shape, real_dtype)
|
516
|
-
temp2 = backend.preallocate_array(fast_shape, real_dtype)
|
517
|
-
numerator = backend.preallocate_array(1, real_dtype)
|
518
|
-
|
519
|
-
n_observations, norm_func = backend.sum(template_mask), normalize_under_mask
|
520
|
-
if backend.datatype_bytes(template_mask.dtype) == 2:
|
521
|
-
norm_func = _normalize_under_mask_overflow_safe
|
522
|
-
n_observations = backend.sum(
|
523
|
-
backend.astype(template_mask, backend._overflow_safe_dtype)
|
524
|
-
)
|
525
|
-
|
526
|
-
template_mask_pad = backend.topleft_pad(template_mask, fast_shape)
|
527
|
-
rfftn(template_mask_pad, ft_template_mask)
|
528
|
-
|
529
|
-
# Denominator E(X^2) - E(X)^2
|
530
|
-
rfftn(backend.square(target_pad), ft_target)
|
531
|
-
backend.multiply(ft_target, ft_template_mask, out=ft_temp)
|
532
|
-
irfftn(ft_temp, temp2)
|
533
|
-
backend.divide(temp2, n_observations, out=temp2)
|
534
|
-
|
535
|
-
rfftn(target_pad, ft_target)
|
536
|
-
backend.multiply(ft_target, ft_template_mask, out=ft_temp)
|
537
|
-
irfftn(ft_temp, temp)
|
538
|
-
backend.divide(temp, n_observations, out=temp)
|
539
|
-
backend.square(temp, out=temp)
|
540
|
-
|
541
|
-
backend.subtract(temp2, temp, out=temp)
|
542
|
-
|
543
|
-
backend.maximum(temp, 0.0, out=temp)
|
544
|
-
backend.sqrt(temp, out=temp)
|
545
|
-
backend.multiply(temp, n_observations, out=temp)
|
546
|
-
|
547
|
-
backend.fill(temp2, 0)
|
548
|
-
nonzero_indices = temp > backend.eps(real_dtype)
|
549
|
-
temp2[nonzero_indices] = 1 / temp[nonzero_indices]
|
550
|
-
|
551
|
-
norm_func(template=template, mask=template_mask, mask_intensity=n_observations)
|
552
|
-
|
553
|
-
template_buffer = backend.arr_to_sharedarr(
|
554
|
-
arr=template, shared_memory_handler=shared_memory_handler
|
555
|
-
)
|
556
|
-
template_mask_buffer = backend.arr_to_sharedarr(
|
557
|
-
arr=template_mask, shared_memory_handler=shared_memory_handler
|
558
|
-
)
|
559
|
-
target_ft_buffer = backend.arr_to_sharedarr(
|
560
|
-
arr=ft_target, shared_memory_handler=shared_memory_handler
|
561
|
-
)
|
562
|
-
inv_denominator_buffer = backend.arr_to_sharedarr(
|
563
|
-
arr=temp2, shared_memory_handler=shared_memory_handler
|
564
|
-
)
|
565
|
-
numerator_buffer = backend.arr_to_sharedarr(
|
566
|
-
arr=numerator, shared_memory_handler=shared_memory_handler
|
567
|
-
)
|
568
|
-
|
569
|
-
template_tuple = (template_buffer, template.shape, template.dtype)
|
570
|
-
template_mask_tuple = (template_mask_buffer, template.shape, template_mask.dtype)
|
571
|
-
target_ft_tuple = (target_ft_buffer, fast_ft_shape, complex_dtype)
|
572
|
-
|
573
|
-
inv_denominator_tuple = (inv_denominator_buffer, fast_shape, real_dtype)
|
574
|
-
numerator_tuple = (numerator_buffer, (1,), real_dtype)
|
575
|
-
|
576
|
-
ret = {
|
577
|
-
"template": template_tuple,
|
578
|
-
"template_mask": template_mask_tuple,
|
579
|
-
"ft_target": target_ft_tuple,
|
580
|
-
"inv_denominator": inv_denominator_tuple,
|
581
|
-
"numerator": numerator_tuple,
|
582
|
-
"targetshape": target.shape,
|
583
|
-
"templateshape": template.shape,
|
584
|
-
"fast_shape": fast_shape,
|
585
|
-
"fast_ft_shape": fast_ft_shape,
|
586
|
-
"callback_class": callback_class,
|
587
|
-
"callback_class_args": callback_class_args,
|
588
|
-
}
|
589
|
-
|
590
|
-
return ret
|
591
|
-
|
592
|
-
|
593
|
-
def mcc_setup(
|
594
|
-
rfftn: Callable,
|
595
|
-
irfftn: Callable,
|
596
|
-
template: NDArray,
|
597
|
-
template_mask: NDArray,
|
598
|
-
target: NDArray,
|
599
|
-
target_mask: NDArray,
|
600
|
-
fast_shape: Tuple[int],
|
601
|
-
fast_ft_shape: Tuple[int],
|
602
|
-
shared_memory_handler: Callable,
|
603
|
-
callback_class: Callable,
|
604
|
-
callback_class_args: Dict,
|
605
|
-
**kwargs,
|
606
|
-
) -> Dict:
|
607
|
-
"""
|
608
|
-
Setup to compute a normalized cross-correlation score with masks
|
609
|
-
for both template and target.
|
610
|
-
|
611
|
-
.. math::
|
612
|
-
|
613
|
-
\\frac{
|
614
|
-
CC(f, g) - \\frac{CC(f, m) \\cdot CC(t, g)}{CC(t, m)}
|
615
|
-
}{
|
616
|
-
\\sqrt{
|
617
|
-
(CC(f ^ 2, m) - \\frac{CC(f, m) ^ 2}{CC(t, m)}) \\cdot
|
618
|
-
(CC(t, g^2) - \\frac{CC(t, g) ^ 2}{CC(t, m)})
|
619
|
-
}
|
620
|
-
}
|
621
|
-
|
622
|
-
Where:
|
623
|
-
|
624
|
-
.. math::
|
625
|
-
|
626
|
-
CC(f,g) = \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*)
|
627
|
-
|
628
|
-
|
629
|
-
References
|
630
|
-
----------
|
631
|
-
.. [1] Masked FFT registration, Dirk Padfield, CVPR 2010 conference
|
632
|
-
|
633
|
-
See Also
|
634
|
-
--------
|
635
|
-
:py:meth:`mcc_scoring`
|
636
|
-
:py:class:`tme.matching_optimization.MaskedCrossCorrelation`
|
637
|
-
"""
|
638
|
-
real_dtype, complex_dtype = backend._float_dtype, backend._complex_dtype
|
639
|
-
target = backend.multiply(target, target_mask > 0, out=target)
|
640
|
-
|
641
|
-
target_pad = backend.topleft_pad(target, fast_shape)
|
642
|
-
target_mask_pad = backend.topleft_pad(target_mask, fast_shape)
|
643
|
-
|
644
|
-
target_ft = backend.preallocate_array(fast_ft_shape, complex_dtype)
|
645
|
-
rfftn(target_pad, target_ft)
|
646
|
-
target_ft_buffer = backend.arr_to_sharedarr(
|
647
|
-
arr=target_ft, shared_memory_handler=shared_memory_handler
|
648
|
-
)
|
649
|
-
|
650
|
-
target_ft2 = backend.preallocate_array(fast_ft_shape, complex_dtype)
|
651
|
-
rfftn(backend.square(target_pad), target_ft2)
|
652
|
-
target_ft2_buffer = backend.arr_to_sharedarr(
|
653
|
-
arr=target_ft2, shared_memory_handler=shared_memory_handler
|
654
|
-
)
|
655
|
-
|
656
|
-
target_mask_ft = backend.preallocate_array(fast_ft_shape, complex_dtype)
|
657
|
-
rfftn(target_mask_pad, target_mask_ft)
|
658
|
-
target_mask_ft_buffer = backend.arr_to_sharedarr(
|
659
|
-
arr=target_mask_ft, shared_memory_handler=shared_memory_handler
|
660
|
-
)
|
661
|
-
|
662
|
-
template_buffer = backend.arr_to_sharedarr(
|
663
|
-
arr=template, shared_memory_handler=shared_memory_handler
|
664
|
-
)
|
665
|
-
template_mask_buffer = backend.arr_to_sharedarr(
|
666
|
-
arr=template_mask, shared_memory_handler=shared_memory_handler
|
667
|
-
)
|
668
|
-
|
669
|
-
template_tuple = (template_buffer, template.shape, template.dtype)
|
670
|
-
template_mask_tuple = (template_mask_buffer, template.shape, template_mask.dtype)
|
671
|
-
|
672
|
-
target_ft_tuple = (target_ft_buffer, fast_ft_shape, complex_dtype)
|
673
|
-
target_ft2_tuple = (target_ft2_buffer, fast_ft_shape, complex_dtype)
|
674
|
-
target_mask_ft_tuple = (target_mask_ft_buffer, fast_ft_shape, complex_dtype)
|
675
|
-
|
676
|
-
ret = {
|
677
|
-
"template": template_tuple,
|
678
|
-
"template_mask": template_mask_tuple,
|
679
|
-
"ft_target": target_ft_tuple,
|
680
|
-
"ft_target2": target_ft2_tuple,
|
681
|
-
"ft_target_mask": target_mask_ft_tuple,
|
682
|
-
"targetshape": target.shape,
|
683
|
-
"templateshape": template.shape,
|
684
|
-
"fast_shape": fast_shape,
|
685
|
-
"fast_ft_shape": fast_ft_shape,
|
686
|
-
"callback_class": callback_class,
|
687
|
-
"callback_class_args": callback_class_args,
|
688
|
-
}
|
689
|
-
|
690
|
-
return ret
|
691
|
-
|
692
|
-
|
693
|
-
def corr_scoring(
|
694
|
-
template: Tuple[type, Tuple[int], type],
|
695
|
-
ft_target: Tuple[type, Tuple[int], type],
|
696
|
-
inv_denominator: Tuple[type, Tuple[int], type],
|
697
|
-
numerator: Tuple[type, Tuple[int], type],
|
698
|
-
template_filter: Tuple[type, Tuple[int], type],
|
699
|
-
targetshape: Tuple[int],
|
700
|
-
templateshape: Tuple[int],
|
701
|
-
fast_shape: Tuple[int],
|
702
|
-
fast_ft_shape: Tuple[int],
|
703
|
-
rotations: NDArray,
|
704
|
-
callback_class: CallbackClass,
|
705
|
-
callback_class_args: Dict,
|
706
|
-
interpolation_order: int,
|
707
|
-
convolution_mode: str = "full",
|
708
|
-
**kwargs,
|
709
|
-
) -> CallbackClass:
|
710
|
-
"""
|
711
|
-
Calculates a normalized cross-correlation between a target f and a template g.
|
712
|
-
|
713
|
-
.. math::
|
714
|
-
|
715
|
-
(CC(f,g) - numerator) \\cdot inv\\_denominator
|
716
|
-
|
717
|
-
Parameters
|
718
|
-
----------
|
719
|
-
template : Tuple[type, Tuple[int], type]
|
720
|
-
Tuple containing a pointer to the template data, its shape, and its datatype.
|
721
|
-
ft_target : Tuple[type, Tuple[int], type]
|
722
|
-
Tuple containing a pointer to the fourier tranform of the target,
|
723
|
-
its shape, and its datatype.
|
724
|
-
inv_denominator : Tuple[type, Tuple[int], type]
|
725
|
-
Tuple containing a pointer to the inverse denominator data, its shape, and its
|
726
|
-
datatype.
|
727
|
-
numerator : Tuple[type, Tuple[int], type]
|
728
|
-
Tuple containing a pointer to the numerator data, its shape, and its datatype.
|
729
|
-
fast_shape : Tuple[int]
|
730
|
-
The shape for fast Fourier transform.
|
731
|
-
fast_ft_shape : Tuple[int]
|
732
|
-
The shape for fast Fourier transform of the target.
|
733
|
-
rotations : NDArray
|
734
|
-
Array containing the rotation matrices to be applied on the template.
|
735
|
-
real_dtype : type
|
736
|
-
Data type for the real part of the array.
|
737
|
-
complex_dtype : type
|
738
|
-
Data type for the complex part of the array.
|
739
|
-
callback_class : CallbackClass
|
740
|
-
A callable class or function for processing the results after each
|
741
|
-
rotation.
|
742
|
-
callback_class_args : Dict
|
743
|
-
Dictionary of arguments to be passed to the callback class if it's
|
744
|
-
instantiable.
|
745
|
-
interpolation_order : int
|
746
|
-
The order of interpolation to be used while rotating the template.
|
747
|
-
**kwargs :
|
748
|
-
Additional arguments to be passed to the function.
|
749
|
-
|
750
|
-
Returns
|
751
|
-
-------
|
752
|
-
CallbackClass
|
753
|
-
If callback_class was provided an instance of callback_class otherwise None.
|
754
|
-
|
755
|
-
See Also
|
756
|
-
--------
|
757
|
-
:py:meth:`cc_setup`
|
758
|
-
:py:meth:`corr_setup`
|
759
|
-
:py:meth:`cam_setup`
|
760
|
-
:py:meth:`flcSphericalMask_setup`
|
761
|
-
"""
|
762
|
-
real_dtype, complex_dtype = backend._float_dtype, backend._complex_dtype
|
763
|
-
|
764
|
-
callback = callback_class
|
765
|
-
if callback_class is not None and isinstance(callback_class, type):
|
766
|
-
callback = callback_class(**callback_class_args)
|
767
|
-
|
768
|
-
template_buffer, template_shape, template_dtype = template
|
769
|
-
template = backend.sharedarr_to_arr(template_buffer, template_shape, template_dtype)
|
770
|
-
ft_target = backend.sharedarr_to_arr(*ft_target)
|
771
|
-
inv_denominator = backend.sharedarr_to_arr(*inv_denominator)
|
772
|
-
numerator = backend.sharedarr_to_arr(*numerator)
|
773
|
-
template_filter = backend.sharedarr_to_arr(*template_filter)
|
774
|
-
|
775
|
-
norm_func = normalize_under_mask
|
776
|
-
norm_template, template_mask, mask_sum = False, 1, 1
|
777
|
-
if "template_mask" in kwargs:
|
778
|
-
norm_template = True
|
779
|
-
template_mask = backend.sharedarr_to_arr(*kwargs["template_mask"])
|
780
|
-
mask_sum = backend.sum(template_mask)
|
781
|
-
if backend.datatype_bytes(template_mask.dtype) == 2:
|
782
|
-
norm_func = _normalize_under_mask_overflow_safe
|
783
|
-
mask_sum = backend.sum(
|
784
|
-
backend.astype(template_mask, backend._overflow_safe_dtype)
|
785
|
-
)
|
786
|
-
|
787
|
-
norm_template = conditional_execute(norm_func, norm_template)
|
788
|
-
norm_numerator = (backend.sum(numerator) != 0) & (backend.size(numerator) != 1)
|
789
|
-
norm_func_numerator = conditional_execute(backend.subtract, norm_numerator)
|
790
|
-
|
791
|
-
norm_denominator = (backend.sum(inv_denominator) != 1) & (
|
792
|
-
backend.size(inv_denominator) != 1
|
793
|
-
)
|
794
|
-
norm_func_denominator = conditional_execute(backend.multiply, norm_denominator)
|
795
|
-
callback_func = conditional_execute(callback, callback_class is not None)
|
796
|
-
template_filter_func = conditional_execute(
|
797
|
-
apply_filter, backend.size(template_filter) != 1
|
798
|
-
)
|
799
|
-
|
800
|
-
arr = backend.preallocate_array(fast_shape, real_dtype)
|
801
|
-
ft_temp = backend.preallocate_array(fast_ft_shape, complex_dtype)
|
802
|
-
|
803
|
-
rfftn, irfftn = backend.build_fft(
|
804
|
-
fast_shape=fast_shape,
|
805
|
-
fast_ft_shape=fast_ft_shape,
|
806
|
-
real_dtype=real_dtype,
|
807
|
-
complex_dtype=complex_dtype,
|
808
|
-
fftargs=kwargs.get("fftargs", {}),
|
809
|
-
temp_real=arr,
|
810
|
-
temp_fft=ft_temp,
|
811
|
-
)
|
812
|
-
|
813
|
-
unpadded_slice = tuple(slice(0, stop) for stop in template.shape)
|
814
|
-
for index in range(rotations.shape[0]):
|
815
|
-
rotation = rotations[index]
|
816
|
-
backend.fill(arr, 0)
|
817
|
-
backend.rotate_array(
|
818
|
-
arr=template,
|
819
|
-
rotation_matrix=rotation,
|
820
|
-
out=arr,
|
821
|
-
use_geometric_center=True,
|
822
|
-
order=interpolation_order,
|
823
|
-
)
|
824
|
-
norm_template(arr[unpadded_slice], template_mask, mask_sum)
|
825
|
-
|
826
|
-
rfftn(arr, ft_temp)
|
827
|
-
template_filter_func(ft_template=ft_temp, template_filter=template_filter)
|
828
|
-
backend.multiply(ft_target, ft_temp, out=ft_temp)
|
829
|
-
irfftn(ft_temp, arr)
|
830
|
-
|
831
|
-
norm_func_numerator(arr, numerator, out=arr)
|
832
|
-
norm_func_denominator(arr, inv_denominator, out=arr)
|
833
|
-
|
834
|
-
callback_func(
|
835
|
-
arr,
|
836
|
-
rotation_matrix=rotation,
|
837
|
-
rotation_index=index,
|
838
|
-
**callback_class_args,
|
839
|
-
)
|
840
|
-
|
841
|
-
return callback
|
842
|
-
|
843
|
-
|
844
|
-
def flc_scoring(
|
845
|
-
template: Tuple[type, Tuple[int], type],
|
846
|
-
template_mask: Tuple[type, Tuple[int], type],
|
847
|
-
ft_target: Tuple[type, Tuple[int], type],
|
848
|
-
ft_target2: Tuple[type, Tuple[int], type],
|
849
|
-
template_filter: Tuple[type, Tuple[int], type],
|
850
|
-
targetshape: Tuple[int],
|
851
|
-
templateshape: Tuple[int],
|
852
|
-
fast_shape: Tuple[int],
|
853
|
-
fast_ft_shape: Tuple[int],
|
854
|
-
rotations: NDArray,
|
855
|
-
callback_class: CallbackClass,
|
856
|
-
callback_class_args: Dict,
|
857
|
-
interpolation_order: int,
|
858
|
-
**kwargs,
|
859
|
-
) -> CallbackClass:
|
860
|
-
"""
|
861
|
-
Computes a normalized cross-correlation score of a target f a template g
|
862
|
-
and a mask m:
|
863
|
-
|
864
|
-
.. math::
|
865
|
-
|
866
|
-
\\frac{CC(f, \\frac{g*m - \\overline{g*m}}{\\sigma_{g*m}})}
|
867
|
-
{N_m * \\sqrt{
|
868
|
-
\\frac{CC(f^2, m)}{N_m} - (\\frac{CC(f, m)}{N_m})^2}
|
869
|
-
}
|
870
|
-
|
871
|
-
Where:
|
872
|
-
|
873
|
-
.. math::
|
874
|
-
|
875
|
-
CC(f,g) = \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*)
|
876
|
-
|
877
|
-
and Nm is the number of voxels within the template mask m.
|
878
|
-
|
879
|
-
References
|
880
|
-
----------
|
881
|
-
.. [1] W. Wan, S. Khavnekar, J. Wagner, P. Erdmann, and W. Baumeister
|
882
|
-
Microsc. Microanal. 26, 2516 (2020)
|
883
|
-
.. [2] T. Hrabe, Y. Chen, S. Pfeffer, L. Kuhn Cuellar, A.-V. Mangold,
|
884
|
-
and F. Förster, J. Struct. Biol. 178, 177 (2012).
|
885
|
-
"""
|
886
|
-
real_dtype, complex_dtype = backend._float_dtype, backend._complex_dtype
|
887
|
-
|
888
|
-
callback = callback_class
|
889
|
-
if callback_class is not None and isinstance(callback_class, type):
|
890
|
-
callback = callback_class(**callback_class_args)
|
891
|
-
|
892
|
-
template = backend.sharedarr_to_arr(*template)
|
893
|
-
template_mask = backend.sharedarr_to_arr(*template_mask)
|
894
|
-
ft_target = backend.sharedarr_to_arr(*ft_target)
|
895
|
-
ft_target2 = backend.sharedarr_to_arr(*ft_target2)
|
896
|
-
template_filter = backend.sharedarr_to_arr(*template_filter)
|
897
|
-
|
898
|
-
arr = backend.preallocate_array(fast_shape, real_dtype)
|
899
|
-
temp = backend.preallocate_array(fast_shape, real_dtype)
|
900
|
-
temp2 = backend.preallocate_array(fast_shape, real_dtype)
|
901
|
-
|
902
|
-
ft_temp = backend.preallocate_array(fast_ft_shape, complex_dtype)
|
903
|
-
ft_denom = backend.preallocate_array(fast_ft_shape, complex_dtype)
|
904
|
-
|
905
|
-
rfftn, irfftn = backend.build_fft(
|
906
|
-
fast_shape=fast_shape,
|
907
|
-
fast_ft_shape=fast_ft_shape,
|
908
|
-
real_dtype=real_dtype,
|
909
|
-
complex_dtype=complex_dtype,
|
910
|
-
fftargs=kwargs.get("fftargs", {}),
|
911
|
-
temp_real=arr,
|
912
|
-
temp_fft=ft_temp,
|
913
|
-
)
|
914
|
-
eps = backend.eps(real_dtype)
|
915
|
-
template_filter_func = conditional_execute(
|
916
|
-
apply_filter, backend.size(template_filter) != 1
|
917
|
-
)
|
918
|
-
callback_func = conditional_execute(callback, callback_class is not None)
|
919
|
-
|
920
|
-
unpadded_slice = tuple(slice(0, stop) for stop in template.shape)
|
921
|
-
for index in range(rotations.shape[0]):
|
922
|
-
rotation = rotations[index]
|
923
|
-
backend.fill(arr, 0)
|
924
|
-
backend.fill(temp, 0)
|
925
|
-
backend.rotate_array(
|
926
|
-
arr=template,
|
927
|
-
arr_mask=template_mask,
|
928
|
-
rotation_matrix=rotation,
|
929
|
-
out=arr,
|
930
|
-
out_mask=temp,
|
931
|
-
use_geometric_center=True,
|
932
|
-
order=interpolation_order,
|
933
|
-
)
|
934
|
-
# Given the amount of FFTs, might aswell normalize properly
|
935
|
-
n_observations = backend.sum(temp)
|
936
|
-
|
937
|
-
normalize_under_mask(
|
938
|
-
template=arr[unpadded_slice],
|
939
|
-
mask=temp[unpadded_slice],
|
940
|
-
mask_intensity=n_observations,
|
941
|
-
)
|
942
|
-
|
943
|
-
rfftn(temp, ft_temp)
|
944
|
-
|
945
|
-
backend.multiply(ft_target, ft_temp, out=ft_denom)
|
946
|
-
irfftn(ft_denom, temp)
|
947
|
-
backend.divide(temp, n_observations, out=temp)
|
948
|
-
backend.square(temp, out=temp)
|
949
|
-
|
950
|
-
backend.multiply(ft_target2, ft_temp, out=ft_denom)
|
951
|
-
irfftn(ft_denom, temp2)
|
952
|
-
backend.divide(temp2, n_observations, out=temp2)
|
953
|
-
|
954
|
-
backend.subtract(temp2, temp, out=temp)
|
955
|
-
backend.maximum(temp, 0.0, out=temp)
|
956
|
-
backend.sqrt(temp, out=temp)
|
957
|
-
backend.multiply(temp, n_observations, out=temp)
|
958
|
-
|
959
|
-
rfftn(arr, ft_temp)
|
960
|
-
template_filter_func(ft_template=ft_temp, template_filter=template_filter)
|
961
|
-
backend.multiply(ft_target, ft_temp, out=ft_temp)
|
962
|
-
irfftn(ft_temp, arr)
|
963
|
-
|
964
|
-
tol = eps
|
965
|
-
# tol = 1e3 * eps * backend.max(backend.abs(temp))
|
966
|
-
nonzero_indices = temp > tol
|
967
|
-
backend.fill(temp2, 0)
|
968
|
-
temp2[nonzero_indices] = arr[nonzero_indices] / temp[nonzero_indices]
|
969
|
-
|
970
|
-
callback_func(
|
971
|
-
temp2,
|
972
|
-
rotation_matrix=rotation,
|
973
|
-
rotation_index=index,
|
974
|
-
**callback_class_args,
|
975
|
-
)
|
976
|
-
|
977
|
-
return callback
|
978
|
-
|
979
|
-
|
980
|
-
def flc_scoring2(
|
981
|
-
template: Tuple[type, Tuple[int], type],
|
982
|
-
template_mask: Tuple[type, Tuple[int], type],
|
983
|
-
ft_target: Tuple[type, Tuple[int], type],
|
984
|
-
ft_target2: Tuple[type, Tuple[int], type],
|
985
|
-
template_filter: Tuple[type, Tuple[int], type],
|
986
|
-
targetshape: Tuple[int],
|
987
|
-
templateshape: Tuple[int],
|
988
|
-
fast_shape: Tuple[int],
|
989
|
-
fast_ft_shape: Tuple[int],
|
990
|
-
rotations: NDArray,
|
991
|
-
callback_class: CallbackClass,
|
992
|
-
callback_class_args: Dict,
|
993
|
-
interpolation_order: int,
|
994
|
-
**kwargs,
|
995
|
-
) -> CallbackClass:
|
996
|
-
"""
|
997
|
-
Computes a normalized cross-correlation score of a target f a template g
|
998
|
-
and a mask m:
|
999
|
-
|
1000
|
-
.. math::
|
1001
|
-
|
1002
|
-
\\frac{CC(f, \\frac{g*m - \\overline{g*m}}{\\sigma_{g*m}})}
|
1003
|
-
{N_m * \\sqrt{
|
1004
|
-
\\frac{CC(f^2, m)}{N_m} - (\\frac{CC(f, m)}{N_m})^2}
|
1005
|
-
}
|
1006
|
-
|
1007
|
-
Where:
|
1008
|
-
|
1009
|
-
.. math::
|
1010
|
-
|
1011
|
-
CC(f,g) = \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*)
|
1012
|
-
|
1013
|
-
and Nm is the number of voxels within the template mask m.
|
1014
|
-
|
1015
|
-
References
|
1016
|
-
----------
|
1017
|
-
.. [1] W. Wan, S. Khavnekar, J. Wagner, P. Erdmann, and W. Baumeister
|
1018
|
-
Microsc. Microanal. 26, 2516 (2020)
|
1019
|
-
.. [2] T. Hrabe, Y. Chen, S. Pfeffer, L. Kuhn Cuellar, A.-V. Mangold,
|
1020
|
-
and F. Förster, J. Struct. Biol. 178, 177 (2012).
|
1021
|
-
"""
|
1022
|
-
real_dtype, complex_dtype = backend._float_dtype, backend._complex_dtype
|
1023
|
-
callback = callback_class
|
1024
|
-
if callback_class is not None and isinstance(callback_class, type):
|
1025
|
-
callback = callback_class(**callback_class_args)
|
1026
|
-
|
1027
|
-
# Retrieve objects from shared memory
|
1028
|
-
template = backend.sharedarr_to_arr(*template)
|
1029
|
-
template_mask = backend.sharedarr_to_arr(*template_mask)
|
1030
|
-
ft_target = backend.sharedarr_to_arr(*ft_target)
|
1031
|
-
ft_target2 = backend.sharedarr_to_arr(*ft_target2)
|
1032
|
-
template_filter = backend.sharedarr_to_arr(*template_filter)
|
1033
|
-
|
1034
|
-
arr = backend.preallocate_array(fast_shape, real_dtype)
|
1035
|
-
temp = backend.preallocate_array(fast_shape, real_dtype)
|
1036
|
-
temp2 = backend.preallocate_array(fast_shape, real_dtype)
|
1037
|
-
|
1038
|
-
ft_temp = backend.preallocate_array(fast_ft_shape, complex_dtype)
|
1039
|
-
ft_denom = backend.preallocate_array(fast_ft_shape, complex_dtype)
|
1040
|
-
|
1041
|
-
eps = backend.eps(real_dtype)
|
1042
|
-
template_filter_func = conditional_execute(
|
1043
|
-
apply_filter, backend.size(template_filter) != 1
|
1044
|
-
)
|
1045
|
-
callback_func = conditional_execute(callback, callback_class is not None)
|
1046
|
-
|
1047
|
-
squeeze_axis = tuple(i for i, x in enumerate(template.shape) if x == 1)
|
1048
|
-
squeeze = tuple(
|
1049
|
-
slice(0, stop) if i not in squeeze_axis else 0
|
1050
|
-
for i, stop in enumerate(template.shape)
|
1051
|
-
)
|
1052
|
-
squeeze_fast = tuple(
|
1053
|
-
slice(0, stop) if i not in squeeze_axis else 0
|
1054
|
-
for i, stop in enumerate(fast_shape)
|
1055
|
-
)
|
1056
|
-
squeeze_fast_ft = tuple(
|
1057
|
-
slice(0, stop) if i not in squeeze_axis else 0
|
1058
|
-
for i, stop in enumerate(fast_ft_shape)
|
1059
|
-
)
|
1060
|
-
|
1061
|
-
rfftn, irfftn = backend.build_fft(
|
1062
|
-
fast_shape=temp[squeeze_fast].shape,
|
1063
|
-
fast_ft_shape=fast_ft_shape,
|
1064
|
-
real_dtype=real_dtype,
|
1065
|
-
complex_dtype=complex_dtype,
|
1066
|
-
fftargs=kwargs.get("fftargs", {}),
|
1067
|
-
inverse_fast_shape=fast_shape,
|
1068
|
-
temp_real=arr[squeeze_fast],
|
1069
|
-
temp_fft=ft_temp,
|
1070
|
-
)
|
1071
|
-
for index in range(rotations.shape[0]):
|
1072
|
-
rotation = rotations[index]
|
1073
|
-
backend.fill(arr, 0)
|
1074
|
-
backend.fill(temp, 0)
|
1075
|
-
backend.rotate_array(
|
1076
|
-
arr=template[squeeze],
|
1077
|
-
arr_mask=template_mask[squeeze],
|
1078
|
-
rotation_matrix=rotation,
|
1079
|
-
out=arr[squeeze],
|
1080
|
-
out_mask=temp[squeeze],
|
1081
|
-
use_geometric_center=True,
|
1082
|
-
order=interpolation_order,
|
1083
|
-
)
|
1084
|
-
# Given the amount of FFTs, might aswell normalize properly
|
1085
|
-
n_observations = backend.sum(temp)
|
1086
|
-
|
1087
|
-
normalize_under_mask(
|
1088
|
-
template=arr[squeeze],
|
1089
|
-
mask=temp[squeeze],
|
1090
|
-
mask_intensity=n_observations,
|
1091
|
-
)
|
1092
|
-
rfftn(temp[squeeze_fast], ft_temp[squeeze_fast_ft])
|
1093
|
-
|
1094
|
-
backend.multiply(ft_target, ft_temp[squeeze_fast_ft], out=ft_denom)
|
1095
|
-
irfftn(ft_denom, temp)
|
1096
|
-
backend.divide(temp, n_observations, out=temp)
|
1097
|
-
backend.square(temp, out=temp)
|
1098
|
-
|
1099
|
-
backend.multiply(ft_target2, ft_temp[squeeze_fast_ft], out=ft_denom)
|
1100
|
-
irfftn(ft_denom, temp2)
|
1101
|
-
backend.divide(temp2, n_observations, out=temp2)
|
1102
|
-
|
1103
|
-
backend.subtract(temp2, temp, out=temp)
|
1104
|
-
backend.maximum(temp, 0.0, out=temp)
|
1105
|
-
backend.sqrt(temp, out=temp)
|
1106
|
-
backend.multiply(temp, n_observations, out=temp)
|
1107
|
-
|
1108
|
-
rfftn(arr[squeeze_fast], ft_temp[squeeze_fast_ft])
|
1109
|
-
template_filter_func(ft_template=ft_temp, template_filter=template_filter)
|
1110
|
-
backend.multiply(ft_target, ft_temp[squeeze_fast_ft], out=ft_denom)
|
1111
|
-
irfftn(ft_denom, arr)
|
1112
|
-
|
1113
|
-
nonzero_indices = temp > eps
|
1114
|
-
backend.fill(temp2, 0)
|
1115
|
-
temp2[nonzero_indices] = arr[nonzero_indices] / temp[nonzero_indices]
|
1116
|
-
|
1117
|
-
callback_func(
|
1118
|
-
temp2,
|
1119
|
-
rotation_matrix=rotation,
|
1120
|
-
rotation_index=index,
|
1121
|
-
**callback_class_args,
|
1122
|
-
)
|
1123
|
-
return callback
|
1124
|
-
|
1125
|
-
|
1126
|
-
def mcc_scoring(
|
1127
|
-
template: Tuple[type, Tuple[int], type],
|
1128
|
-
template_mask: Tuple[type, Tuple[int], type],
|
1129
|
-
ft_target: Tuple[type, Tuple[int], type],
|
1130
|
-
ft_target2: Tuple[type, Tuple[int], type],
|
1131
|
-
ft_target_mask: Tuple[type, Tuple[int], type],
|
1132
|
-
template_filter: Tuple[type, Tuple[int], type],
|
1133
|
-
targetshape: Tuple[int],
|
1134
|
-
templateshape: Tuple[int],
|
1135
|
-
fast_shape: Tuple[int],
|
1136
|
-
fast_ft_shape: Tuple[int],
|
1137
|
-
rotations: NDArray,
|
1138
|
-
callback_class: CallbackClass,
|
1139
|
-
callback_class_args: type,
|
1140
|
-
interpolation_order: int,
|
1141
|
-
overlap_ratio: float = 0.3,
|
1142
|
-
**kwargs,
|
1143
|
-
) -> CallbackClass:
|
1144
|
-
"""
|
1145
|
-
Computes a cross-correlation score with masks for both template and target.
|
1146
|
-
|
1147
|
-
.. math::
|
1148
|
-
|
1149
|
-
\\frac{
|
1150
|
-
CC(f, g) - \\frac{CC(f, m) \\cdot CC(t, g)}{CC(t, m)}
|
1151
|
-
}{
|
1152
|
-
\\sqrt{
|
1153
|
-
(CC(f ^ 2, m) - \\frac{CC(f, m) ^ 2}{CC(t, m)}) \\cdot
|
1154
|
-
(CC(t, g^2) - \\frac{CC(t, g) ^ 2}{CC(t, m)})
|
1155
|
-
}
|
1156
|
-
}
|
1157
|
-
|
1158
|
-
Where:
|
1159
|
-
|
1160
|
-
.. math::
|
1161
|
-
|
1162
|
-
CC(f,g) = \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*)
|
1163
|
-
|
1164
|
-
|
1165
|
-
References
|
1166
|
-
----------
|
1167
|
-
.. [1] Masked FFT registration, Dirk Padfield, CVPR 2010 conference
|
1168
|
-
.. [2] https://scikit-image.org/docs/stable/api/skimage.registration.html
|
1169
|
-
|
1170
|
-
See Also
|
1171
|
-
--------
|
1172
|
-
:py:class:`tme.matching_optimization.MaskedCrossCorrelation`
|
38
|
+
Raises
|
39
|
+
------
|
40
|
+
Exception
|
41
|
+
Re-raises the last exception.
|
1173
42
|
"""
|
1174
|
-
|
1175
|
-
|
1176
|
-
|
1177
|
-
|
1178
|
-
|
1179
|
-
# Retrieve objects from shared memory
|
1180
|
-
template_buffer, template_shape, template_dtype = template
|
1181
|
-
template = backend.sharedarr_to_arr(*template)
|
1182
|
-
target_ft = backend.sharedarr_to_arr(*ft_target)
|
1183
|
-
target_ft2 = backend.sharedarr_to_arr(*ft_target2)
|
1184
|
-
template_mask = backend.sharedarr_to_arr(*template_mask)
|
1185
|
-
target_mask_ft = backend.sharedarr_to_arr(*ft_target_mask)
|
1186
|
-
template_filter = backend.sharedarr_to_arr(*template_filter)
|
1187
|
-
|
1188
|
-
axes = tuple(range(template.ndim))
|
1189
|
-
eps = backend.eps(real_dtype)
|
1190
|
-
|
1191
|
-
# Allocate score and process specific arrays
|
1192
|
-
template_rot = backend.preallocate_array(fast_shape, real_dtype)
|
1193
|
-
mask_overlap = backend.preallocate_array(fast_shape, real_dtype)
|
1194
|
-
numerator = backend.preallocate_array(fast_shape, real_dtype)
|
1195
|
-
|
1196
|
-
temp = backend.preallocate_array(fast_shape, real_dtype)
|
1197
|
-
temp2 = backend.preallocate_array(fast_shape, real_dtype)
|
1198
|
-
temp3 = backend.preallocate_array(fast_shape, real_dtype)
|
1199
|
-
temp_ft = backend.preallocate_array(fast_ft_shape, complex_dtype)
|
1200
|
-
|
1201
|
-
rfftn, irfftn = backend.build_fft(
|
1202
|
-
fast_shape=fast_shape,
|
1203
|
-
fast_ft_shape=fast_ft_shape,
|
1204
|
-
real_dtype=real_dtype,
|
1205
|
-
complex_dtype=complex_dtype,
|
1206
|
-
fftargs=kwargs.get("fftargs", {}),
|
1207
|
-
temp_real=numerator,
|
1208
|
-
temp_fft=temp_ft,
|
1209
|
-
)
|
1210
|
-
|
1211
|
-
template_filter_func = conditional_execute(
|
1212
|
-
apply_filter, backend.size(template_filter) != 1
|
1213
|
-
)
|
1214
|
-
callback_func = conditional_execute(callback, callback_class is not None)
|
1215
|
-
|
1216
|
-
# Calculate scores across all rotations
|
1217
|
-
for index in range(rotations.shape[0]):
|
1218
|
-
rotation = rotations[index]
|
1219
|
-
backend.fill(template_rot, 0)
|
1220
|
-
backend.fill(temp, 0)
|
1221
|
-
|
1222
|
-
backend.rotate_array(
|
1223
|
-
arr=template,
|
1224
|
-
arr_mask=template_mask,
|
1225
|
-
rotation_matrix=rotation,
|
1226
|
-
out=template_rot,
|
1227
|
-
out_mask=temp,
|
1228
|
-
use_geometric_center=True,
|
1229
|
-
order=interpolation_order,
|
1230
|
-
)
|
1231
|
-
|
1232
|
-
backend.multiply(template_rot, temp > 0, out=template_rot)
|
1233
|
-
|
1234
|
-
# template_rot_ft
|
1235
|
-
rfftn(template_rot, temp_ft)
|
1236
|
-
template_filter_func(ft_template=temp_ft, template_filter=template_filter)
|
1237
|
-
irfftn(target_mask_ft * temp_ft, temp2)
|
1238
|
-
irfftn(target_ft * temp_ft, numerator)
|
1239
|
-
|
1240
|
-
# temp template_mask_rot | temp_ft template_mask_rot_ft
|
1241
|
-
# Calculate overlap of masks at every point in the convolution.
|
1242
|
-
# Locations with high overlap should not be taken into account.
|
1243
|
-
rfftn(temp, temp_ft)
|
1244
|
-
irfftn(temp_ft * target_mask_ft, mask_overlap)
|
1245
|
-
mask_overlap[:] = np.round(mask_overlap)
|
1246
|
-
mask_overlap[:] = np.maximum(mask_overlap, eps)
|
1247
|
-
irfftn(temp_ft * target_ft, temp)
|
1248
|
-
|
1249
|
-
backend.subtract(
|
1250
|
-
numerator,
|
1251
|
-
backend.divide(backend.multiply(temp, temp2), mask_overlap),
|
1252
|
-
out=numerator,
|
1253
|
-
)
|
1254
|
-
|
1255
|
-
# temp_3 = fixed_denom
|
1256
|
-
backend.multiply(temp_ft, target_ft2, out=temp_ft)
|
1257
|
-
irfftn(temp_ft, temp3)
|
1258
|
-
backend.subtract(
|
1259
|
-
temp3, backend.divide(backend.square(temp), mask_overlap), out=temp3
|
1260
|
-
)
|
1261
|
-
backend.maximum(temp3, 0.0, out=temp3)
|
43
|
+
if last_type is None:
|
44
|
+
return None
|
45
|
+
traceback.print_tb(last_traceback)
|
46
|
+
raise Exception(last_value)
|
1262
47
|
|
1263
|
-
# temp = moving_denom
|
1264
|
-
rfftn(backend.square(template_rot), temp_ft)
|
1265
|
-
backend.multiply(target_mask_ft, temp_ft, out=temp_ft)
|
1266
|
-
irfftn(temp_ft, temp)
|
1267
48
|
|
1268
|
-
|
1269
|
-
|
1270
|
-
|
1271
|
-
|
1272
|
-
|
1273
|
-
# temp_2 = denom
|
1274
|
-
backend.multiply(temp3, temp, out=temp)
|
1275
|
-
backend.sqrt(temp, temp2)
|
1276
|
-
|
1277
|
-
# Pixels where `denom` is very small will introduce large
|
1278
|
-
# numbers after division. To get around this problem,
|
1279
|
-
# we zero-out problematic pixels.
|
1280
|
-
tol = 1e3 * eps * backend.max(backend.abs(temp2), axis=axes, keepdims=True)
|
1281
|
-
nonzero_indices = temp2 > tol
|
1282
|
-
|
1283
|
-
backend.fill(temp, 0)
|
1284
|
-
temp[nonzero_indices] = numerator[nonzero_indices] / temp2[nonzero_indices]
|
1285
|
-
backend.clip(temp, a_min=-1, a_max=1, out=temp)
|
1286
|
-
|
1287
|
-
# Apply overlap ratio threshold
|
1288
|
-
number_px_threshold = overlap_ratio * backend.max(
|
1289
|
-
mask_overlap, axis=axes, keepdims=True
|
1290
|
-
)
|
1291
|
-
temp[mask_overlap < number_px_threshold] = 0.0
|
49
|
+
def _wrap_backend(func):
|
50
|
+
@wraps(func)
|
51
|
+
def wrapper(*args, backend_name: str, backend_args: Dict, **kwargs):
|
52
|
+
from tme.backends import backend as be
|
1292
53
|
|
1293
|
-
|
1294
|
-
|
1295
|
-
rotation_matrix=rotation,
|
1296
|
-
rotation_index=index,
|
1297
|
-
**callback_class_args,
|
1298
|
-
)
|
54
|
+
be.change_backend(backend_name, **backend_args)
|
55
|
+
return func(*args, **kwargs)
|
1299
56
|
|
1300
|
-
return
|
57
|
+
return wrapper
|
1301
58
|
|
1302
59
|
|
1303
60
|
def _setup_template_filter_apply_target_filter(
|
@@ -1306,42 +63,47 @@ def _setup_template_filter_apply_target_filter(
|
|
1306
63
|
irfftn: Callable,
|
1307
64
|
fast_shape: Tuple[int],
|
1308
65
|
fast_ft_shape: Tuple[int],
|
66
|
+
pad_template_filter: bool = True,
|
1309
67
|
):
|
1310
68
|
filter_template = isinstance(matching_data.template_filter, Compose)
|
1311
69
|
filter_target = isinstance(matching_data.target_filter, Compose)
|
1312
70
|
|
1313
|
-
template_filter =
|
71
|
+
template_filter = be.full(shape=(1,), fill_value=1, dtype=be._float_dtype)
|
1314
72
|
|
1315
73
|
if not filter_template and not filter_target:
|
1316
74
|
return template_filter
|
1317
75
|
|
1318
|
-
target_temp =
|
1319
|
-
|
1320
|
-
)
|
1321
|
-
target_temp_ft = backend.preallocate_array(fast_ft_shape, backend._complex_dtype)
|
76
|
+
target_temp = be.topleft_pad(matching_data.target, fast_shape)
|
77
|
+
target_temp_ft = be.zeros(fast_ft_shape, be._complex_dtype)
|
1322
78
|
|
1323
|
-
|
1324
|
-
filter_shape
|
1325
|
-
|
1326
|
-
fast_shape = fast_shape[fast_shape != 0]
|
79
|
+
inv_mask = be.subtract(1, be.to_backend_array(matching_data._batch_mask))
|
80
|
+
filter_shape = be.multiply(be.to_backend_array(fast_ft_shape), inv_mask)
|
81
|
+
filter_shape = tuple(int(x) if x != 0 else 1 for x in filter_shape)
|
1327
82
|
|
1328
|
-
fast_shape =
|
1329
|
-
|
83
|
+
fast_shape = be.multiply(be.to_backend_array(fast_shape), inv_mask)
|
84
|
+
fast_shape = tuple(int(x) for x in fast_shape if x != 0)
|
85
|
+
|
86
|
+
target_temp_ft = rfftn(target_temp, target_temp_ft)
|
87
|
+
if filter_template:
|
88
|
+
# TODO: Pad to fast shapes and adapt _setup_template_filtering accordingly
|
89
|
+
template_fast_shape, template_filter_shape = fast_shape, filter_shape
|
90
|
+
if not pad_template_filter:
|
91
|
+
template_fast_shape = tuple(int(x) for x in matching_data._template.shape)
|
92
|
+
template_filter_shape = [
|
93
|
+
int(x) for x in matching_data._output_template_shape
|
94
|
+
]
|
95
|
+
template_filter_shape[-1] = template_filter_shape[-1] // 2 + 1
|
1330
96
|
|
1331
|
-
rfftn(target_temp, target_temp_ft)
|
1332
|
-
if isinstance(matching_data.template_filter, Compose):
|
1333
97
|
template_filter = matching_data.template_filter(
|
1334
|
-
shape=
|
98
|
+
shape=template_fast_shape,
|
1335
99
|
return_real_fourier=True,
|
1336
100
|
shape_is_real_fourier=False,
|
1337
101
|
data_rfft=target_temp_ft,
|
1338
102
|
batch_dimension=matching_data._target_dims,
|
1339
|
-
)
|
1340
|
-
template_filter = template_filter
|
1341
|
-
template_filter[tuple(0 for _ in range(template_filter.ndim))] = 0
|
1342
|
-
template_filter = backend.reshape(template_filter, filter_shape)
|
103
|
+
)["data"]
|
104
|
+
template_filter = be.reshape(template_filter, template_filter_shape)
|
1343
105
|
|
1344
|
-
if
|
106
|
+
if filter_target:
|
1345
107
|
target_filter = matching_data.target_filter(
|
1346
108
|
shape=fast_shape,
|
1347
109
|
return_real_fourier=True,
|
@@ -1349,15 +111,12 @@ def _setup_template_filter_apply_target_filter(
|
|
1349
111
|
data_rfft=target_temp_ft,
|
1350
112
|
weight_type=None,
|
1351
113
|
batch_dimension=matching_data._target_dims,
|
1352
|
-
)
|
1353
|
-
target_filter = target_filter
|
1354
|
-
|
1355
|
-
backend.multiply(target_temp_ft, target_filter, out=target_temp_ft)
|
114
|
+
)["data"]
|
115
|
+
target_filter = be.reshape(target_filter, filter_shape)
|
116
|
+
target_temp_ft = be.multiply(target_temp_ft, target_filter, out=target_temp_ft)
|
1356
117
|
|
1357
|
-
irfftn(target_temp_ft, target_temp)
|
1358
|
-
matching_data._target =
|
1359
|
-
target_temp, matching_data.target.shape
|
1360
|
-
)
|
118
|
+
target_temp = irfftn(target_temp_ft, target_temp)
|
119
|
+
matching_data._target = be.topleft_pad(target_temp, matching_data.target.shape)
|
1361
120
|
|
1362
121
|
return template_filter
|
1363
122
|
|
@@ -1371,13 +130,14 @@ def device_memory_handler(func: Callable):
|
|
1371
130
|
last_type, last_value, last_traceback = sys.exc_info()
|
1372
131
|
try:
|
1373
132
|
with SharedMemoryManager() as smh:
|
1374
|
-
|
133
|
+
gpu_index = kwargs.pop("gpu_index") if "gpu_index" in kwargs else 0
|
134
|
+
with be.set_device(gpu_index):
|
1375
135
|
return_value = func(shared_memory_handler=smh, *args, **kwargs)
|
1376
136
|
except Exception as e:
|
1377
137
|
print(e)
|
1378
138
|
last_type, last_value, last_traceback = sys.exc_info()
|
1379
139
|
finally:
|
1380
|
-
|
140
|
+
_handle_traceback(last_type, last_value, last_traceback)
|
1381
141
|
return return_value
|
1382
142
|
|
1383
143
|
return inner_function
|
@@ -1391,18 +151,20 @@ def scan(
|
|
1391
151
|
n_jobs: int = 4,
|
1392
152
|
callback_class: CallbackClass = None,
|
1393
153
|
callback_class_args: Dict = {},
|
1394
|
-
fftargs: Dict = {},
|
1395
154
|
pad_fourier: bool = True,
|
155
|
+
pad_template_filter: bool = True,
|
1396
156
|
interpolation_order: int = 3,
|
1397
157
|
jobs_per_callback_class: int = 8,
|
1398
|
-
|
1399
|
-
) -> Tuple:
|
158
|
+
shared_memory_handler=None,
|
159
|
+
) -> Optional[Tuple]:
|
1400
160
|
"""
|
1401
|
-
|
161
|
+
Run template matching.
|
162
|
+
|
163
|
+
.. warning:: ``matching_data`` might be altered or destroyed during computation.
|
1402
164
|
|
1403
165
|
Parameters
|
1404
166
|
----------
|
1405
|
-
matching_data : MatchingData
|
167
|
+
matching_data : :py:class:`tme.matching_data.MatchingData`
|
1406
168
|
Template matching data.
|
1407
169
|
matching_setup : Callable
|
1408
170
|
Function pointer to setup function.
|
@@ -1414,21 +176,21 @@ def scan(
|
|
1414
176
|
Analyzer class pointer to operate on computed scores.
|
1415
177
|
callback_class_args : dict, optional
|
1416
178
|
Arguments passed to the callback_class. Default is an empty dictionary.
|
1417
|
-
fftargs : dict, optional
|
1418
|
-
Arguments for the FFT operations. Default is an empty dictionary.
|
1419
179
|
pad_fourier: bool, optional
|
1420
180
|
Whether to pad target and template to the full convolution shape.
|
181
|
+
pad_template_filter: bool, optional
|
182
|
+
Whether to pad potential template filters to the full convolution shape.
|
1421
183
|
interpolation_order : int, optional
|
1422
184
|
Order of spline interpolation for rotations.
|
1423
185
|
jobs_per_callback_class : int, optional
|
1424
186
|
How many jobs should be processed by a single callback_class instance,
|
1425
187
|
if one is provided.
|
1426
|
-
|
1427
|
-
|
188
|
+
shared_memory_handler : type, optional
|
189
|
+
Manager for shared memory objects, None by default.
|
1428
190
|
|
1429
191
|
Returns
|
1430
192
|
-------
|
1431
|
-
Tuple
|
193
|
+
Optional[Tuple]
|
1432
194
|
The merged results from callback_class if provided otherwise None.
|
1433
195
|
|
1434
196
|
Examples
|
@@ -1450,159 +212,95 @@ def scan(
|
|
1450
212
|
|
1451
213
|
"""
|
1452
214
|
matching_data.to_backend()
|
1453
|
-
shape_diff = backend.subtract(
|
1454
|
-
matching_data._output_target_shape, matching_data._output_template_shape
|
1455
|
-
)
|
1456
|
-
shape_diff = backend.multiply(shape_diff, 1 - matching_data._batch_mask)
|
1457
|
-
|
1458
|
-
if backend.sum(shape_diff < 0) and not pad_fourier:
|
1459
|
-
warnings.warn(
|
1460
|
-
"Target is larger than template and Fourier padding is turned off."
|
1461
|
-
" This can lead to shifted results. You can swap template and target,"
|
1462
|
-
" zero-pad the target or turn off template centering."
|
1463
|
-
)
|
1464
215
|
fast_shape, fast_ft_shape, fourier_shift = matching_data.fourier_padding(
|
1465
216
|
pad_fourier=pad_fourier
|
1466
217
|
)
|
1467
218
|
|
1468
|
-
|
1469
|
-
rfftn, irfftn = backend.build_fft(
|
219
|
+
rfftn, irfftn = be.build_fft(
|
1470
220
|
fast_shape=fast_shape,
|
1471
221
|
fast_ft_shape=fast_ft_shape,
|
1472
|
-
real_dtype=
|
1473
|
-
complex_dtype=
|
1474
|
-
fftargs=fftargs,
|
222
|
+
real_dtype=be._float_dtype,
|
223
|
+
complex_dtype=be._complex_dtype,
|
1475
224
|
)
|
1476
|
-
|
1477
225
|
template_filter = _setup_template_filter_apply_target_filter(
|
1478
226
|
matching_data=matching_data,
|
1479
227
|
rfftn=rfftn,
|
1480
228
|
irfftn=irfftn,
|
1481
229
|
fast_shape=fast_shape,
|
1482
230
|
fast_ft_shape=fast_ft_shape,
|
231
|
+
pad_template_filter=pad_template_filter,
|
1483
232
|
)
|
233
|
+
template_filter = be.astype(be.to_backend_array(template_filter), be._float_dtype)
|
1484
234
|
|
1485
235
|
setup = matching_setup(
|
1486
236
|
rfftn=rfftn,
|
1487
237
|
irfftn=irfftn,
|
1488
238
|
template=matching_data.template,
|
239
|
+
template_filter=template_filter,
|
1489
240
|
template_mask=matching_data.template_mask,
|
1490
241
|
target=matching_data.target,
|
1491
242
|
target_mask=matching_data.target_mask,
|
1492
243
|
fast_shape=fast_shape,
|
1493
244
|
fast_ft_shape=fast_ft_shape,
|
1494
|
-
|
1495
|
-
callback_class_args=callback_class_args,
|
1496
|
-
**kwargs,
|
245
|
+
shared_memory_handler=shared_memory_handler,
|
1497
246
|
)
|
1498
247
|
rfftn, irfftn = None, None
|
1499
|
-
|
1500
|
-
template_filter = backend.to_backend_array(template_filter)
|
1501
|
-
template_filter = backend.astype(template_filter, backend._float_dtype)
|
1502
|
-
template_filter_buffer = backend.arr_to_sharedarr(
|
1503
|
-
arr=template_filter,
|
1504
|
-
shared_memory_handler=kwargs.get("shared_memory_handler", None),
|
1505
|
-
)
|
1506
|
-
setup["template_filter"] = (
|
1507
|
-
template_filter_buffer,
|
1508
|
-
template_filter.shape,
|
1509
|
-
template_filter.dtype,
|
1510
|
-
)
|
1511
|
-
|
1512
|
-
callback_class_args["translation_offset"] = backend.astype(
|
1513
|
-
matching_data._translation_offset, int
|
1514
|
-
)
|
1515
|
-
callback_class_args["thread_safe"] = n_jobs > 1
|
1516
|
-
callback_class_args["gpu_index"] = kwargs.get("gpu_index", -1)
|
1517
|
-
|
1518
|
-
n_callback_classes = max(n_jobs // jobs_per_callback_class, 1)
|
1519
|
-
callback_class = setup.pop("callback_class", callback_class)
|
1520
|
-
callback_class_args = setup.pop("callback_class_args", callback_class_args)
|
1521
|
-
callback_classes = [callback_class for _ in range(n_callback_classes)]
|
1522
|
-
|
1523
|
-
convolution_mode = "same"
|
1524
|
-
if backend.sum(backend.to_backend_array(matching_data._target_pad)) > 0:
|
1525
|
-
convolution_mode = "valid"
|
1526
|
-
|
1527
|
-
callback_class_args["fourier_shift"] = fourier_shift
|
1528
|
-
callback_class_args["convolution_mode"] = convolution_mode
|
1529
|
-
callback_class_args["targetshape"] = setup["targetshape"]
|
1530
|
-
callback_class_args["templateshape"] = setup["templateshape"]
|
1531
|
-
|
1532
|
-
if callback_class == MaxScoreOverRotations:
|
1533
|
-
callback_classes = [
|
1534
|
-
class_name(
|
1535
|
-
score_space_shape=fast_shape,
|
1536
|
-
score_space_dtype=backend._float_dtype,
|
1537
|
-
shared_memory_handler=kwargs.get("shared_memory_handler", None),
|
1538
|
-
rotation_space_dtype=backend._int_dtype,
|
1539
|
-
**callback_class_args,
|
1540
|
-
)
|
1541
|
-
for class_name in callback_classes
|
1542
|
-
]
|
1543
|
-
|
1544
|
-
matching_data._target, matching_data._template = None, None
|
1545
|
-
matching_data._target_mask, matching_data._template_mask = None, None
|
1546
|
-
|
1547
|
-
setup["fftargs"] = fftargs.copy()
|
1548
|
-
convolution_mode = "same"
|
1549
|
-
if backend.sum(backend.to_backend_array(matching_data._target_pad)) > 0:
|
1550
|
-
convolution_mode = "valid"
|
1551
|
-
setup["convolution_mode"] = convolution_mode
|
1552
248
|
setup["interpolation_order"] = interpolation_order
|
1553
|
-
|
1554
|
-
|
1555
|
-
|
1556
|
-
|
1557
|
-
|
1558
|
-
|
249
|
+
setup["template_filter"] = be.to_sharedarr(template_filter, shared_memory_handler)
|
250
|
+
|
251
|
+
offset = be.to_backend_array(matching_data._translation_offset)
|
252
|
+
convmode = "valid" if getattr(matching_data, "_is_padded", False) else "same"
|
253
|
+
default_callback_args = {
|
254
|
+
"offset": be.astype(offset, be._int_dtype),
|
255
|
+
"thread_safe": n_jobs > 1,
|
256
|
+
"fourier_shift": fourier_shift,
|
257
|
+
"convolution_mode": convmode,
|
258
|
+
"targetshape": matching_data.target.shape,
|
259
|
+
"templateshape": matching_data.template.shape,
|
260
|
+
"fast_shape": fast_shape,
|
261
|
+
"indices": getattr(matching_data, "indices", None),
|
262
|
+
"shared_memory_handler": shared_memory_handler,
|
263
|
+
"only_unique_rotations": True,
|
264
|
+
}
|
265
|
+
default_callback_args.update(callback_class_args)
|
1559
266
|
|
1560
|
-
|
1561
|
-
|
267
|
+
matching_data._free_data()
|
268
|
+
be.free_cache()
|
1562
269
|
|
270
|
+
# For highly parallel jobs, blocking in certain analyzers becomes a bottleneck
|
271
|
+
if getattr(callback_class, "shared", True):
|
272
|
+
jobs_per_callback_class = 1
|
273
|
+
n_callback_classes = max(n_jobs // jobs_per_callback_class, 1)
|
274
|
+
callback_classes = [
|
275
|
+
callback_class(
|
276
|
+
shape=fast_shape,
|
277
|
+
**default_callback_args,
|
278
|
+
)
|
279
|
+
if callback_class is not None
|
280
|
+
else None
|
281
|
+
for _ in range(n_callback_classes)
|
282
|
+
]
|
1563
283
|
callbacks = Parallel(n_jobs=n_jobs)(
|
1564
|
-
delayed(
|
1565
|
-
backend_name=
|
1566
|
-
backend_args=
|
284
|
+
delayed(_wrap_backend(matching_score))(
|
285
|
+
backend_name=be._backend_name,
|
286
|
+
backend_args=be._backend_args,
|
1567
287
|
rotations=rotation,
|
1568
|
-
|
1569
|
-
callback_class_args=callback_class_args,
|
288
|
+
callback=callback_classes[index % n_callback_classes],
|
1570
289
|
**setup,
|
1571
290
|
)
|
1572
|
-
for index, rotation in enumerate(
|
291
|
+
for index, rotation in enumerate(matching_data._split_rotations_on_jobs(n_jobs))
|
1573
292
|
)
|
1574
293
|
|
1575
|
-
callbacks = callbacks[0:n_callback_classes]
|
1576
294
|
callbacks = [
|
1577
|
-
tuple(
|
1578
|
-
callback._postprocess(
|
1579
|
-
fourier_shift=fourier_shift,
|
1580
|
-
convolution_mode=convolution_mode,
|
1581
|
-
targetshape=setup["targetshape"],
|
1582
|
-
templateshape=setup["templateshape"],
|
1583
|
-
shared_memory_handler=kwargs.get("shared_memory_handler", None),
|
1584
|
-
)
|
1585
|
-
)
|
1586
|
-
if hasattr(callback, "_postprocess")
|
1587
|
-
else tuple(callback)
|
295
|
+
tuple(callback._postprocess(**default_callback_args))
|
1588
296
|
for callback in callbacks
|
1589
297
|
if callback is not None
|
1590
298
|
]
|
1591
|
-
|
299
|
+
be.free_cache()
|
1592
300
|
|
1593
|
-
merged_callback = None
|
1594
301
|
if callback_class is not None:
|
1595
|
-
|
1596
|
-
|
1597
|
-
score_indices = matching_data.indices
|
1598
|
-
merged_callback = callback_class.merge(
|
1599
|
-
callbacks,
|
1600
|
-
**callback_class_args,
|
1601
|
-
score_indices=score_indices,
|
1602
|
-
inner_merge=True,
|
1603
|
-
)
|
1604
|
-
|
1605
|
-
return merged_callback
|
302
|
+
return callback_class.merge(callbacks, **default_callback_args)
|
303
|
+
return None
|
1606
304
|
|
1607
305
|
|
1608
306
|
def scan_subsets(
|
@@ -1616,10 +314,12 @@ def scan_subsets(
|
|
1616
314
|
template_splits: Dict = {},
|
1617
315
|
pad_target_edges: bool = False,
|
1618
316
|
pad_fourier: bool = True,
|
317
|
+
pad_template_filter: bool = True,
|
1619
318
|
interpolation_order: int = 3,
|
1620
319
|
jobs_per_callback_class: int = 8,
|
1621
|
-
|
1622
|
-
|
320
|
+
backend_name: str = None,
|
321
|
+
backend_args: Dict = {},
|
322
|
+
) -> Optional[Tuple]:
|
1623
323
|
"""
|
1624
324
|
Wrapper around :py:meth:`scan` that supports matching on splits
|
1625
325
|
of ``matching_data``.
|
@@ -1651,21 +351,17 @@ def scan_subsets(
|
|
1651
351
|
along each axis.
|
1652
352
|
pad_fourier: bool, optional
|
1653
353
|
Whether to pad target and template to the full convolution shape.
|
354
|
+
pad_template_filter: bool, optional
|
355
|
+
Whether to pad potential template filters to the full convolution shape.
|
1654
356
|
interpolation_order : int, optional
|
1655
357
|
Order of spline interpolation for rotations.
|
1656
358
|
jobs_per_callback_class : int, optional
|
1657
359
|
How many jobs should be processed by a single callback_class instance,
|
1658
360
|
if ones is provided.
|
1659
|
-
**kwargs : various
|
1660
|
-
Additional arguments.
|
1661
|
-
|
1662
|
-
Notes
|
1663
|
-
-----
|
1664
|
-
Objects in matching_data might be destroyed during computation.
|
1665
361
|
|
1666
362
|
Returns
|
1667
363
|
-------
|
1668
|
-
Tuple
|
364
|
+
Optional[Tuple]
|
1669
365
|
The merged results from callback_class if provided otherwise None.
|
1670
366
|
|
1671
367
|
Examples
|
@@ -1720,73 +416,59 @@ def scan_subsets(
|
|
1720
416
|
>>> target_splits = target_splits,
|
1721
417
|
>>> )
|
1722
418
|
|
1723
|
-
The
|
419
|
+
The ``results`` tuple contains the output of the chosen analyzer.
|
1724
420
|
|
1725
421
|
See Also
|
1726
422
|
--------
|
1727
423
|
:py:meth:`tme.matching_utils.compute_parallelization_schedule`
|
1728
424
|
"""
|
1729
|
-
|
1730
|
-
|
1731
|
-
)
|
425
|
+
template_splits = split_shape(matching_data._template.shape, splits=template_splits)
|
426
|
+
target_splits = split_shape(matching_data._target.shape, splits=target_splits)
|
1732
427
|
if (len(target_splits) > 1) and not pad_target_edges:
|
1733
428
|
warnings.warn(
|
1734
429
|
"Target splitting without padding target edges leads to unreliable "
|
1735
430
|
"similarity estimates around the split border."
|
1736
431
|
)
|
1737
|
-
|
1738
|
-
template_splits = split_numpy_array_slices(
|
1739
|
-
matching_data._template.shape, splits=template_splits
|
1740
|
-
)
|
1741
|
-
target_pad = matching_data.target_padding(pad_target=pad_target_edges)
|
432
|
+
splits = tuple(product(target_splits, template_splits))
|
1742
433
|
|
1743
434
|
outer_jobs, inner_jobs = job_schedule
|
1744
|
-
|
1745
|
-
|
1746
|
-
|
1747
|
-
|
1748
|
-
matching_data=matching_data
|
1749
|
-
|
1750
|
-
|
1751
|
-
|
1752
|
-
),
|
1753
|
-
matching_score=matching_score,
|
1754
|
-
matching_setup=matching_setup,
|
1755
|
-
n_jobs=inner_jobs,
|
435
|
+
target_pad = matching_data.target_padding(pad_target=pad_target_edges)
|
436
|
+
if hasattr(be, "scan"):
|
437
|
+
corr_scoring = MATCHING_EXHAUSTIVE_REGISTER.get("CORR", (None, None))[1]
|
438
|
+
results = be.scan(
|
439
|
+
matching_data=matching_data,
|
440
|
+
splits=splits,
|
441
|
+
n_jobs=outer_jobs,
|
442
|
+
rotate_mask=matching_score != corr_scoring,
|
1756
443
|
callback_class=callback_class,
|
1757
|
-
callback_class_args=callback_class_args,
|
1758
|
-
interpolation_order=interpolation_order,
|
1759
|
-
pad_fourier=pad_fourier,
|
1760
|
-
gpu_index=index % outer_jobs,
|
1761
|
-
**kwargs,
|
1762
444
|
)
|
1763
|
-
|
1764
|
-
|
445
|
+
else:
|
446
|
+
results = Parallel(n_jobs=outer_jobs)(
|
447
|
+
delayed(_wrap_backend(scan))(
|
448
|
+
backend_name=be._backend_name,
|
449
|
+
backend_args=be._backend_args,
|
450
|
+
matching_data=matching_data.subset_by_slice(
|
451
|
+
target_slice=target_split,
|
452
|
+
target_pad=target_pad,
|
453
|
+
template_slice=template_split,
|
454
|
+
),
|
455
|
+
matching_score=matching_score,
|
456
|
+
matching_setup=matching_setup,
|
457
|
+
n_jobs=inner_jobs,
|
458
|
+
callback_class=callback_class,
|
459
|
+
callback_class_args=callback_class_args,
|
460
|
+
interpolation_order=interpolation_order,
|
461
|
+
pad_fourier=pad_fourier,
|
462
|
+
gpu_index=index % outer_jobs,
|
463
|
+
pad_template_filter=pad_template_filter,
|
464
|
+
)
|
465
|
+
for index, (target_split, template_split) in enumerate(splits)
|
1765
466
|
)
|
1766
|
-
)
|
1767
|
-
|
1768
|
-
matching_data._target, matching_data._template = None, None
|
1769
|
-
matching_data._target_mask, matching_data._template_mask = None, None
|
1770
467
|
|
1771
|
-
|
468
|
+
matching_data._free_data()
|
1772
469
|
if callback_class is not None:
|
1773
|
-
|
1774
|
-
|
1775
|
-
)
|
1776
|
-
|
1777
|
-
return candidates
|
1778
|
-
|
1779
|
-
|
1780
|
-
MATCHING_EXHAUSTIVE_REGISTER = {
|
1781
|
-
"CC": (cc_setup, corr_scoring),
|
1782
|
-
"LCC": (lcc_setup, corr_scoring),
|
1783
|
-
"CORR": (corr_setup, corr_scoring),
|
1784
|
-
"CAM": (cam_setup, corr_scoring),
|
1785
|
-
"FLCSphericalMask": (flcSphericalMask_setup, corr_scoring),
|
1786
|
-
"FLC": (flc_setup, flc_scoring),
|
1787
|
-
"FLC2": (flc_setup, flc_scoring2),
|
1788
|
-
"MCC": (mcc_setup, mcc_scoring),
|
1789
|
-
}
|
470
|
+
return callback_class.merge(results, **callback_class_args)
|
471
|
+
return None
|
1790
472
|
|
1791
473
|
|
1792
474
|
def register_matching_exhaustive(
|
@@ -1803,20 +485,17 @@ def register_matching_exhaustive(
|
|
1803
485
|
matching : str
|
1804
486
|
Name of the matching method.
|
1805
487
|
matching_setup : Callable
|
1806
|
-
|
488
|
+
Corresponding setup function.
|
1807
489
|
matching_scoring : Callable
|
1808
|
-
|
490
|
+
Corresponing scoring function.
|
1809
491
|
memory_class : MatchingMemoryUsage
|
1810
|
-
|
1811
|
-
:py:class:`tme.matching_memory.MatchingMemoryUsage`.
|
492
|
+
Child of :py:class:`tme.memory.MatchingMemoryUsage`.
|
1812
493
|
|
1813
494
|
Raises
|
1814
495
|
------
|
1815
496
|
ValueError
|
1816
|
-
If a function with the name ``matching`` already exists in the registry
|
1817
|
-
|
1818
|
-
If ``memory_class`` is not a subclass of
|
1819
|
-
:py:class:`tme.matching_memory.MatchingMemoryUsage`.
|
497
|
+
If a function with the name ``matching`` already exists in the registry, or
|
498
|
+
if ``memory_class`` is no child of :py:class:`tme.memory.MatchingMemoryUsage`.
|
1820
499
|
"""
|
1821
500
|
|
1822
501
|
if matching in MATCHING_EXHAUSTIVE_REGISTER:
|