pytme 0.1.5__cp311-cp311-macosx_14_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pytme-0.1.5.data/scripts/estimate_ram_usage.py +81 -0
- pytme-0.1.5.data/scripts/match_template.py +744 -0
- pytme-0.1.5.data/scripts/postprocess.py +279 -0
- pytme-0.1.5.data/scripts/preprocess.py +93 -0
- pytme-0.1.5.data/scripts/preprocessor_gui.py +729 -0
- pytme-0.1.5.dist-info/LICENSE +153 -0
- pytme-0.1.5.dist-info/METADATA +69 -0
- pytme-0.1.5.dist-info/RECORD +63 -0
- pytme-0.1.5.dist-info/WHEEL +5 -0
- pytme-0.1.5.dist-info/entry_points.txt +6 -0
- pytme-0.1.5.dist-info/top_level.txt +2 -0
- scripts/__init__.py +0 -0
- scripts/estimate_ram_usage.py +81 -0
- scripts/match_template.py +744 -0
- scripts/match_template_devel.py +788 -0
- scripts/postprocess.py +279 -0
- scripts/preprocess.py +93 -0
- scripts/preprocessor_gui.py +729 -0
- tme/__init__.py +6 -0
- tme/__version__.py +1 -0
- tme/analyzer.py +1144 -0
- tme/backends/__init__.py +134 -0
- tme/backends/cupy_backend.py +309 -0
- tme/backends/matching_backend.py +1154 -0
- tme/backends/npfftw_backend.py +763 -0
- tme/backends/pytorch_backend.py +526 -0
- tme/data/__init__.py +0 -0
- tme/data/c48n309.npy +0 -0
- tme/data/c48n527.npy +0 -0
- tme/data/c48n9.npy +0 -0
- tme/data/c48u1.npy +0 -0
- tme/data/c48u1153.npy +0 -0
- tme/data/c48u1201.npy +0 -0
- tme/data/c48u1641.npy +0 -0
- tme/data/c48u181.npy +0 -0
- tme/data/c48u2219.npy +0 -0
- tme/data/c48u27.npy +0 -0
- tme/data/c48u2947.npy +0 -0
- tme/data/c48u3733.npy +0 -0
- tme/data/c48u4749.npy +0 -0
- tme/data/c48u5879.npy +0 -0
- tme/data/c48u7111.npy +0 -0
- tme/data/c48u815.npy +0 -0
- tme/data/c48u83.npy +0 -0
- tme/data/c48u8649.npy +0 -0
- tme/data/c600v.npy +0 -0
- tme/data/c600vc.npy +0 -0
- tme/data/metadata.yaml +80 -0
- tme/data/quat_to_numpy.py +42 -0
- tme/data/scattering_factors.pickle +0 -0
- tme/density.py +2314 -0
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/helpers.py +881 -0
- tme/matching_data.py +377 -0
- tme/matching_exhaustive.py +1553 -0
- tme/matching_memory.py +382 -0
- tme/matching_optimization.py +1123 -0
- tme/matching_utils.py +1180 -0
- tme/parser.py +429 -0
- tme/preprocessor.py +1291 -0
- tme/scoring.py +866 -0
- tme/structure.py +1428 -0
- tme/types.py +10 -0
@@ -0,0 +1,1553 @@
|
|
1
|
+
""" Implements template matching based on different metrics. The functions
|
2
|
+
implemented here are equivalent to what is implemented in `Fitter`,
|
3
|
+
but removing them out of the Fitter classes reduces the time it takes
|
4
|
+
when initializing the multiprocessing framework. This is useful in cases,
|
5
|
+
where the number of sampled rotations is low, but the number of input
|
6
|
+
arrays is high, e.g. split fitting and evaluating different templates.
|
7
|
+
|
8
|
+
Copyright (c) 2023 European Molecular Biology Laboratory
|
9
|
+
|
10
|
+
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
11
|
+
"""
|
12
|
+
import os
|
13
|
+
import sys
|
14
|
+
from copy import deepcopy
|
15
|
+
from itertools import product
|
16
|
+
from typing import Callable, Tuple, Dict
|
17
|
+
from functools import wraps
|
18
|
+
from joblib import Parallel, delayed
|
19
|
+
from multiprocessing.managers import SharedMemoryManager
|
20
|
+
|
21
|
+
import numpy as np
|
22
|
+
from scipy.ndimage import laplace
|
23
|
+
|
24
|
+
from .analyzer import MaxScoreOverRotations
|
25
|
+
from .matching_utils import (
|
26
|
+
apply_convolution_mode,
|
27
|
+
handle_traceback,
|
28
|
+
split_numpy_array_slices,
|
29
|
+
conditional_execute,
|
30
|
+
)
|
31
|
+
from .matching_memory import MatchingMemoryUsage, MATCHING_MEMORY_REGISTRY
|
32
|
+
from .preprocessor import Preprocessor
|
33
|
+
from .matching_data import MatchingData
|
34
|
+
from .backends import backend
|
35
|
+
from .types import NDArray, CallbackClass
|
36
|
+
|
37
|
+
os.environ["MKL_NUM_THREADS"] = "1"
|
38
|
+
os.environ["OMP_NUM_THREADS"] = "1"
|
39
|
+
os.environ["PYFFTW_NUM_THREADS"] = "1"
|
40
|
+
os.environ["OPENBLAS_NUM_THREADS"] = "1"
|
41
|
+
|
42
|
+
|
43
|
+
def _run_inner(backend_name, backend_args, **kwargs):
|
44
|
+
from tme.backends import backend
|
45
|
+
|
46
|
+
backend.change_backend(backend_name, **backend_args)
|
47
|
+
# print("make sure to remove this")
|
48
|
+
# from tme.scoring import scan
|
49
|
+
return scan(**kwargs)
|
50
|
+
|
51
|
+
|
52
|
+
def cc_setup(
|
53
|
+
rfftn: Callable,
|
54
|
+
irfftn: Callable,
|
55
|
+
template: NDArray,
|
56
|
+
target: NDArray,
|
57
|
+
fast_shape: Tuple[int],
|
58
|
+
fast_ft_shape: Tuple[int],
|
59
|
+
real_dtype: type,
|
60
|
+
complex_dtype: type,
|
61
|
+
shared_memory_handler: Callable,
|
62
|
+
callback_class: Callable,
|
63
|
+
callback_class_args: Dict,
|
64
|
+
**kwargs,
|
65
|
+
) -> Dict:
|
66
|
+
"""
|
67
|
+
Setup to compute the cross-correlation between a target f and template g
|
68
|
+
defined as:
|
69
|
+
|
70
|
+
.. math::
|
71
|
+
|
72
|
+
\\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*)
|
73
|
+
|
74
|
+
|
75
|
+
See Also
|
76
|
+
--------
|
77
|
+
:py:meth:`corr_scoring`
|
78
|
+
:py:class:`tme.matching_optimization.CrossCorrelation`
|
79
|
+
"""
|
80
|
+
target_shape = target.shape
|
81
|
+
target_pad = backend.topleft_pad(target, fast_shape)
|
82
|
+
target_pad_ft = backend.preallocate_array(fast_ft_shape, complex_dtype)
|
83
|
+
|
84
|
+
rfftn(target_pad, target_pad_ft)
|
85
|
+
|
86
|
+
target_ft_out = backend.arr_to_sharedarr(
|
87
|
+
arr=target_pad_ft, shared_memory_handler=shared_memory_handler
|
88
|
+
)
|
89
|
+
|
90
|
+
template_out = backend.arr_to_sharedarr(
|
91
|
+
arr=template, shared_memory_handler=shared_memory_handler
|
92
|
+
)
|
93
|
+
inv_denominator_buffer = backend.arr_to_sharedarr(
|
94
|
+
arr=backend.preallocate_array(1, real_dtype) + 1,
|
95
|
+
shared_memory_handler=shared_memory_handler,
|
96
|
+
)
|
97
|
+
numerator2_buffer = backend.arr_to_sharedarr(
|
98
|
+
arr=backend.preallocate_array(1, real_dtype),
|
99
|
+
shared_memory_handler=shared_memory_handler,
|
100
|
+
)
|
101
|
+
|
102
|
+
target_ft_tuple = (target_ft_out, fast_ft_shape, complex_dtype)
|
103
|
+
template_tuple = (template_out, template.shape, real_dtype)
|
104
|
+
|
105
|
+
inv_denominator_tuple = (inv_denominator_buffer, (1,), real_dtype)
|
106
|
+
numerator2_tuple = (numerator2_buffer, (1,), real_dtype)
|
107
|
+
|
108
|
+
ret = {
|
109
|
+
"template": template_tuple,
|
110
|
+
"ft_target": target_ft_tuple,
|
111
|
+
"inv_denominator": inv_denominator_tuple,
|
112
|
+
"numerator2": numerator2_tuple,
|
113
|
+
"targetshape": target_shape,
|
114
|
+
"templateshape": template.shape,
|
115
|
+
"fast_shape": fast_shape,
|
116
|
+
"fast_ft_shape": fast_ft_shape,
|
117
|
+
"real_dtype": real_dtype,
|
118
|
+
"complex_dtype": complex_dtype,
|
119
|
+
"callback_class": callback_class,
|
120
|
+
"callback_class_args": callback_class_args,
|
121
|
+
"use_memmap": kwargs.get("use_memmap", False),
|
122
|
+
"temp_dir": kwargs.get("temp_dir", None),
|
123
|
+
}
|
124
|
+
|
125
|
+
return ret
|
126
|
+
|
127
|
+
|
128
|
+
def lcc_setup(**kwargs) -> Dict:
|
129
|
+
"""
|
130
|
+
Setup to compute the cross-correlation between a laplace transformed target f
|
131
|
+
and laplace transformed template g defined as:
|
132
|
+
|
133
|
+
.. math::
|
134
|
+
|
135
|
+
\\mathcal{F}^{-1}(\\mathcal{F}(\\nabla^{2}f) \\cdot \\mathcal{F}(\\nabla^{2} g)^*)
|
136
|
+
|
137
|
+
|
138
|
+
See Also
|
139
|
+
--------
|
140
|
+
:py:meth:`corr_scoring`
|
141
|
+
:py:class:`tme.matching_optimization.LaplaceCrossCorrelation`
|
142
|
+
"""
|
143
|
+
from cupyx.scipy.ndimage import laplace
|
144
|
+
kwargs["target"] = laplace(kwargs["target"], mode="wrap")
|
145
|
+
kwargs["template"] = laplace(kwargs["template"], mode="wrap")
|
146
|
+
return flcSphericalMask_setup(**kwargs)
|
147
|
+
|
148
|
+
|
149
|
+
def corr_setup(
|
150
|
+
rfftn: Callable,
|
151
|
+
irfftn: Callable,
|
152
|
+
template: NDArray,
|
153
|
+
template_mask: NDArray,
|
154
|
+
target: NDArray,
|
155
|
+
fast_shape: Tuple[int],
|
156
|
+
fast_ft_shape: Tuple[int],
|
157
|
+
real_dtype: type,
|
158
|
+
complex_dtype: type,
|
159
|
+
shared_memory_handler: Callable,
|
160
|
+
callback_class: Callable,
|
161
|
+
callback_class_args: Dict,
|
162
|
+
**kwargs,
|
163
|
+
) -> Dict:
|
164
|
+
"""
|
165
|
+
Setup to compute a normalized cross-correlation between a target f and a template g.
|
166
|
+
|
167
|
+
.. math::
|
168
|
+
|
169
|
+
\\frac{CC(f,g) - \\overline{g} \\cdot CC(f, m)}
|
170
|
+
{(CC(f^2, m) - \\frac{CC(f, m)^2}{N_g}) \\cdot \\sigma_{g}}
|
171
|
+
|
172
|
+
Where:
|
173
|
+
|
174
|
+
.. math::
|
175
|
+
|
176
|
+
CC(f,g) = \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*)
|
177
|
+
|
178
|
+
and m is a mask with the same dimension as the template filled with ones.
|
179
|
+
|
180
|
+
References
|
181
|
+
----------
|
182
|
+
.. [1] J. P. Lewis, "Fast Normalized Cross-Correlation", Industrial Light
|
183
|
+
and Magic.
|
184
|
+
|
185
|
+
See Also
|
186
|
+
--------
|
187
|
+
:py:meth:`corr_scoring`
|
188
|
+
:py:class:`tme.matching_optimization.NormalizedCrossCorrelation`.
|
189
|
+
"""
|
190
|
+
target_pad = backend.topleft_pad(target, fast_shape)
|
191
|
+
|
192
|
+
# The exact composition of the denominator is debatable
|
193
|
+
# scikit-image match_template multiplies the running sum of the target
|
194
|
+
# with a scaling factor derived from the template. This is probably appropriate
|
195
|
+
# in pattern matching situations where the template exists in the target
|
196
|
+
window_template = backend.topleft_pad(template_mask, fast_shape)
|
197
|
+
ft_window_template = backend.preallocate_array(fast_ft_shape, complex_dtype)
|
198
|
+
rfftn(window_template, ft_window_template)
|
199
|
+
window_template = None
|
200
|
+
|
201
|
+
# Target and squared target window sums
|
202
|
+
ft_target = backend.preallocate_array(fast_ft_shape, complex_dtype)
|
203
|
+
ft_target2 = backend.preallocate_array(fast_ft_shape, complex_dtype)
|
204
|
+
denominator = backend.preallocate_array(fast_shape, real_dtype)
|
205
|
+
target_window_sum = backend.preallocate_array(fast_shape, real_dtype)
|
206
|
+
rfftn(target_pad, ft_target)
|
207
|
+
|
208
|
+
rfftn(backend.square(target_pad), ft_target2)
|
209
|
+
backend.multiply(ft_target2, ft_window_template, out=ft_target2)
|
210
|
+
irfftn(ft_target2, denominator)
|
211
|
+
|
212
|
+
backend.multiply(ft_target, ft_window_template, out=ft_window_template)
|
213
|
+
irfftn(ft_window_template, target_window_sum)
|
214
|
+
|
215
|
+
target_pad, ft_target2, ft_window_template = None, None, None
|
216
|
+
|
217
|
+
# Normalizing constants
|
218
|
+
template_mean = backend.mean(template)
|
219
|
+
template_volume = np.prod(template.shape)
|
220
|
+
template_ssd = backend.sum(
|
221
|
+
backend.square(backend.subtract(template, template_mean))
|
222
|
+
)
|
223
|
+
|
224
|
+
# Final numerator is score - numerator2
|
225
|
+
numerator2 = backend.multiply(target_window_sum, template_mean)
|
226
|
+
|
227
|
+
# Compute denominator
|
228
|
+
backend.multiply(target_window_sum, target_window_sum, out=target_window_sum)
|
229
|
+
backend.divide(target_window_sum, template_volume, out=target_window_sum)
|
230
|
+
|
231
|
+
backend.subtract(denominator, target_window_sum, out=denominator)
|
232
|
+
backend.multiply(denominator, template_ssd, out=denominator)
|
233
|
+
backend.maximum(denominator, 0, out=denominator)
|
234
|
+
backend.sqrt(denominator, out=denominator)
|
235
|
+
target_window_sum = None
|
236
|
+
|
237
|
+
# Invert denominator to compute final score as product
|
238
|
+
denominator_mask = denominator > backend.eps(denominator.dtype)
|
239
|
+
inv_denominator = backend.preallocate_array(fast_shape, real_dtype)
|
240
|
+
inv_denominator[denominator_mask] = 1 / denominator[denominator_mask]
|
241
|
+
|
242
|
+
# Convert arrays used in subsequent fitting to SharedMemory objects
|
243
|
+
template_buffer = backend.arr_to_sharedarr(
|
244
|
+
arr=template, shared_memory_handler=shared_memory_handler
|
245
|
+
)
|
246
|
+
target_ft_buffer = backend.arr_to_sharedarr(
|
247
|
+
arr=ft_target, shared_memory_handler=shared_memory_handler
|
248
|
+
)
|
249
|
+
inv_denominator_buffer = backend.arr_to_sharedarr(
|
250
|
+
arr=inv_denominator, shared_memory_handler=shared_memory_handler
|
251
|
+
)
|
252
|
+
numerator2_buffer = backend.arr_to_sharedarr(
|
253
|
+
arr=numerator2, shared_memory_handler=shared_memory_handler
|
254
|
+
)
|
255
|
+
|
256
|
+
template_tuple = (template_buffer, deepcopy(template.shape), real_dtype)
|
257
|
+
target_ft_tuple = (target_ft_buffer, fast_ft_shape, complex_dtype)
|
258
|
+
|
259
|
+
inv_denominator_tuple = (inv_denominator_buffer, fast_shape, real_dtype)
|
260
|
+
numerator2_tuple = (numerator2_buffer, fast_shape, real_dtype)
|
261
|
+
|
262
|
+
ft_target, inv_denominator, numerator2 = None, None, None
|
263
|
+
|
264
|
+
ret = {
|
265
|
+
"template": template_tuple,
|
266
|
+
"ft_target": target_ft_tuple,
|
267
|
+
"inv_denominator": inv_denominator_tuple,
|
268
|
+
"numerator2": numerator2_tuple,
|
269
|
+
"targetshape": deepcopy(target.shape),
|
270
|
+
"templateshape": deepcopy(template.shape),
|
271
|
+
"fast_shape": fast_shape,
|
272
|
+
"fast_ft_shape": fast_ft_shape,
|
273
|
+
"real_dtype": real_dtype,
|
274
|
+
"complex_dtype": complex_dtype,
|
275
|
+
"callback_class": callback_class,
|
276
|
+
"callback_class_args": callback_class_args,
|
277
|
+
"template_mean": kwargs.get("template_mean", template_mean),
|
278
|
+
}
|
279
|
+
|
280
|
+
return ret
|
281
|
+
|
282
|
+
|
283
|
+
def cam_setup(**kwargs):
|
284
|
+
"""
|
285
|
+
Setup to compute a normalized cross-correlation between a target f and a template g
|
286
|
+
over their means. In practice this can be expressed like the cross-correlation
|
287
|
+
CORR defined in :py:meth:`corr_scoring`, so that:
|
288
|
+
|
289
|
+
Notes
|
290
|
+
-----
|
291
|
+
|
292
|
+
.. math::
|
293
|
+
|
294
|
+
\\text{CORR}(f-\\overline{f}, g - \\overline{g})
|
295
|
+
|
296
|
+
Where
|
297
|
+
|
298
|
+
.. math::
|
299
|
+
|
300
|
+
\\overline{f}, \\overline{g}
|
301
|
+
|
302
|
+
are the mean of f and g respectively.
|
303
|
+
|
304
|
+
References
|
305
|
+
----------
|
306
|
+
.. [1] J. P. Lewis, "Fast Normalized Cross-Correlation", Industrial Light
|
307
|
+
and Magic.
|
308
|
+
|
309
|
+
See Also
|
310
|
+
--------
|
311
|
+
:py:meth:`corr_scoring`.
|
312
|
+
:py:class:`tme.matching_optimization.NormalizedCrossCorrelationMean`.
|
313
|
+
"""
|
314
|
+
kwargs["template"] = kwargs["template"] - kwargs["template"].mean()
|
315
|
+
return corr_setup(**kwargs)
|
316
|
+
|
317
|
+
|
318
|
+
def flc_setup(
|
319
|
+
rfftn: Callable,
|
320
|
+
irfftn: Callable,
|
321
|
+
template: NDArray,
|
322
|
+
template_mask: NDArray,
|
323
|
+
target: NDArray,
|
324
|
+
fast_shape: Tuple[int],
|
325
|
+
fast_ft_shape: Tuple[int],
|
326
|
+
real_dtype: type,
|
327
|
+
complex_dtype: type,
|
328
|
+
shared_memory_handler: Callable,
|
329
|
+
callback_class: Callable,
|
330
|
+
callback_class_args: Dict,
|
331
|
+
**kwargs,
|
332
|
+
) -> Dict:
|
333
|
+
"""
|
334
|
+
Setup to compute a normalized cross-correlation score of a target f a template g
|
335
|
+
and a mask m:
|
336
|
+
|
337
|
+
.. math::
|
338
|
+
|
339
|
+
\\frac{CC(f, \\frac{g*m - \\overline{g*m}}{\\sigma_{g*m}})}
|
340
|
+
{N_m * \\sqrt{
|
341
|
+
\\frac{CC(f^2, m)}{N_m} - (\\frac{CC(f, m)}{N_m})^2}
|
342
|
+
}
|
343
|
+
|
344
|
+
Where:
|
345
|
+
|
346
|
+
.. math::
|
347
|
+
|
348
|
+
CC(f,g) = \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*)
|
349
|
+
|
350
|
+
and Nm is the number of voxels within the template mask m.
|
351
|
+
|
352
|
+
References
|
353
|
+
----------
|
354
|
+
.. [1] W. Wan, S. Khavnekar, J. Wagner, P. Erdmann, and W. Baumeister
|
355
|
+
Microsc. Microanal. 26, 2516 (2020)
|
356
|
+
.. [2] T. Hrabe, Y. Chen, S. Pfeffer, L. Kuhn Cuellar, A.-V. Mangold,
|
357
|
+
and F. Förster, J. Struct. Biol. 178, 177 (2012).
|
358
|
+
|
359
|
+
See Also
|
360
|
+
--------
|
361
|
+
:py:meth:`flc_scoring`
|
362
|
+
"""
|
363
|
+
target_pad = backend.topleft_pad(target, fast_shape)
|
364
|
+
|
365
|
+
# Target and squared target window sums
|
366
|
+
ft_target = backend.preallocate_array(fast_ft_shape, complex_dtype)
|
367
|
+
ft_target2 = backend.preallocate_array(fast_ft_shape, complex_dtype)
|
368
|
+
rfftn(target_pad, ft_target)
|
369
|
+
rfftn(backend.square(target_pad), ft_target2)
|
370
|
+
|
371
|
+
# Convert arrays used in subsequent fitting to SharedMemory objects
|
372
|
+
ft_target = backend.arr_to_sharedarr(
|
373
|
+
arr=ft_target, shared_memory_handler=shared_memory_handler
|
374
|
+
)
|
375
|
+
ft_target2 = backend.arr_to_sharedarr(
|
376
|
+
arr=ft_target2, shared_memory_handler=shared_memory_handler
|
377
|
+
)
|
378
|
+
|
379
|
+
template_mask = template_mask > 0
|
380
|
+
template_mean = backend.mean(template[template_mask])
|
381
|
+
template_std = backend.std(template[template_mask])
|
382
|
+
template_mask = backend.astype(template_mask, real_dtype)
|
383
|
+
|
384
|
+
backend.divide(template - template_mean, template_std, out=template)
|
385
|
+
backend.multiply(template, template_mask, out=template)
|
386
|
+
|
387
|
+
template_buffer = backend.arr_to_sharedarr(
|
388
|
+
arr=template, shared_memory_handler=shared_memory_handler
|
389
|
+
)
|
390
|
+
template_mask_buffer = backend.arr_to_sharedarr(
|
391
|
+
arr=template_mask, shared_memory_handler=shared_memory_handler
|
392
|
+
)
|
393
|
+
|
394
|
+
template_tuple = (template_buffer, template.shape, real_dtype)
|
395
|
+
template_mask_tuple = (template_mask_buffer, template_mask.shape, real_dtype)
|
396
|
+
|
397
|
+
target_ft_tuple = (ft_target, fast_ft_shape, complex_dtype)
|
398
|
+
target_ft2_tuple = (ft_target2, fast_ft_shape, complex_dtype)
|
399
|
+
|
400
|
+
ret = {
|
401
|
+
"template": template_tuple,
|
402
|
+
"template_mask": template_mask_tuple,
|
403
|
+
"ft_target": target_ft_tuple,
|
404
|
+
"ft_target2": target_ft2_tuple,
|
405
|
+
"targetshape": target.shape,
|
406
|
+
"templateshape": template.shape,
|
407
|
+
"fast_shape": fast_shape,
|
408
|
+
"fast_ft_shape": fast_ft_shape,
|
409
|
+
"real_dtype": real_dtype,
|
410
|
+
"complex_dtype": complex_dtype,
|
411
|
+
"callback_class": callback_class,
|
412
|
+
"callback_class_args": callback_class_args,
|
413
|
+
}
|
414
|
+
|
415
|
+
return ret
|
416
|
+
|
417
|
+
|
418
|
+
def flcSphericalMask_setup(
|
419
|
+
rfftn: Callable,
|
420
|
+
irfftn: Callable,
|
421
|
+
template: NDArray,
|
422
|
+
template_mask: NDArray,
|
423
|
+
target: NDArray,
|
424
|
+
fast_shape: Tuple[int],
|
425
|
+
fast_ft_shape: Tuple[int],
|
426
|
+
real_dtype: type,
|
427
|
+
complex_dtype: type,
|
428
|
+
shared_memory_handler: Callable,
|
429
|
+
callback_class: Callable,
|
430
|
+
callback_class_args: Dict,
|
431
|
+
**kwargs,
|
432
|
+
) -> Dict:
|
433
|
+
"""
|
434
|
+
Like :py:meth:`flc_setup` but for rotation invariant masks. In such cases
|
435
|
+
the score can be computed quicker by not computing the fourier transforms
|
436
|
+
of the mask for each rotation.
|
437
|
+
|
438
|
+
.. math::
|
439
|
+
|
440
|
+
\\frac{CC(f, \\frac{g*m - \\overline{g*m}}{\\sigma_{g*m}})}
|
441
|
+
{N_m * \\sqrt{
|
442
|
+
\\frac{CC(f^2, m)}{N_m} - (\\frac{CC(f, m)}{N_m})^2}
|
443
|
+
}
|
444
|
+
|
445
|
+
Where:
|
446
|
+
|
447
|
+
.. math::
|
448
|
+
|
449
|
+
CC(f,g) = \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*)
|
450
|
+
|
451
|
+
and Nm is the number of voxels within the template mask m.
|
452
|
+
|
453
|
+
References
|
454
|
+
----------
|
455
|
+
.. [1] W. Wan, S. Khavnekar, J. Wagner, P. Erdmann, and W. Baumeister
|
456
|
+
Microsc. Microanal. 26, 2516 (2020)
|
457
|
+
.. [2] T. Hrabe, Y. Chen, S. Pfeffer, L. Kuhn Cuellar, A.-V. Mangold,
|
458
|
+
and F. Förster, J. Struct. Biol. 178, 177 (2012).
|
459
|
+
|
460
|
+
See Also
|
461
|
+
--------
|
462
|
+
:py:meth:`corr_scoring`
|
463
|
+
"""
|
464
|
+
target_pad = backend.topleft_pad(target, fast_shape)
|
465
|
+
template_mask_pad = backend.topleft_pad(template_mask, fast_shape)
|
466
|
+
|
467
|
+
# Target and squared target window sums
|
468
|
+
ft_target = backend.preallocate_array(fast_ft_shape, complex_dtype)
|
469
|
+
ft_template_mask = backend.preallocate_array(fast_ft_shape, complex_dtype)
|
470
|
+
ft_temp = backend.preallocate_array(fast_ft_shape, complex_dtype)
|
471
|
+
|
472
|
+
temp = backend.preallocate_array(fast_shape, real_dtype)
|
473
|
+
temp2 = backend.preallocate_array(fast_shape, real_dtype)
|
474
|
+
numerator2 = backend.preallocate_array(1, real_dtype)
|
475
|
+
|
476
|
+
eps = backend.eps(real_dtype)
|
477
|
+
n_observations = backend.sum(template_mask_pad)
|
478
|
+
rfftn(template_mask_pad, ft_template_mask)
|
479
|
+
|
480
|
+
# Variance part denominator
|
481
|
+
rfftn(backend.square(target_pad), ft_target)
|
482
|
+
backend.multiply(ft_target, ft_template_mask, out=ft_temp)
|
483
|
+
irfftn(ft_temp, temp2)
|
484
|
+
backend.divide(temp2, n_observations, out=temp2)
|
485
|
+
|
486
|
+
# Mean part denominator
|
487
|
+
rfftn(target_pad, ft_target)
|
488
|
+
backend.multiply(ft_target, ft_template_mask, out=ft_temp)
|
489
|
+
irfftn(ft_temp, temp)
|
490
|
+
backend.divide(temp, n_observations, out=temp)
|
491
|
+
backend.square(temp, out=temp)
|
492
|
+
|
493
|
+
backend.subtract(temp2, temp, out=temp)
|
494
|
+
|
495
|
+
backend.maximum(temp, 0.0, out=temp)
|
496
|
+
backend.sqrt(temp, out=temp)
|
497
|
+
backend.multiply(temp, n_observations, out=temp)
|
498
|
+
|
499
|
+
tol = 1e3 * eps * backend.max(backend.abs(temp))
|
500
|
+
nonzero_indices = temp > tol
|
501
|
+
|
502
|
+
backend.fill(temp2, 0)
|
503
|
+
temp2[nonzero_indices] = 1 / temp[nonzero_indices]
|
504
|
+
|
505
|
+
template_mask = template_mask > 0
|
506
|
+
template_mean = backend.mean(template[template_mask])
|
507
|
+
template_std = backend.std(template[template_mask])
|
508
|
+
|
509
|
+
template = backend.divide(backend.subtract(template, template_mean), template_std)
|
510
|
+
backend.multiply(template, template_mask, out=template)
|
511
|
+
|
512
|
+
# Convert arrays used in subsequent fitting to SharedMemory objects
|
513
|
+
template_buffer = backend.arr_to_sharedarr(
|
514
|
+
arr=template, shared_memory_handler=shared_memory_handler
|
515
|
+
)
|
516
|
+
target_ft_buffer = backend.arr_to_sharedarr(
|
517
|
+
arr=ft_target, shared_memory_handler=shared_memory_handler
|
518
|
+
)
|
519
|
+
inv_denominator_buffer = backend.arr_to_sharedarr(
|
520
|
+
arr=temp2, shared_memory_handler=shared_memory_handler
|
521
|
+
)
|
522
|
+
numerator2_buffer = backend.arr_to_sharedarr(
|
523
|
+
arr=numerator2, shared_memory_handler=shared_memory_handler
|
524
|
+
)
|
525
|
+
|
526
|
+
template_tuple = (template_buffer, template.shape, real_dtype)
|
527
|
+
target_ft_tuple = (target_ft_buffer, fast_ft_shape, complex_dtype)
|
528
|
+
|
529
|
+
inv_denominator_tuple = (inv_denominator_buffer, fast_shape, real_dtype)
|
530
|
+
numerator2_tuple = (numerator2_buffer, (1,), real_dtype)
|
531
|
+
|
532
|
+
ret = {
|
533
|
+
"template": template_tuple,
|
534
|
+
"ft_target": target_ft_tuple,
|
535
|
+
"inv_denominator": inv_denominator_tuple,
|
536
|
+
"numerator2": numerator2_tuple,
|
537
|
+
"targetshape": target.shape,
|
538
|
+
"templateshape": template.shape,
|
539
|
+
"fast_shape": fast_shape,
|
540
|
+
"fast_ft_shape": fast_ft_shape,
|
541
|
+
"real_dtype": real_dtype,
|
542
|
+
"complex_dtype": complex_dtype,
|
543
|
+
"callback_class": callback_class,
|
544
|
+
"callback_class_args": callback_class_args,
|
545
|
+
}
|
546
|
+
|
547
|
+
return ret
|
548
|
+
|
549
|
+
|
550
|
+
def mcc_setup(
|
551
|
+
rfftn: Callable,
|
552
|
+
irfftn: Callable,
|
553
|
+
template: NDArray,
|
554
|
+
template_mask: NDArray,
|
555
|
+
target: NDArray,
|
556
|
+
target_mask: NDArray,
|
557
|
+
fast_shape: Tuple[int],
|
558
|
+
fast_ft_shape: Tuple[int],
|
559
|
+
real_dtype: type,
|
560
|
+
complex_dtype: type,
|
561
|
+
shared_memory_handler: Callable,
|
562
|
+
callback_class: Callable,
|
563
|
+
callback_class_args: Dict,
|
564
|
+
**kwargs,
|
565
|
+
) -> Dict:
|
566
|
+
"""
|
567
|
+
Setup to compute a normalized cross-correlation score with masks
|
568
|
+
for both template and target.
|
569
|
+
|
570
|
+
.. math::
|
571
|
+
|
572
|
+
\\frac{
|
573
|
+
CC(f, g) - \\frac{CC(f, m) \\cdot CC(t, g)}{CC(t, m)}
|
574
|
+
}{
|
575
|
+
\\sqrt{
|
576
|
+
(CC(f ^ 2, m) - \\frac{CC(f, m) ^ 2}{CC(t, m)}) \\cdot
|
577
|
+
(CC(t, g^2) - \\frac{CC(t, g) ^ 2}{CC(t, m)})
|
578
|
+
}
|
579
|
+
}
|
580
|
+
|
581
|
+
Where:
|
582
|
+
|
583
|
+
.. math::
|
584
|
+
|
585
|
+
CC(f,g) = \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*)
|
586
|
+
|
587
|
+
|
588
|
+
References
|
589
|
+
----------
|
590
|
+
.. [1] Masked FFT registration, Dirk Padfield, CVPR 2010 conference
|
591
|
+
|
592
|
+
See Also
|
593
|
+
--------
|
594
|
+
:py:meth:`mcc_scoring`
|
595
|
+
:py:class:`tme.matching_optimization.MaskedCrossCorrelation`
|
596
|
+
"""
|
597
|
+
target = backend.multiply(target, target_mask > 0, out=target)
|
598
|
+
|
599
|
+
target_pad = backend.topleft_pad(target, fast_shape)
|
600
|
+
target_mask_pad = backend.topleft_pad(target_mask, fast_shape)
|
601
|
+
|
602
|
+
target_ft = backend.preallocate_array(fast_ft_shape, complex_dtype)
|
603
|
+
rfftn(target_pad, target_ft)
|
604
|
+
target_ft_buffer = backend.arr_to_sharedarr(
|
605
|
+
arr=target_ft, shared_memory_handler=shared_memory_handler
|
606
|
+
)
|
607
|
+
|
608
|
+
target_ft2 = backend.preallocate_array(fast_ft_shape, complex_dtype)
|
609
|
+
rfftn(backend.square(target_pad), target_ft2)
|
610
|
+
target_ft2_buffer = backend.arr_to_sharedarr(
|
611
|
+
arr=target_ft2, shared_memory_handler=shared_memory_handler
|
612
|
+
)
|
613
|
+
|
614
|
+
target_mask_ft = backend.preallocate_array(fast_ft_shape, complex_dtype)
|
615
|
+
rfftn(target_mask_pad, target_mask_ft)
|
616
|
+
target_mask_ft_buffer = backend.arr_to_sharedarr(
|
617
|
+
arr=target_mask_ft, shared_memory_handler=shared_memory_handler
|
618
|
+
)
|
619
|
+
|
620
|
+
template_buffer = backend.arr_to_sharedarr(
|
621
|
+
arr=template, shared_memory_handler=shared_memory_handler
|
622
|
+
)
|
623
|
+
template_mask_buffer = backend.arr_to_sharedarr(
|
624
|
+
arr=template_mask, shared_memory_handler=shared_memory_handler
|
625
|
+
)
|
626
|
+
|
627
|
+
template_tuple = (template_buffer, template.shape, real_dtype)
|
628
|
+
template_mask_tuple = (template_mask_buffer, template.shape, real_dtype)
|
629
|
+
|
630
|
+
target_ft_tuple = (target_ft_buffer, fast_ft_shape, complex_dtype)
|
631
|
+
target_ft2_tuple = (target_ft2_buffer, fast_ft_shape, complex_dtype)
|
632
|
+
target_mask_ft_tuple = (target_mask_ft_buffer, fast_ft_shape, complex_dtype)
|
633
|
+
|
634
|
+
ret = {
|
635
|
+
"template": template_tuple,
|
636
|
+
"template_mask": template_mask_tuple,
|
637
|
+
"ft_target": target_ft_tuple,
|
638
|
+
"ft_target2": target_ft2_tuple,
|
639
|
+
"ft_target_mask": target_mask_ft_tuple,
|
640
|
+
"targetshape": target.shape,
|
641
|
+
"templateshape": template.shape,
|
642
|
+
"fast_shape": fast_shape,
|
643
|
+
"fast_ft_shape": fast_ft_shape,
|
644
|
+
"real_dtype": real_dtype,
|
645
|
+
"complex_dtype": complex_dtype,
|
646
|
+
"callback_class": callback_class,
|
647
|
+
"callback_class_args": callback_class_args,
|
648
|
+
}
|
649
|
+
|
650
|
+
return ret
|
651
|
+
|
652
|
+
|
653
|
+
def corr_scoring(
|
654
|
+
template: Tuple[type, Tuple[int], type],
|
655
|
+
ft_target: Tuple[type, Tuple[int], type],
|
656
|
+
inv_denominator: Tuple[type, Tuple[int], type],
|
657
|
+
numerator2: Tuple[type, Tuple[int], type],
|
658
|
+
template_filter: Tuple[type, Tuple[int], type],
|
659
|
+
targetshape: Tuple[int],
|
660
|
+
templateshape: Tuple[int],
|
661
|
+
fast_shape: Tuple[int],
|
662
|
+
fast_ft_shape: Tuple[int],
|
663
|
+
rotations: NDArray,
|
664
|
+
real_dtype: type,
|
665
|
+
complex_dtype: type,
|
666
|
+
callback_class: CallbackClass,
|
667
|
+
callback_class_args: Dict,
|
668
|
+
interpolation_order: int,
|
669
|
+
convolution_mode: str = "full",
|
670
|
+
**kwargs,
|
671
|
+
) -> CallbackClass:
|
672
|
+
"""
|
673
|
+
Calculates a normalized cross-correlation between a target f and a template g.
|
674
|
+
|
675
|
+
.. math::
|
676
|
+
|
677
|
+
(CC(f,g) - numerator2) \\cdot inv\\_denominator
|
678
|
+
|
679
|
+
Parameters
|
680
|
+
----------
|
681
|
+
template : Tuple[type, Tuple[int], type]
|
682
|
+
Tuple containing a pointer to the template data, its shape, and its datatype.
|
683
|
+
ft_target : Tuple[type, Tuple[int], type]
|
684
|
+
Tuple containing a pointer to the fourier tranform of the target,
|
685
|
+
its shape, and its datatype.
|
686
|
+
inv_denominator : Tuple[type, Tuple[int], type]
|
687
|
+
Tuple containing a pointer to the inverse denominator data, its shape, and its
|
688
|
+
datatype.
|
689
|
+
numerator2 : Tuple[type, Tuple[int], type]
|
690
|
+
Tuple containing a pointer to the numerator2 data, its shape, and its datatype.
|
691
|
+
targetshape : Tuple[int]
|
692
|
+
The shape of the target.
|
693
|
+
templateshape : Tuple[int]
|
694
|
+
The shape of the template.
|
695
|
+
fast_shape : Tuple[int]
|
696
|
+
The shape for fast Fourier transform.
|
697
|
+
fast_ft_shape : Tuple[int]
|
698
|
+
The shape for fast Fourier transform of the target.
|
699
|
+
rotations : NDArray
|
700
|
+
Array containing the rotation matrices to be applied on the template.
|
701
|
+
real_dtype : type
|
702
|
+
Data type for the real part of the array.
|
703
|
+
complex_dtype : type
|
704
|
+
Data type for the complex part of the array.
|
705
|
+
callback_class : CallbackClass
|
706
|
+
A callable class or function for processing the results after each
|
707
|
+
rotation.
|
708
|
+
callback_class_args : Dict
|
709
|
+
Dictionary of arguments to be passed to the callback class if it's
|
710
|
+
instantiable.
|
711
|
+
interpolation_order : int
|
712
|
+
The order of interpolation to be used while rotating the template.
|
713
|
+
convolution_mode : str, optional
|
714
|
+
Mode to use for convolution, default is "full".
|
715
|
+
**kwargs :
|
716
|
+
Additional arguments to be passed to the function.
|
717
|
+
|
718
|
+
Returns
|
719
|
+
-------
|
720
|
+
CallbackClass
|
721
|
+
If callback_class was provided an instance of callback_class otherwise None.
|
722
|
+
|
723
|
+
See Also
|
724
|
+
--------
|
725
|
+
:py:meth:`cc_setup`
|
726
|
+
:py:meth:`corr_setup`
|
727
|
+
:py:meth:`cam_setup`
|
728
|
+
:py:meth:`flcSphericalMask_setup`
|
729
|
+
"""
|
730
|
+
template_buffer, template_shape, template_dtype = template
|
731
|
+
ft_target_buffer, ft_target_shape, ft_target_dtype = ft_target
|
732
|
+
inv_denominator_buffer, inv_denominator_pointer_shape, _ = inv_denominator
|
733
|
+
numerator2_buffer, numerator2_shape, _ = numerator2
|
734
|
+
filter_buffer, filter_shape, filter_dtype = template_filter
|
735
|
+
|
736
|
+
if callback_class is not None and isinstance(callback_class, type):
|
737
|
+
callback = callback_class(**callback_class_args)
|
738
|
+
elif not isinstance(callback_class, type):
|
739
|
+
callback = callback_class
|
740
|
+
|
741
|
+
# Retrieve objects from shared memory
|
742
|
+
template = backend.sharedarr_to_arr(template_shape, template_dtype, template_buffer)
|
743
|
+
ft_target = backend.sharedarr_to_arr(
|
744
|
+
ft_target_shape, ft_target_dtype, ft_target_buffer
|
745
|
+
)
|
746
|
+
inv_denominator = backend.sharedarr_to_arr(
|
747
|
+
inv_denominator_pointer_shape, template_dtype, inv_denominator_buffer
|
748
|
+
)
|
749
|
+
numerator2 = backend.sharedarr_to_arr(
|
750
|
+
numerator2_shape, template_dtype, numerator2_buffer
|
751
|
+
)
|
752
|
+
template_filter = backend.sharedarr_to_arr(
|
753
|
+
filter_shape, filter_dtype, filter_buffer
|
754
|
+
)
|
755
|
+
|
756
|
+
arr = backend.preallocate_array(fast_shape, real_dtype)
|
757
|
+
ft_temp = backend.preallocate_array(fast_ft_shape, complex_dtype)
|
758
|
+
|
759
|
+
rfftn, irfftn = backend.build_fft(
|
760
|
+
fast_shape=fast_shape,
|
761
|
+
fast_ft_shape=fast_ft_shape,
|
762
|
+
real_dtype=real_dtype,
|
763
|
+
complex_dtype=complex_dtype,
|
764
|
+
fftargs=kwargs.get("fftargs", {}),
|
765
|
+
temp_real=arr,
|
766
|
+
temp_fft=ft_temp,
|
767
|
+
)
|
768
|
+
|
769
|
+
norm_numerator = (backend.sum(numerator2) != 0) & (backend.size(numerator2) != 1)
|
770
|
+
norm_denominator = (backend.sum(inv_denominator) != 1) & (
|
771
|
+
backend.size(inv_denominator) != 1
|
772
|
+
)
|
773
|
+
filter_template = backend.size(template_filter) != 0
|
774
|
+
|
775
|
+
norm_func_numerator = conditional_execute(backend.subtract, norm_numerator)
|
776
|
+
norm_func_denominator = conditional_execute(backend.multiply, norm_denominator)
|
777
|
+
template_filter_func = conditional_execute(backend.multiply, filter_template)
|
778
|
+
|
779
|
+
axis = tuple(range(arr.ndim))
|
780
|
+
fourier_shift = callback_class_args.get("fourier_shift", backend.zeros(arr.ndim))
|
781
|
+
fourier_shift_scores = backend.sum(fourier_shift != 0) != 0
|
782
|
+
|
783
|
+
for index in range(rotations.shape[0]):
|
784
|
+
rotation = rotations[index]
|
785
|
+
backend.fill(arr, 0)
|
786
|
+
backend.rotate_array(
|
787
|
+
arr=template,
|
788
|
+
rotation_matrix=rotation,
|
789
|
+
out=arr,
|
790
|
+
use_geometric_center=False,
|
791
|
+
order=interpolation_order,
|
792
|
+
)
|
793
|
+
rfftn(arr, ft_temp)
|
794
|
+
template_filter_func(ft_temp, template_filter, out=ft_temp)
|
795
|
+
|
796
|
+
backend.multiply(ft_target, ft_temp, out=ft_temp)
|
797
|
+
irfftn(ft_temp, arr)
|
798
|
+
|
799
|
+
norm_func_numerator(arr, numerator2, out=arr)
|
800
|
+
norm_func_denominator(arr, inv_denominator, out=arr)
|
801
|
+
|
802
|
+
if fourier_shift_scores:
|
803
|
+
arr = backend.roll(arr, shift=fourier_shift, axis=axis)
|
804
|
+
|
805
|
+
score = apply_convolution_mode(
|
806
|
+
arr, convolution_mode=convolution_mode, s1=targetshape, s2=templateshape
|
807
|
+
)
|
808
|
+
|
809
|
+
if callback_class is not None:
|
810
|
+
callback(
|
811
|
+
score,
|
812
|
+
rotation_matrix=rotation,
|
813
|
+
rotation_index=index,
|
814
|
+
**callback_class_args,
|
815
|
+
)
|
816
|
+
|
817
|
+
return callback
|
818
|
+
|
819
|
+
|
820
|
+
def flc_scoring(
|
821
|
+
template: Tuple[type, Tuple[int], type],
|
822
|
+
template_mask: Tuple[type, Tuple[int], type],
|
823
|
+
ft_target: Tuple[type, Tuple[int], type],
|
824
|
+
ft_target2: Tuple[type, Tuple[int], type],
|
825
|
+
template_filter: Tuple[type, Tuple[int], type],
|
826
|
+
targetshape: Tuple[int],
|
827
|
+
templateshape: Tuple[int],
|
828
|
+
fast_shape: Tuple[int],
|
829
|
+
fast_ft_shape: Tuple[int],
|
830
|
+
rotations: NDArray,
|
831
|
+
real_dtype: type,
|
832
|
+
complex_dtype: type,
|
833
|
+
callback_class: CallbackClass,
|
834
|
+
callback_class_args: Dict,
|
835
|
+
interpolation_order: int,
|
836
|
+
**kwargs,
|
837
|
+
) -> CallbackClass:
|
838
|
+
"""
|
839
|
+
Computes a normalized cross-correlation score of a target f a template g
|
840
|
+
and a mask m:
|
841
|
+
|
842
|
+
.. math::
|
843
|
+
|
844
|
+
\\frac{CC(f, \\frac{g*m - \\overline{g*m}}{\\sigma_{g*m}})}
|
845
|
+
{N_m * \\sqrt{
|
846
|
+
\\frac{CC(f^2, m)}{N_m} - (\\frac{CC(f, m)}{N_m})^2}
|
847
|
+
}
|
848
|
+
|
849
|
+
Where:
|
850
|
+
|
851
|
+
.. math::
|
852
|
+
|
853
|
+
CC(f,g) = \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*)
|
854
|
+
|
855
|
+
and Nm is the number of voxels within the template mask m.
|
856
|
+
|
857
|
+
References
|
858
|
+
----------
|
859
|
+
.. [1] W. Wan, S. Khavnekar, J. Wagner, P. Erdmann, and W. Baumeister
|
860
|
+
Microsc. Microanal. 26, 2516 (2020)
|
861
|
+
.. [2] T. Hrabe, Y. Chen, S. Pfeffer, L. Kuhn Cuellar, A.-V. Mangold,
|
862
|
+
and F. Förster, J. Struct. Biol. 178, 177 (2012).
|
863
|
+
"""
|
864
|
+
template_buffer, template_shape, template_dtype = template
|
865
|
+
template_mask_buffer, *_ = template_mask
|
866
|
+
filter_buffer, filter_shape, filter_dtype = template_filter
|
867
|
+
|
868
|
+
ft_target_buffer, ft_target_shape, ft_target_dtype = ft_target
|
869
|
+
ft_target2_buffer, *_ = ft_target2
|
870
|
+
|
871
|
+
if callback_class is not None and isinstance(callback_class, type):
|
872
|
+
callback = callback_class(**callback_class_args)
|
873
|
+
elif not isinstance(callback_class, type):
|
874
|
+
callback = callback_class
|
875
|
+
|
876
|
+
# Retrieve objects from shared memory
|
877
|
+
template = backend.sharedarr_to_arr(template_shape, template_dtype, template_buffer)
|
878
|
+
template_mask = backend.sharedarr_to_arr(
|
879
|
+
template_shape, template_dtype, template_mask_buffer
|
880
|
+
)
|
881
|
+
ft_target = backend.sharedarr_to_arr(
|
882
|
+
ft_target_shape, ft_target_dtype, ft_target_buffer
|
883
|
+
)
|
884
|
+
ft_target2 = backend.sharedarr_to_arr(
|
885
|
+
ft_target_shape, ft_target_dtype, ft_target2_buffer
|
886
|
+
)
|
887
|
+
template_filter = backend.sharedarr_to_arr(
|
888
|
+
filter_shape, filter_dtype, filter_buffer
|
889
|
+
)
|
890
|
+
|
891
|
+
arr = backend.preallocate_array(fast_shape, real_dtype)
|
892
|
+
temp = backend.preallocate_array(fast_shape, real_dtype)
|
893
|
+
temp2 = backend.preallocate_array(fast_shape, real_dtype)
|
894
|
+
|
895
|
+
ft_temp = backend.preallocate_array(fast_ft_shape, complex_dtype)
|
896
|
+
ft_denom = backend.preallocate_array(fast_ft_shape, complex_dtype)
|
897
|
+
|
898
|
+
rfftn, irfftn = backend.build_fft(
|
899
|
+
fast_shape=fast_shape,
|
900
|
+
fast_ft_shape=fast_ft_shape,
|
901
|
+
real_dtype=real_dtype,
|
902
|
+
complex_dtype=complex_dtype,
|
903
|
+
fftargs=kwargs.get("fftargs", {}),
|
904
|
+
temp_real=arr,
|
905
|
+
temp_fft=ft_temp,
|
906
|
+
)
|
907
|
+
eps = backend.eps(real_dtype)
|
908
|
+
filter_template = backend.size(template_filter) != 0
|
909
|
+
template_filter_func = conditional_execute(backend.multiply, filter_template)
|
910
|
+
|
911
|
+
axis = tuple(range(arr.ndim))
|
912
|
+
fourier_shift = callback_class_args.get("fourier_shift", backend.zeros(arr.ndim))
|
913
|
+
fourier_shift_scores = backend.sum(fourier_shift != 0) != 0
|
914
|
+
|
915
|
+
for index in range(rotations.shape[0]):
|
916
|
+
rotation = rotations[index]
|
917
|
+
backend.fill(arr, 0)
|
918
|
+
backend.fill(temp, 0)
|
919
|
+
backend.rotate_array(
|
920
|
+
arr=template,
|
921
|
+
arr_mask=template_mask,
|
922
|
+
rotation_matrix=rotation,
|
923
|
+
out=arr,
|
924
|
+
out_mask=temp,
|
925
|
+
use_geometric_center=False,
|
926
|
+
order=interpolation_order,
|
927
|
+
)
|
928
|
+
n_observations = backend.sum(temp)
|
929
|
+
|
930
|
+
rfftn(temp, ft_temp)
|
931
|
+
|
932
|
+
backend.multiply(ft_target, ft_temp, out=ft_denom)
|
933
|
+
irfftn(ft_denom, temp)
|
934
|
+
backend.divide(temp, n_observations, out=temp)
|
935
|
+
backend.square(temp, out=temp)
|
936
|
+
|
937
|
+
backend.multiply(ft_target2, ft_temp, out=ft_denom)
|
938
|
+
irfftn(ft_denom, temp2)
|
939
|
+
backend.divide(temp2, n_observations, out=temp2)
|
940
|
+
|
941
|
+
backend.subtract(temp2, temp, out=temp)
|
942
|
+
backend.maximum(temp, 0.0, out=temp)
|
943
|
+
backend.sqrt(temp, out=temp)
|
944
|
+
backend.multiply(temp, n_observations, out=temp)
|
945
|
+
|
946
|
+
rfftn(arr, ft_temp)
|
947
|
+
template_filter_func(ft_temp, template_filter, out=ft_temp)
|
948
|
+
backend.multiply(ft_target, ft_temp, out=ft_temp)
|
949
|
+
irfftn(ft_temp, arr)
|
950
|
+
|
951
|
+
tol = tol = 1e3 * eps * backend.max(backend.abs(temp))
|
952
|
+
nonzero_indices = temp > tol
|
953
|
+
backend.fill(temp2, 0)
|
954
|
+
temp2[nonzero_indices] = arr[nonzero_indices] / temp[nonzero_indices]
|
955
|
+
|
956
|
+
convolution_mode = kwargs.get("convolution_mode", "full")
|
957
|
+
|
958
|
+
if fourier_shift_scores:
|
959
|
+
temp2 = backend.roll(temp2, shift=fourier_shift, axis=axis)
|
960
|
+
|
961
|
+
score = apply_convolution_mode(
|
962
|
+
temp2, convolution_mode=convolution_mode, s1=targetshape, s2=templateshape
|
963
|
+
)
|
964
|
+
|
965
|
+
if callback_class is not None:
|
966
|
+
callback(
|
967
|
+
score,
|
968
|
+
rotation_matrix=rotation,
|
969
|
+
rotation_index=index,
|
970
|
+
**callback_class_args,
|
971
|
+
)
|
972
|
+
|
973
|
+
return callback
|
974
|
+
|
975
|
+
|
976
|
+
def mcc_scoring(
|
977
|
+
template: Tuple[type, Tuple[int], type],
|
978
|
+
template_mask: Tuple[type, Tuple[int], type],
|
979
|
+
ft_target: Tuple[type, Tuple[int], type],
|
980
|
+
ft_target2: Tuple[type, Tuple[int], type],
|
981
|
+
ft_target_mask: Tuple[type, Tuple[int], type],
|
982
|
+
template_filter: Tuple[type, Tuple[int], type],
|
983
|
+
targetshape: Tuple[int],
|
984
|
+
templateshape: Tuple[int],
|
985
|
+
fast_shape: Tuple[int],
|
986
|
+
fast_ft_shape: Tuple[int],
|
987
|
+
rotations: NDArray,
|
988
|
+
real_dtype: type,
|
989
|
+
complex_dtype: type,
|
990
|
+
callback_class: CallbackClass,
|
991
|
+
callback_class_args: type,
|
992
|
+
interpolation_order: int,
|
993
|
+
overlap_ratio: float = 0.3,
|
994
|
+
**kwargs,
|
995
|
+
) -> CallbackClass:
|
996
|
+
"""
|
997
|
+
Computes a cross-correlation score with masks for both template and target.
|
998
|
+
|
999
|
+
.. math::
|
1000
|
+
|
1001
|
+
\\frac{
|
1002
|
+
CC(f, g) - \\frac{CC(f, m) \\cdot CC(t, g)}{CC(t, m)}
|
1003
|
+
}{
|
1004
|
+
\\sqrt{
|
1005
|
+
(CC(f ^ 2, m) - \\frac{CC(f, m) ^ 2}{CC(t, m)}) \\cdot
|
1006
|
+
(CC(t, g^2) - \\frac{CC(t, g) ^ 2}{CC(t, m)})
|
1007
|
+
}
|
1008
|
+
}
|
1009
|
+
|
1010
|
+
Where:
|
1011
|
+
|
1012
|
+
.. math::
|
1013
|
+
|
1014
|
+
CC(f,g) = \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*)
|
1015
|
+
|
1016
|
+
|
1017
|
+
References
|
1018
|
+
----------
|
1019
|
+
.. [1] Masked FFT registration, Dirk Padfield, CVPR 2010 conference
|
1020
|
+
.. [2] https://scikit-image.org/docs/stable/api/skimage.registration.html
|
1021
|
+
|
1022
|
+
See Also
|
1023
|
+
--------
|
1024
|
+
:py:class:`tme.matching_optimization.MaskedCrossCorrelation`
|
1025
|
+
"""
|
1026
|
+
template_buffer, template_shape, template_dtype = template
|
1027
|
+
ft_target_buffer, ft_target_shape, ft_target_dtype = ft_target
|
1028
|
+
ft_target2_buffer, ft_target_shape, ft_target_dtype = ft_target2
|
1029
|
+
template_mask_buffer, _, _ = template
|
1030
|
+
ft_target_mask_buffer, _, _ = ft_target
|
1031
|
+
filter_buffer, filter_shape, filter_dtype = template_filter
|
1032
|
+
|
1033
|
+
if callback_class is not None and isinstance(callback_class, type):
|
1034
|
+
callback = callback_class(**callback_class_args)
|
1035
|
+
elif not isinstance(callback_class, type):
|
1036
|
+
callback = callback_class
|
1037
|
+
|
1038
|
+
# Retrieve objects from shared memory
|
1039
|
+
template = backend.sharedarr_to_arr(template_shape, template_dtype, template_buffer)
|
1040
|
+
target_ft = backend.sharedarr_to_arr(
|
1041
|
+
ft_target_shape, ft_target_dtype, ft_target_buffer
|
1042
|
+
)
|
1043
|
+
target_ft2 = backend.sharedarr_to_arr(
|
1044
|
+
ft_target_shape, ft_target_dtype, ft_target2_buffer
|
1045
|
+
)
|
1046
|
+
template_mask = backend.sharedarr_to_arr(
|
1047
|
+
template_shape, template_dtype, template_mask_buffer
|
1048
|
+
)
|
1049
|
+
target_mask_ft = backend.sharedarr_to_arr(
|
1050
|
+
ft_target_shape, ft_target_dtype, ft_target_mask_buffer
|
1051
|
+
)
|
1052
|
+
template_filter = backend.sharedarr_to_arr(
|
1053
|
+
filter_shape, filter_dtype, filter_buffer
|
1054
|
+
)
|
1055
|
+
|
1056
|
+
axes = tuple(range(template.ndim))
|
1057
|
+
eps = backend.eps(real_dtype)
|
1058
|
+
|
1059
|
+
# Allocate score and process specific arrays
|
1060
|
+
template_rot = backend.preallocate_array(fast_shape, real_dtype)
|
1061
|
+
mask_overlap = backend.preallocate_array(fast_shape, real_dtype)
|
1062
|
+
numerator = backend.preallocate_array(fast_shape, real_dtype)
|
1063
|
+
|
1064
|
+
temp = backend.preallocate_array(fast_shape, real_dtype)
|
1065
|
+
temp2 = backend.preallocate_array(fast_shape, real_dtype)
|
1066
|
+
temp3 = backend.preallocate_array(fast_shape, real_dtype)
|
1067
|
+
temp_ft = backend.preallocate_array(fast_ft_shape, complex_dtype)
|
1068
|
+
|
1069
|
+
rfftn, irfftn = backend.build_fft(
|
1070
|
+
fast_shape=fast_shape,
|
1071
|
+
fast_ft_shape=fast_ft_shape,
|
1072
|
+
real_dtype=real_dtype,
|
1073
|
+
complex_dtype=complex_dtype,
|
1074
|
+
fftargs=kwargs.get("fftargs", {}),
|
1075
|
+
temp_real=numerator,
|
1076
|
+
temp_fft=temp_ft,
|
1077
|
+
)
|
1078
|
+
|
1079
|
+
filter_template = backend.size(template_filter) != 0
|
1080
|
+
template_filter_func = conditional_execute(backend.multiply, filter_template)
|
1081
|
+
|
1082
|
+
axis = tuple(range(template.ndim))
|
1083
|
+
fourier_shift = callback_class_args.get(
|
1084
|
+
"fourier_shift", backend.zeros(template.ndim)
|
1085
|
+
)
|
1086
|
+
fourier_shift_scores = backend.sum(fourier_shift != 0) != 0
|
1087
|
+
|
1088
|
+
# Calculate scores across all rotations
|
1089
|
+
for index in range(rotations.shape[0]):
|
1090
|
+
rotation = rotations[index]
|
1091
|
+
backend.fill(template_rot, 0)
|
1092
|
+
backend.fill(temp, 0)
|
1093
|
+
|
1094
|
+
backend.rotate_array(
|
1095
|
+
arr=template,
|
1096
|
+
arr_mask=template_mask,
|
1097
|
+
rotation_matrix=rotation,
|
1098
|
+
out=template_rot,
|
1099
|
+
out_mask=temp,
|
1100
|
+
use_geometric_center=False,
|
1101
|
+
order=interpolation_order,
|
1102
|
+
)
|
1103
|
+
|
1104
|
+
backend.multiply(template_rot, temp > 0, out=template_rot)
|
1105
|
+
|
1106
|
+
# template_rot_ft
|
1107
|
+
rfftn(template_rot, temp_ft)
|
1108
|
+
template_filter_func(temp_ft, template_filter, out=temp_ft)
|
1109
|
+
irfftn(target_mask_ft * temp_ft, temp2)
|
1110
|
+
irfftn(target_ft * temp_ft, numerator)
|
1111
|
+
|
1112
|
+
# temp template_mask_rot | temp_ft template_mask_rot_ft
|
1113
|
+
# Calculate overlap of masks at every point in the convolution.
|
1114
|
+
# Locations with high overlap should not be taken into account.
|
1115
|
+
rfftn(temp, temp_ft)
|
1116
|
+
irfftn(temp_ft * target_mask_ft, mask_overlap)
|
1117
|
+
mask_overlap[:] = np.round(mask_overlap)
|
1118
|
+
mask_overlap[:] = np.maximum(mask_overlap, eps)
|
1119
|
+
irfftn(temp_ft * target_ft, temp)
|
1120
|
+
|
1121
|
+
backend.subtract(
|
1122
|
+
numerator,
|
1123
|
+
backend.divide(backend.multiply(temp, temp2), mask_overlap),
|
1124
|
+
out=numerator,
|
1125
|
+
)
|
1126
|
+
|
1127
|
+
# temp_3 = fixed_denom
|
1128
|
+
backend.multiply(temp_ft, target_ft2, out=temp_ft)
|
1129
|
+
irfftn(temp_ft, temp3)
|
1130
|
+
backend.subtract(
|
1131
|
+
temp3, backend.divide(backend.square(temp), mask_overlap), out=temp3
|
1132
|
+
)
|
1133
|
+
backend.maximum(temp3, 0.0, out=temp3)
|
1134
|
+
|
1135
|
+
# temp = moving_denom
|
1136
|
+
rfftn(backend.square(template_rot), temp_ft)
|
1137
|
+
backend.multiply(target_mask_ft, temp_ft, out=temp_ft)
|
1138
|
+
irfftn(temp_ft, temp)
|
1139
|
+
|
1140
|
+
backend.subtract(
|
1141
|
+
temp, backend.divide(backend.square(temp2), mask_overlap), out=temp
|
1142
|
+
)
|
1143
|
+
backend.maximum(temp, 0.0, out=temp)
|
1144
|
+
|
1145
|
+
# temp_2 = denom
|
1146
|
+
backend.multiply(temp3, temp, out=temp)
|
1147
|
+
backend.sqrt(temp, temp2)
|
1148
|
+
|
1149
|
+
# Pixels where `denom` is very small will introduce large
|
1150
|
+
# numbers after division. To get around this problem,
|
1151
|
+
# we zero-out problematic pixels.
|
1152
|
+
tol = 1e3 * eps * backend.max(backend.abs(temp2), axis=axes, keepdims=True)
|
1153
|
+
nonzero_indices = temp2 > tol
|
1154
|
+
|
1155
|
+
backend.fill(temp, 0)
|
1156
|
+
temp[nonzero_indices] = numerator[nonzero_indices] / temp2[nonzero_indices]
|
1157
|
+
backend.clip(temp, a_min=-1, a_max=1, out=temp)
|
1158
|
+
|
1159
|
+
# Apply overlap ratio threshold
|
1160
|
+
number_px_threshold = overlap_ratio * backend.max(
|
1161
|
+
mask_overlap, axis=axes, keepdims=True
|
1162
|
+
)
|
1163
|
+
temp[mask_overlap < number_px_threshold] = 0.0
|
1164
|
+
convolution_mode = kwargs.get("convolution_mode", "full")
|
1165
|
+
|
1166
|
+
if fourier_shift_scores:
|
1167
|
+
temp = backend.roll(temp, shift=fourier_shift, axis=axis)
|
1168
|
+
|
1169
|
+
score = apply_convolution_mode(
|
1170
|
+
temp, convolution_mode=convolution_mode, s1=targetshape, s2=templateshape
|
1171
|
+
)
|
1172
|
+
if callback_class is not None:
|
1173
|
+
callback(
|
1174
|
+
score,
|
1175
|
+
rotation_matrix=rotation,
|
1176
|
+
rotation_index=index,
|
1177
|
+
**callback_class_args,
|
1178
|
+
)
|
1179
|
+
|
1180
|
+
return callback
|
1181
|
+
|
1182
|
+
|
1183
|
+
def device_memory_handler(func: Callable):
|
1184
|
+
"""Decorator function providing SharedMemory Handler."""
|
1185
|
+
|
1186
|
+
@wraps(func)
|
1187
|
+
def inner_function(*args, **kwargs):
|
1188
|
+
return_value = None
|
1189
|
+
last_type, last_value, last_traceback = sys.exc_info()
|
1190
|
+
try:
|
1191
|
+
with SharedMemoryManager() as smh:
|
1192
|
+
with backend.set_device(kwargs.get("gpu_index", 0)):
|
1193
|
+
return_value = func(shared_memory_handler=smh, *args, **kwargs)
|
1194
|
+
except Exception as e:
|
1195
|
+
print(e)
|
1196
|
+
last_type, last_value, last_traceback = sys.exc_info()
|
1197
|
+
finally:
|
1198
|
+
handle_traceback(last_type, last_value, last_traceback)
|
1199
|
+
return return_value
|
1200
|
+
|
1201
|
+
return inner_function
|
1202
|
+
|
1203
|
+
|
1204
|
+
@device_memory_handler
|
1205
|
+
def scan(
|
1206
|
+
matching_data: MatchingData,
|
1207
|
+
matching_setup: Callable,
|
1208
|
+
matching_score: Callable,
|
1209
|
+
n_jobs: int = 4,
|
1210
|
+
callback_class: CallbackClass = None,
|
1211
|
+
callback_class_args: Dict = {},
|
1212
|
+
fftargs: Dict = {},
|
1213
|
+
pad_fourier: bool = True,
|
1214
|
+
interpolation_order: int = 3,
|
1215
|
+
jobs_per_callback_class: int = 8,
|
1216
|
+
**kwargs,
|
1217
|
+
) -> Tuple:
|
1218
|
+
"""
|
1219
|
+
Perform template matching between target and template and sample
|
1220
|
+
different rotations of template.
|
1221
|
+
|
1222
|
+
Parameters
|
1223
|
+
----------
|
1224
|
+
matching_data : MatchingData
|
1225
|
+
Template matching data.
|
1226
|
+
matching_setup : Callable
|
1227
|
+
Function pointer to setup function.
|
1228
|
+
matching_score : Callable
|
1229
|
+
Function pointer to scoring function.
|
1230
|
+
n_jobs : int, optional
|
1231
|
+
Number of parallel jobs. Default is 4.
|
1232
|
+
callback_class : type, optional
|
1233
|
+
Analyzer class pointer to operate on computed scores.
|
1234
|
+
callback_class_args : dict, optional
|
1235
|
+
Arguments passed to the callback_class. Default is an empty dictionary.
|
1236
|
+
fftargs : dict, optional
|
1237
|
+
Arguments for the FFT operations. Default is an empty dictionary.
|
1238
|
+
pad_fourier: bool, optional
|
1239
|
+
Whether to pad target and template to the full convolution shape.
|
1240
|
+
interpolation_order : int, optional
|
1241
|
+
Order of spline interpolation for rotations.
|
1242
|
+
jobs_per_callback_class : int, optional
|
1243
|
+
How many jobs should be processed by a single callback_class instance,
|
1244
|
+
if ones is provided.
|
1245
|
+
**kwargs : various
|
1246
|
+
Additional arguments.
|
1247
|
+
|
1248
|
+
Returns
|
1249
|
+
-------
|
1250
|
+
Tuple
|
1251
|
+
The merged results from callback_class if provided otherwise None.
|
1252
|
+
"""
|
1253
|
+
matching_data.to_backend()
|
1254
|
+
fourier_pad = matching_data._templateshape
|
1255
|
+
fourier_shift = backend.zeros(len(fourier_pad))
|
1256
|
+
if not pad_fourier:
|
1257
|
+
fourier_pad = backend.full(shape=fourier_shift.shape, fill_value=1, dtype=int)
|
1258
|
+
|
1259
|
+
convolution_shape, fast_shape, fast_ft_shape = backend.compute_convolution_shapes(
|
1260
|
+
matching_data._target.shape, fourier_pad
|
1261
|
+
)
|
1262
|
+
if not pad_fourier:
|
1263
|
+
fourier_shift = 1 - backend.astype(
|
1264
|
+
backend.divide(matching_data._templateshape, 2), int
|
1265
|
+
)
|
1266
|
+
fourier_shift -= backend.mod(matching_data._templateshape, 2)
|
1267
|
+
fourier_shift = backend.flip(fourier_shift, axis=(0,))
|
1268
|
+
shape_diff = backend.subtract(fast_shape, convolution_shape)
|
1269
|
+
shape_diff = backend.astype(backend.divide(shape_diff, 2), int)
|
1270
|
+
backend.add(fourier_shift, shape_diff, out=fourier_shift)
|
1271
|
+
|
1272
|
+
callback_class_args["fourier_shift"] = fourier_shift
|
1273
|
+
rfftn, irfftn = backend.build_fft(
|
1274
|
+
fast_shape=fast_shape,
|
1275
|
+
fast_ft_shape=fast_ft_shape,
|
1276
|
+
real_dtype=matching_data._default_dtype,
|
1277
|
+
complex_dtype=matching_data._complex_dtype,
|
1278
|
+
fftargs=fftargs,
|
1279
|
+
)
|
1280
|
+
setup = matching_setup(
|
1281
|
+
rfftn=rfftn,
|
1282
|
+
irfftn=irfftn,
|
1283
|
+
template=matching_data.template,
|
1284
|
+
template_mask=matching_data.template_mask,
|
1285
|
+
target=matching_data.target,
|
1286
|
+
target_mask=matching_data.target_mask,
|
1287
|
+
fast_shape=fast_shape,
|
1288
|
+
fast_ft_shape=fast_ft_shape,
|
1289
|
+
real_dtype=matching_data._default_dtype,
|
1290
|
+
complex_dtype=matching_data._complex_dtype,
|
1291
|
+
callback_class=callback_class,
|
1292
|
+
callback_class_args=callback_class_args,
|
1293
|
+
**kwargs,
|
1294
|
+
)
|
1295
|
+
rfftn, irfftn = None, None
|
1296
|
+
|
1297
|
+
template_filter, preprocessor = None, Preprocessor()
|
1298
|
+
for method, parameters in matching_data.template_filter.items():
|
1299
|
+
parameters["shape"] = fast_shape
|
1300
|
+
parameters["omit_negative_frequencies"] = True
|
1301
|
+
out = preprocessor.apply_method(method=method, parameters=parameters)
|
1302
|
+
if template_filter is None:
|
1303
|
+
template_filter = out
|
1304
|
+
np.multiply(template_filter, out, out=template_filter)
|
1305
|
+
|
1306
|
+
if template_filter is None:
|
1307
|
+
template_filter = backend.full(
|
1308
|
+
shape=(1,), fill_value=1, dtype=backend._default_dtype
|
1309
|
+
)
|
1310
|
+
else:
|
1311
|
+
template_filter = backend.to_backend_array(template_filter)
|
1312
|
+
|
1313
|
+
template_filter = backend.astype(template_filter, backend._default_dtype)
|
1314
|
+
template_filter_buffer = backend.arr_to_sharedarr(
|
1315
|
+
arr=template_filter,
|
1316
|
+
shared_memory_handler=kwargs.get("shared_memory_handler", None),
|
1317
|
+
)
|
1318
|
+
setup["template_filter"] = (
|
1319
|
+
template_filter_buffer,
|
1320
|
+
template_filter.shape,
|
1321
|
+
template_filter.dtype,
|
1322
|
+
)
|
1323
|
+
|
1324
|
+
callback_class_args["translation_offset"] = backend.astype(
|
1325
|
+
matching_data._translation_offset, int
|
1326
|
+
)
|
1327
|
+
callback_class_args["thread_safe"] = n_jobs > 1
|
1328
|
+
callback_class_args["gpu_index"] = kwargs.get("gpu_index", -1)
|
1329
|
+
|
1330
|
+
n_callback_classes = max(n_jobs // jobs_per_callback_class, 1)
|
1331
|
+
callback_class = setup.pop("callback_class", callback_class)
|
1332
|
+
callback_class_args = setup.pop("callback_class_args", callback_class_args)
|
1333
|
+
callback_classes = [callback_class for _ in range(n_callback_classes)]
|
1334
|
+
if callback_class == MaxScoreOverRotations:
|
1335
|
+
score_space_shape = backend.subtract(
|
1336
|
+
matching_data.target.shape,
|
1337
|
+
matching_data._target_pad,
|
1338
|
+
)
|
1339
|
+
callback_classes = [
|
1340
|
+
class_name(
|
1341
|
+
score_space_shape=score_space_shape,
|
1342
|
+
score_space_dtype=matching_data._default_dtype,
|
1343
|
+
shared_memory_handler=kwargs.get("shared_memory_handler", None),
|
1344
|
+
rotation_space_dtype=backend._default_dtype_int,
|
1345
|
+
**callback_class_args,
|
1346
|
+
)
|
1347
|
+
for class_name in callback_classes
|
1348
|
+
]
|
1349
|
+
|
1350
|
+
matching_data._target, matching_data._template = None, None
|
1351
|
+
matching_data._target_mask, matching_data._template_mask = None, None
|
1352
|
+
|
1353
|
+
setup["fftargs"] = fftargs.copy()
|
1354
|
+
convolution_mode = "same"
|
1355
|
+
if backend.sum(matching_data._target_pad) > 0:
|
1356
|
+
convolution_mode = "valid"
|
1357
|
+
setup["convolution_mode"] = convolution_mode
|
1358
|
+
setup["interpolation_order"] = interpolation_order
|
1359
|
+
rotation_list = matching_data._split_rotations_on_jobs(n_jobs)
|
1360
|
+
|
1361
|
+
backend.free_cache()
|
1362
|
+
|
1363
|
+
def _run_scoring(backend_name, backend_args, rotations, **kwargs):
|
1364
|
+
from tme.backends import backend
|
1365
|
+
|
1366
|
+
backend.change_backend(backend_name, **backend_args)
|
1367
|
+
return matching_score(rotations=rotations, **kwargs)
|
1368
|
+
|
1369
|
+
callbacks = Parallel(n_jobs=n_jobs)(
|
1370
|
+
delayed(_run_scoring)(
|
1371
|
+
backend_name=backend._backend_name,
|
1372
|
+
backend_args=backend._backend_args,
|
1373
|
+
rotations=rotation,
|
1374
|
+
callback_class=callback_classes[index % n_callback_classes],
|
1375
|
+
callback_class_args=callback_class_args,
|
1376
|
+
**setup,
|
1377
|
+
)
|
1378
|
+
for index, rotation in enumerate(rotation_list)
|
1379
|
+
)
|
1380
|
+
|
1381
|
+
callbacks = [
|
1382
|
+
tuple(callback)
|
1383
|
+
for callback in callbacks[0:n_callback_classes]
|
1384
|
+
if callback is not None
|
1385
|
+
]
|
1386
|
+
backend.free_cache()
|
1387
|
+
|
1388
|
+
merged_callback = None
|
1389
|
+
if callback_class is not None:
|
1390
|
+
merged_callback = callback_class.merge(
|
1391
|
+
callbacks,
|
1392
|
+
**callback_class_args,
|
1393
|
+
score_indices=matching_data.indices,
|
1394
|
+
inner_merge=True,
|
1395
|
+
)
|
1396
|
+
|
1397
|
+
return merged_callback
|
1398
|
+
|
1399
|
+
|
1400
|
+
def scan_subsets(
|
1401
|
+
matching_data: MatchingData,
|
1402
|
+
matching_score: Callable,
|
1403
|
+
matching_setup: Callable,
|
1404
|
+
callback_class: CallbackClass = None,
|
1405
|
+
callback_class_args: Dict = {},
|
1406
|
+
job_schedule: Tuple[int] = (1, 1),
|
1407
|
+
target_splits: Dict = {},
|
1408
|
+
template_splits: Dict = {},
|
1409
|
+
pad_target_edges: bool = False,
|
1410
|
+
pad_fourier: bool = True,
|
1411
|
+
interpolation_order: int = 3,
|
1412
|
+
jobs_per_callback_class: int = 8,
|
1413
|
+
**kwargs,
|
1414
|
+
) -> Tuple:
|
1415
|
+
"""
|
1416
|
+
Wrapper around :py:meth:`scan` that supports template matching on splits
|
1417
|
+
of template and target.
|
1418
|
+
|
1419
|
+
Parameters
|
1420
|
+
----------
|
1421
|
+
matching_data : MatchingData
|
1422
|
+
Template matching data.
|
1423
|
+
matching_func : type
|
1424
|
+
Function pointer to setup function.
|
1425
|
+
matching_score : type
|
1426
|
+
Function pointer to scoring function.
|
1427
|
+
callback_class : type, optional
|
1428
|
+
Analyzer class pointer to operate on computed scores.
|
1429
|
+
callback_class_args : dict, optional
|
1430
|
+
Arguments passed to the callback_class. Default is an empty dictionary.
|
1431
|
+
job_schedule : tuple of int, optional
|
1432
|
+
Schedule of jobs. Default is (1, 1).
|
1433
|
+
target_splits : dict, optional
|
1434
|
+
Splits for target. Default is an empty dictionary, i.e. no splits
|
1435
|
+
template_splits : dict, optional
|
1436
|
+
Splits for template. Default is an empty dictionary, i.e. no splits.
|
1437
|
+
pad_target_edges : bool, optional
|
1438
|
+
Whether to pad the target boundaries by half the template shape
|
1439
|
+
along each axis.
|
1440
|
+
pad_fourier: bool, optional
|
1441
|
+
Whether to pad target and template to the full convolution shape.
|
1442
|
+
interpolation_order : int, optional
|
1443
|
+
Order of spline interpolation for rotations.
|
1444
|
+
jobs_per_callback_class : int, optional
|
1445
|
+
How many jobs should be processed by a single callback_class instance,
|
1446
|
+
if ones is provided.
|
1447
|
+
**kwargs : various
|
1448
|
+
Additional arguments.
|
1449
|
+
|
1450
|
+
Notes
|
1451
|
+
-----
|
1452
|
+
Objects in matching_data might be destroyed during computation.
|
1453
|
+
|
1454
|
+
Returns
|
1455
|
+
-------
|
1456
|
+
Tuple
|
1457
|
+
The merged results from callback_class if provided otherwise None.
|
1458
|
+
"""
|
1459
|
+
target_splits = split_numpy_array_slices(
|
1460
|
+
matching_data.target.shape, splits=target_splits
|
1461
|
+
)
|
1462
|
+
template_splits = split_numpy_array_slices(
|
1463
|
+
matching_data._templateshape, splits=template_splits
|
1464
|
+
)
|
1465
|
+
|
1466
|
+
target_pad = np.zeros(len(matching_data.target.shape), dtype=int)
|
1467
|
+
if pad_target_edges:
|
1468
|
+
target_pad = np.subtract(
|
1469
|
+
matching_data._templateshape, np.mod(matching_data._templateshape, 2)
|
1470
|
+
)
|
1471
|
+
outer_jobs, inner_jobs = job_schedule
|
1472
|
+
results = Parallel(n_jobs=outer_jobs)(
|
1473
|
+
delayed(_run_inner)(
|
1474
|
+
backend_name=backend._backend_name,
|
1475
|
+
backend_args=backend._backend_args,
|
1476
|
+
matching_data=matching_data.subset_by_slice(
|
1477
|
+
target_slice=target_split,
|
1478
|
+
target_pad=target_pad,
|
1479
|
+
template_slice=template_split,
|
1480
|
+
),
|
1481
|
+
matching_score=matching_score,
|
1482
|
+
matching_setup=matching_setup,
|
1483
|
+
n_jobs=inner_jobs,
|
1484
|
+
callback_class=callback_class,
|
1485
|
+
callback_class_args=callback_class_args,
|
1486
|
+
interpolation_order=interpolation_order,
|
1487
|
+
pad_fourier=pad_fourier,
|
1488
|
+
gpu_index=index % outer_jobs,
|
1489
|
+
**kwargs,
|
1490
|
+
)
|
1491
|
+
for index, (target_split, template_split) in enumerate(
|
1492
|
+
product(target_splits, template_splits)
|
1493
|
+
)
|
1494
|
+
)
|
1495
|
+
|
1496
|
+
matching_data._target, matching_data._template = None, None
|
1497
|
+
matching_data._target_mask, matching_data._template_mask = None, None
|
1498
|
+
|
1499
|
+
if callback_class is not None:
|
1500
|
+
candidates = callback_class.merge(
|
1501
|
+
results, **callback_class_args, inner_merge=False
|
1502
|
+
)
|
1503
|
+
return candidates
|
1504
|
+
|
1505
|
+
|
1506
|
+
MATCHING_EXHAUSTIVE_REGISTER = {
|
1507
|
+
"CC": (cc_setup, corr_scoring),
|
1508
|
+
"LCC": (lcc_setup, corr_scoring),
|
1509
|
+
"CORR": (corr_setup, corr_scoring),
|
1510
|
+
"CAM": (cam_setup, corr_scoring),
|
1511
|
+
"FLCSphericalMask": (flcSphericalMask_setup, corr_scoring),
|
1512
|
+
"FLC": (flc_setup, flc_scoring),
|
1513
|
+
"MCC": (mcc_setup, mcc_scoring),
|
1514
|
+
}
|
1515
|
+
|
1516
|
+
|
1517
|
+
def register_matching_exhaustive(
|
1518
|
+
matching: str,
|
1519
|
+
matching_setup: Callable,
|
1520
|
+
matching_scoring: Callable,
|
1521
|
+
memory_class: MatchingMemoryUsage,
|
1522
|
+
) -> None:
|
1523
|
+
"""
|
1524
|
+
Registers a new matching scheme.
|
1525
|
+
|
1526
|
+
Parameters
|
1527
|
+
----------
|
1528
|
+
matching : str
|
1529
|
+
Name of the matching method.
|
1530
|
+
matching_setup : Callable
|
1531
|
+
The setup function associated with the name.
|
1532
|
+
matching_scoring : Callable
|
1533
|
+
The scoring function associated with the name.
|
1534
|
+
memory_class : MatchingMemoryUsage
|
1535
|
+
The custom memory estimation class extending
|
1536
|
+
:py:class:`tme.matching_memory.MatchingMemoryUsage`.
|
1537
|
+
|
1538
|
+
Raises
|
1539
|
+
------
|
1540
|
+
ValueError
|
1541
|
+
If a function with the name ``matching`` already exists in the registry.
|
1542
|
+
ValueError
|
1543
|
+
If ``memory_class`` is not a subclass of
|
1544
|
+
:py:class:`tme.matching_memory.MatchingMemoryUsage`.
|
1545
|
+
"""
|
1546
|
+
|
1547
|
+
if matching in MATCHING_EXHAUSTIVE_REGISTER:
|
1548
|
+
raise ValueError(f"A method with name '{matching}' is already registered.")
|
1549
|
+
if not issubclass(memory_class, MatchingMemoryUsage):
|
1550
|
+
raise ValueError(f"{memory_class} is not a subclass of {MatchingMemoryUsage}.")
|
1551
|
+
|
1552
|
+
MATCHING_EXHAUSTIVE_REGISTER[matching] = (matching_setup, matching_scoring)
|
1553
|
+
MATCHING_MEMORY_REGISTRY[matching] = memory_class
|