pytme 0.2.0b0__cp311-cp311-macosx_14_0_arm64.whl → 0.2.2__cp311-cp311-macosx_14_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pytme-0.2.2.data/scripts/match_template.py +1187 -0
- {pytme-0.2.0b0.data → pytme-0.2.2.data}/scripts/postprocess.py +170 -71
- {pytme-0.2.0b0.data → pytme-0.2.2.data}/scripts/preprocessor_gui.py +179 -86
- pytme-0.2.2.dist-info/METADATA +91 -0
- pytme-0.2.2.dist-info/RECORD +74 -0
- {pytme-0.2.0b0.dist-info → pytme-0.2.2.dist-info}/WHEEL +1 -1
- scripts/extract_candidates.py +126 -87
- scripts/match_template.py +596 -209
- scripts/match_template_filters.py +571 -223
- scripts/postprocess.py +170 -71
- scripts/preprocessor_gui.py +179 -86
- scripts/refine_matches.py +567 -159
- tme/__init__.py +0 -1
- tme/__version__.py +1 -1
- tme/analyzer.py +627 -855
- tme/backends/__init__.py +41 -11
- tme/backends/_jax_utils.py +185 -0
- tme/backends/cupy_backend.py +120 -225
- tme/backends/jax_backend.py +282 -0
- tme/backends/matching_backend.py +464 -388
- tme/backends/mlx_backend.py +45 -68
- tme/backends/npfftw_backend.py +256 -514
- tme/backends/pytorch_backend.py +41 -154
- tme/density.py +312 -421
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/matching_data.py +366 -303
- tme/matching_exhaustive.py +279 -1521
- tme/matching_optimization.py +234 -129
- tme/matching_scores.py +884 -0
- tme/matching_utils.py +281 -387
- tme/memory.py +377 -0
- tme/orientations.py +226 -66
- tme/parser.py +3 -4
- tme/preprocessing/__init__.py +2 -0
- tme/preprocessing/_utils.py +217 -0
- tme/preprocessing/composable_filter.py +31 -0
- tme/preprocessing/compose.py +55 -0
- tme/preprocessing/frequency_filters.py +388 -0
- tme/preprocessing/tilt_series.py +1011 -0
- tme/preprocessor.py +574 -530
- tme/structure.py +495 -189
- tme/types.py +5 -3
- pytme-0.2.0b0.data/scripts/match_template.py +0 -800
- pytme-0.2.0b0.dist-info/METADATA +0 -73
- pytme-0.2.0b0.dist-info/RECORD +0 -66
- tme/helpers.py +0 -881
- tme/matching_constrained.py +0 -195
- {pytme-0.2.0b0.data → pytme-0.2.2.data}/scripts/estimate_ram_usage.py +0 -0
- {pytme-0.2.0b0.data → pytme-0.2.2.data}/scripts/preprocess.py +0 -0
- {pytme-0.2.0b0.dist-info → pytme-0.2.2.dist-info}/LICENSE +0 -0
- {pytme-0.2.0b0.dist-info → pytme-0.2.2.dist-info}/entry_points.txt +0 -0
- {pytme-0.2.0b0.dist-info → pytme-0.2.2.dist-info}/top_level.txt +0 -0
tme/matching_scores.py
ADDED
@@ -0,0 +1,884 @@
|
|
1
|
+
""" Implements a range of cross-correlation coefficients.
|
2
|
+
|
3
|
+
Copyright (c) 2023-2024 European Molecular Biology Laboratory
|
4
|
+
|
5
|
+
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
6
|
+
"""
|
7
|
+
|
8
|
+
import warnings
|
9
|
+
from typing import Callable, Tuple, Dict, Optional
|
10
|
+
|
11
|
+
import numpy as np
|
12
|
+
from scipy.ndimage import laplace
|
13
|
+
|
14
|
+
from .backends import backend as be
|
15
|
+
from .types import CallbackClass, BackendArray, shm_type
|
16
|
+
from .matching_utils import (
|
17
|
+
conditional_execute,
|
18
|
+
identity,
|
19
|
+
normalize_template,
|
20
|
+
_normalize_template_overflow_safe,
|
21
|
+
)
|
22
|
+
|
23
|
+
|
24
|
+
def _shape_match(shape1: Tuple[int], shape2: Tuple[int]) -> bool:
|
25
|
+
"""
|
26
|
+
Determine whether ``shape1`` is equal to ``shape2``.
|
27
|
+
|
28
|
+
Parameters
|
29
|
+
----------
|
30
|
+
shape1, shape2 : tuple of ints
|
31
|
+
Shapes to compare.
|
32
|
+
|
33
|
+
Returns
|
34
|
+
-------
|
35
|
+
Bool
|
36
|
+
``shape1`` is equal to ``shape2``.
|
37
|
+
"""
|
38
|
+
if len(shape1) != len(shape2):
|
39
|
+
return False
|
40
|
+
return shape1 == shape2
|
41
|
+
|
42
|
+
|
43
|
+
def _setup_template_filtering(
|
44
|
+
forward_ft_shape: Tuple[int],
|
45
|
+
inverse_ft_shape: Tuple[int],
|
46
|
+
template_shape: Tuple[int],
|
47
|
+
template_filter: BackendArray,
|
48
|
+
rfftn: Callable = None,
|
49
|
+
irfftn: Callable = None,
|
50
|
+
) -> Callable:
|
51
|
+
"""
|
52
|
+
Configure template filtering function for Fourier transforms.
|
53
|
+
|
54
|
+
Parameters
|
55
|
+
----------
|
56
|
+
forward_ft_shape : tuple of ints
|
57
|
+
Shape for the forward Fourier transform.
|
58
|
+
inverse_ft_shape : tuple of ints
|
59
|
+
Shape for the inverse Fourier transform.
|
60
|
+
template_shape : tuple of ints
|
61
|
+
Shape of the template to be filtered.
|
62
|
+
template_filter : BackendArray
|
63
|
+
Precomputed filter to apply in the frequency domain.
|
64
|
+
rfftn : Callable, optional
|
65
|
+
Real-to-complex FFT function.
|
66
|
+
irfftn : Callable, optional
|
67
|
+
Complex-to-real inverse FFT function.
|
68
|
+
|
69
|
+
Returns
|
70
|
+
-------
|
71
|
+
Callable
|
72
|
+
Filter function with parameters template, ft_temp and template_filter.
|
73
|
+
|
74
|
+
Notes
|
75
|
+
-----
|
76
|
+
If the shape of template_filter does not match inverse_ft_shape
|
77
|
+
the template is assumed to be padded and cropped back to template_shape
|
78
|
+
prior to filter application.
|
79
|
+
"""
|
80
|
+
if be.size(template_filter) == 1:
|
81
|
+
return conditional_execute(identity, identity, False)
|
82
|
+
|
83
|
+
shape_mismatch = False
|
84
|
+
if not _shape_match(template_filter.shape, inverse_ft_shape):
|
85
|
+
shape_mismatch = True
|
86
|
+
forward_ft_shape = template_shape
|
87
|
+
inverse_ft_shape = template_filter.shape
|
88
|
+
|
89
|
+
if rfftn is not None and irfftn is not None:
|
90
|
+
rfftn, irfftn = be.build_fft(
|
91
|
+
fast_shape=forward_ft_shape,
|
92
|
+
fast_ft_shape=inverse_ft_shape,
|
93
|
+
real_dtype=be._float_dtype,
|
94
|
+
complex_dtype=be._complex_dtype,
|
95
|
+
inverse_fast_shape=forward_ft_shape,
|
96
|
+
)
|
97
|
+
|
98
|
+
# Default case, all shapes are correctly matched
|
99
|
+
def _apply_template_filter(template, ft_temp, template_filter):
|
100
|
+
ft_temp = rfftn(template, ft_temp)
|
101
|
+
ft_temp = be.multiply(ft_temp, template_filter, out=ft_temp)
|
102
|
+
return irfftn(ft_temp, template)
|
103
|
+
|
104
|
+
# Template is padded, filter is not. Crop and assign for continuous arrays
|
105
|
+
if shape_mismatch:
|
106
|
+
real_subset = tuple(slice(0, x) for x in forward_ft_shape)
|
107
|
+
_template = be.zeros(forward_ft_shape, be._float_dtype)
|
108
|
+
_ft_temp = be.zeros(inverse_ft_shape, be._complex_dtype)
|
109
|
+
|
110
|
+
def _apply_filter_shape_mismatch(template, ft_temp, template_filter):
|
111
|
+
_template[:] = template[real_subset]
|
112
|
+
return _apply_template_filter(_template, _ft_temp, template_filter)
|
113
|
+
|
114
|
+
return _apply_filter_shape_mismatch
|
115
|
+
|
116
|
+
return _apply_template_filter
|
117
|
+
|
118
|
+
|
119
|
+
def cc_setup(
|
120
|
+
rfftn: Callable,
|
121
|
+
irfftn: Callable,
|
122
|
+
template: BackendArray,
|
123
|
+
target: BackendArray,
|
124
|
+
fast_shape: Tuple[int],
|
125
|
+
fast_ft_shape: Tuple[int],
|
126
|
+
shared_memory_handler: type,
|
127
|
+
**kwargs,
|
128
|
+
) -> Dict:
|
129
|
+
"""
|
130
|
+
Setup function for comuting a unnormalized cross-correlation between
|
131
|
+
``target`` (f) and ``template`` (g)
|
132
|
+
|
133
|
+
.. math::
|
134
|
+
|
135
|
+
\\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*).
|
136
|
+
|
137
|
+
|
138
|
+
Notes
|
139
|
+
-----
|
140
|
+
To be used with :py:meth:`corr_scoring`.
|
141
|
+
"""
|
142
|
+
target_pad_ft = be.zeros(fast_ft_shape, be._complex_dtype)
|
143
|
+
target_pad_ft = rfftn(be.topleft_pad(target, fast_shape), target_pad_ft)
|
144
|
+
numerator = be.zeros(1, be._float_dtype)
|
145
|
+
inv_denominator = be.zeros(1, be._float_dtype) + 1
|
146
|
+
|
147
|
+
ret = {
|
148
|
+
"fast_shape": fast_shape,
|
149
|
+
"fast_ft_shape": fast_ft_shape,
|
150
|
+
"template": be.to_sharedarr(template, shared_memory_handler),
|
151
|
+
"ft_target": be.to_sharedarr(target_pad_ft, shared_memory_handler),
|
152
|
+
"inv_denominator": be.to_sharedarr(inv_denominator, shared_memory_handler),
|
153
|
+
"numerator": be.to_sharedarr(numerator, shared_memory_handler),
|
154
|
+
}
|
155
|
+
|
156
|
+
return ret
|
157
|
+
|
158
|
+
|
159
|
+
def lcc_setup(target: BackendArray, template: BackendArray, **kwargs) -> Dict:
|
160
|
+
"""
|
161
|
+
Setup function for computing a laplace cross-correlation between
|
162
|
+
``target`` (f) and ``template`` (g)
|
163
|
+
|
164
|
+
.. math::
|
165
|
+
|
166
|
+
\\mathcal{F}^{-1}(\\mathcal{F}(\\nabla^{2}f) \\cdot \\mathcal{F}(\\nabla^{2} g)^*)
|
167
|
+
|
168
|
+
|
169
|
+
Notes
|
170
|
+
-----
|
171
|
+
To be used with :py:meth:`corr_scoring`.
|
172
|
+
"""
|
173
|
+
target, template = be.to_numpy_array(target), be.to_numpy_array(template)
|
174
|
+
kwargs["target"] = be.to_backend_array(laplace(target, mode="wrap"))
|
175
|
+
kwargs["template"] = be.to_backend_array(laplace(template, mode="wrap"))
|
176
|
+
return cc_setup(**kwargs)
|
177
|
+
|
178
|
+
|
179
|
+
def corr_setup(
|
180
|
+
rfftn: Callable,
|
181
|
+
irfftn: Callable,
|
182
|
+
template: BackendArray,
|
183
|
+
template_mask: BackendArray,
|
184
|
+
template_filter: BackendArray,
|
185
|
+
target: BackendArray,
|
186
|
+
fast_shape: Tuple[int],
|
187
|
+
fast_ft_shape: Tuple[int],
|
188
|
+
shared_memory_handler: type,
|
189
|
+
**kwargs,
|
190
|
+
) -> Dict:
|
191
|
+
"""
|
192
|
+
Setup for computing a normalized cross-correlation between a
|
193
|
+
``target`` (f), a ``template`` (g) given ``template_mask`` (m)
|
194
|
+
|
195
|
+
.. math::
|
196
|
+
|
197
|
+
\\frac{CC(f,g) - \\overline{g} \\cdot CC(f, m)}
|
198
|
+
{(CC(f^2, m) - \\frac{CC(f, m)^2}{N_g}) \\cdot \\sigma_{g}},
|
199
|
+
|
200
|
+
where
|
201
|
+
|
202
|
+
.. math::
|
203
|
+
|
204
|
+
CC(f,g) = \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*).
|
205
|
+
|
206
|
+
Notes
|
207
|
+
-----
|
208
|
+
To be used with :py:meth:`corr_scoring`.
|
209
|
+
|
210
|
+
References
|
211
|
+
----------
|
212
|
+
.. [1] Lewis P. J. Fast Normalized Cross-Correlation, Industrial Light and Magic.
|
213
|
+
"""
|
214
|
+
target_pad = be.topleft_pad(target, fast_shape)
|
215
|
+
|
216
|
+
# The exact composition of the denominator is debatable
|
217
|
+
# scikit-image match_template multiplies the running sum of the target
|
218
|
+
# with a scaling factor derived from the template. This is probably appropriate
|
219
|
+
# in pattern matching situations where the template exists in the target
|
220
|
+
ft_window = be.zeros(fast_ft_shape, be._complex_dtype)
|
221
|
+
ft_window = rfftn(be.topleft_pad(template_mask, fast_shape), ft_window)
|
222
|
+
ft_target = be.zeros(fast_ft_shape, be._complex_dtype)
|
223
|
+
ft_target2 = be.zeros(fast_ft_shape, be._complex_dtype)
|
224
|
+
denominator = be.zeros(fast_shape, be._float_dtype)
|
225
|
+
window_sum = be.zeros(fast_shape, be._float_dtype)
|
226
|
+
|
227
|
+
ft_target = rfftn(target_pad, ft_target)
|
228
|
+
ft_target2 = rfftn(be.square(target_pad), ft_target2)
|
229
|
+
ft_target2 = be.multiply(ft_target2, ft_window, out=ft_target2)
|
230
|
+
denominator = irfftn(ft_target2, denominator)
|
231
|
+
ft_window = be.multiply(ft_target, ft_window, out=ft_window)
|
232
|
+
window_sum = irfftn(ft_window, window_sum)
|
233
|
+
|
234
|
+
target_pad, ft_target2, ft_window = None, None, None
|
235
|
+
|
236
|
+
# TODO: Factor in template_filter here
|
237
|
+
if be.size(template_filter) != 1:
|
238
|
+
warnings.warn(
|
239
|
+
"CORR scores obtained with template_filter are not correctly scaled. "
|
240
|
+
"Please use a different score or consider only relative peak heights."
|
241
|
+
)
|
242
|
+
n_obs, norm_func = be.sum(template_mask), normalize_template
|
243
|
+
if be.datatype_bytes(template_mask.dtype) == 2:
|
244
|
+
norm_func = _normalize_template_overflow_safe
|
245
|
+
n_obs = be.sum(be.astype(template_mask, be._overflow_safe_dtype))
|
246
|
+
|
247
|
+
template = norm_func(template, template_mask, n_obs)
|
248
|
+
template_mean = be.sum(be.multiply(template, template_mask))
|
249
|
+
template_mean = be.divide(template_mean, n_obs)
|
250
|
+
template_ssd = be.sum(be.square(template - template_mean) * template_mask)
|
251
|
+
template_volume = np.prod(tuple(int(x) for x in template.shape))
|
252
|
+
template = be.multiply(template, template_mask, out=template)
|
253
|
+
|
254
|
+
numerator = be.multiply(window_sum, template_mean)
|
255
|
+
window_sum = be.square(window_sum, out=window_sum)
|
256
|
+
window_sum = be.divide(window_sum, template_volume, out=window_sum)
|
257
|
+
denominator = be.subtract(denominator, window_sum, out=denominator)
|
258
|
+
denominator = be.multiply(denominator, template_ssd, out=denominator)
|
259
|
+
denominator = be.maximum(denominator, 0, out=denominator)
|
260
|
+
denominator = be.sqrt(denominator, out=denominator)
|
261
|
+
|
262
|
+
mask = denominator > be.eps(be._float_dtype)
|
263
|
+
denominator = be.multiply(denominator, mask, out=denominator)
|
264
|
+
denominator = be.add(denominator, ~mask, out=denominator)
|
265
|
+
denominator = be.divide(1, denominator, out=denominator)
|
266
|
+
denominator = be.multiply(denominator, mask, out=denominator)
|
267
|
+
|
268
|
+
ret = {
|
269
|
+
"fast_shape": fast_shape,
|
270
|
+
"fast_ft_shape": fast_ft_shape,
|
271
|
+
"template": be.to_sharedarr(template, shared_memory_handler),
|
272
|
+
"ft_target": be.to_sharedarr(ft_target, shared_memory_handler),
|
273
|
+
"inv_denominator": be.to_sharedarr(denominator, shared_memory_handler),
|
274
|
+
"numerator": be.to_sharedarr(numerator, shared_memory_handler),
|
275
|
+
}
|
276
|
+
|
277
|
+
return ret
|
278
|
+
|
279
|
+
|
280
|
+
def cam_setup(template: BackendArray, target: BackendArray, **kwargs) -> Dict:
|
281
|
+
"""
|
282
|
+
Like :py:meth:`corr_setup` but with standardized ``target``, ``template``
|
283
|
+
|
284
|
+
.. math::
|
285
|
+
|
286
|
+
f' = \\frac{f - \\overline{f}}{\\sigma_f}.
|
287
|
+
|
288
|
+
Notes
|
289
|
+
-----
|
290
|
+
To be used with :py:meth:`corr_scoring`.
|
291
|
+
"""
|
292
|
+
template = (template - be.mean(template)) / be.std(template)
|
293
|
+
target = (target - be.mean(target)) / be.std(target)
|
294
|
+
return corr_setup(template=template, target=target, **kwargs)
|
295
|
+
|
296
|
+
|
297
|
+
def flc_setup(
|
298
|
+
rfftn: Callable,
|
299
|
+
irfftn: Callable,
|
300
|
+
template: BackendArray,
|
301
|
+
template_mask: BackendArray,
|
302
|
+
target: BackendArray,
|
303
|
+
fast_shape: Tuple[int],
|
304
|
+
fast_ft_shape: Tuple[int],
|
305
|
+
shared_memory_handler: type,
|
306
|
+
**kwargs,
|
307
|
+
) -> Dict:
|
308
|
+
"""
|
309
|
+
Setup function for :py:meth:`flc_scoring`.
|
310
|
+
"""
|
311
|
+
target_pad = be.topleft_pad(target, fast_shape)
|
312
|
+
ft_target = be.zeros(fast_ft_shape, be._complex_dtype)
|
313
|
+
ft_target2 = be.zeros(fast_ft_shape, be._complex_dtype)
|
314
|
+
|
315
|
+
ft_target = rfftn(target_pad, ft_target)
|
316
|
+
target_pad = be.square(target_pad, out=target_pad)
|
317
|
+
ft_target2 = rfftn(target_pad, ft_target2)
|
318
|
+
template = normalize_template(template, template_mask, be.sum(template_mask))
|
319
|
+
|
320
|
+
ret = {
|
321
|
+
"fast_shape": fast_shape,
|
322
|
+
"fast_ft_shape": fast_ft_shape,
|
323
|
+
"template": be.to_sharedarr(template, shared_memory_handler),
|
324
|
+
"template_mask": be.to_sharedarr(template_mask, shared_memory_handler),
|
325
|
+
"ft_target": be.to_sharedarr(ft_target, shared_memory_handler),
|
326
|
+
"ft_target2": be.to_sharedarr(ft_target2, shared_memory_handler),
|
327
|
+
}
|
328
|
+
|
329
|
+
return ret
|
330
|
+
|
331
|
+
|
332
|
+
def flcSphericalMask_setup(
|
333
|
+
rfftn: Callable,
|
334
|
+
irfftn: Callable,
|
335
|
+
template: BackendArray,
|
336
|
+
template_mask: BackendArray,
|
337
|
+
target: BackendArray,
|
338
|
+
fast_shape: Tuple[int],
|
339
|
+
fast_ft_shape: Tuple[int],
|
340
|
+
shared_memory_handler: type,
|
341
|
+
**kwargs,
|
342
|
+
) -> Dict:
|
343
|
+
"""
|
344
|
+
Setup for :py:meth:`corr_scoring`, like :py:meth:`flc_setup` but for rotation
|
345
|
+
invariant masks.
|
346
|
+
"""
|
347
|
+
n_obs, norm_func = be.sum(template_mask), normalize_template
|
348
|
+
if be.datatype_bytes(template_mask.dtype) == 2:
|
349
|
+
norm_func = _normalize_template_overflow_safe
|
350
|
+
n_obs = be.sum(be.astype(template_mask, be._overflow_safe_dtype))
|
351
|
+
|
352
|
+
target_pad = be.topleft_pad(target, fast_shape)
|
353
|
+
temp = be.zeros(fast_shape, be._float_dtype)
|
354
|
+
temp2 = be.zeros(fast_shape, be._float_dtype)
|
355
|
+
numerator = be.zeros(1, be._float_dtype)
|
356
|
+
ft_target = be.zeros(fast_ft_shape, be._complex_dtype)
|
357
|
+
ft_template_mask = be.zeros(fast_ft_shape, be._complex_dtype)
|
358
|
+
ft_temp = be.zeros(fast_ft_shape, be._complex_dtype)
|
359
|
+
|
360
|
+
template = norm_func(template, template_mask, n_obs)
|
361
|
+
ft_template_mask = rfftn(
|
362
|
+
be.topleft_pad(template_mask, fast_shape), ft_template_mask
|
363
|
+
)
|
364
|
+
|
365
|
+
# E(X^2) - E(X)^2
|
366
|
+
ft_target = rfftn(be.square(target_pad), ft_target)
|
367
|
+
ft_temp = be.multiply(ft_target, ft_template_mask, out=ft_temp)
|
368
|
+
temp2 = irfftn(ft_temp, temp2)
|
369
|
+
temp2 = be.divide(temp2, n_obs, out=temp2)
|
370
|
+
|
371
|
+
ft_target = rfftn(target_pad, ft_target)
|
372
|
+
ft_temp = be.multiply(ft_target, ft_template_mask, out=ft_temp)
|
373
|
+
temp = irfftn(ft_temp, temp)
|
374
|
+
temp = be.divide(temp, n_obs, out=temp)
|
375
|
+
temp = be.square(temp, out=temp)
|
376
|
+
|
377
|
+
temp = be.subtract(temp2, temp, out=temp)
|
378
|
+
temp = be.maximum(temp, 0.0, out=temp)
|
379
|
+
temp = be.sqrt(temp, out=temp)
|
380
|
+
|
381
|
+
# Avoide divide by zero warnings
|
382
|
+
mask = temp > be.eps(be._float_dtype)
|
383
|
+
temp = be.multiply(temp, mask * n_obs, out=temp)
|
384
|
+
temp = be.add(temp, ~mask, out=temp)
|
385
|
+
temp2 = be.divide(1, temp, out=temp)
|
386
|
+
temp2 = be.multiply(temp2, mask, out=temp2)
|
387
|
+
|
388
|
+
ret = {
|
389
|
+
"fast_shape": fast_shape,
|
390
|
+
"fast_ft_shape": fast_ft_shape,
|
391
|
+
"template": be.to_sharedarr(template, shared_memory_handler),
|
392
|
+
"template_mask": be.to_sharedarr(template_mask, shared_memory_handler),
|
393
|
+
"ft_target": be.to_sharedarr(ft_target, shared_memory_handler),
|
394
|
+
"inv_denominator": be.to_sharedarr(temp2, shared_memory_handler),
|
395
|
+
"numerator": be.to_sharedarr(numerator, shared_memory_handler),
|
396
|
+
}
|
397
|
+
|
398
|
+
return ret
|
399
|
+
|
400
|
+
|
401
|
+
def mcc_setup(
|
402
|
+
rfftn: Callable,
|
403
|
+
irfftn: Callable,
|
404
|
+
template: BackendArray,
|
405
|
+
template_mask: BackendArray,
|
406
|
+
target: BackendArray,
|
407
|
+
target_mask: BackendArray,
|
408
|
+
fast_shape: Tuple[int],
|
409
|
+
fast_ft_shape: Tuple[int],
|
410
|
+
shared_memory_handler: Callable,
|
411
|
+
**kwargs,
|
412
|
+
) -> Dict:
|
413
|
+
"""
|
414
|
+
Setup function for :py:meth:`mcc_scoring`.
|
415
|
+
"""
|
416
|
+
target = be.multiply(target, target_mask > 0, out=target)
|
417
|
+
target_pad = be.topleft_pad(target, fast_shape)
|
418
|
+
|
419
|
+
ft_target = be.zeros(fast_ft_shape, be._complex_dtype)
|
420
|
+
ft_target2 = be.zeros(fast_ft_shape, be._complex_dtype)
|
421
|
+
target_mask_ft = be.zeros(fast_ft_shape, be._complex_dtype)
|
422
|
+
|
423
|
+
ft_target = rfftn(target_pad, ft_target)
|
424
|
+
ft_target2 = rfftn(be.square(target_pad), ft_target2)
|
425
|
+
target_mask_ft = rfftn(be.topleft_pad(target_mask, fast_shape), target_mask_ft)
|
426
|
+
|
427
|
+
ret = {
|
428
|
+
"fast_shape": fast_shape,
|
429
|
+
"fast_ft_shape": fast_ft_shape,
|
430
|
+
"template": be.to_sharedarr(template, shared_memory_handler),
|
431
|
+
"template_mask": be.to_sharedarr(template_mask, shared_memory_handler),
|
432
|
+
"ft_target": be.to_sharedarr(ft_target, shared_memory_handler),
|
433
|
+
"ft_target2": be.to_sharedarr(ft_target2, shared_memory_handler),
|
434
|
+
"ft_target_mask": be.to_sharedarr(target_mask_ft, shared_memory_handler),
|
435
|
+
}
|
436
|
+
|
437
|
+
return ret
|
438
|
+
|
439
|
+
|
440
|
+
def corr_scoring(
|
441
|
+
template: shm_type,
|
442
|
+
template_filter: shm_type,
|
443
|
+
ft_target: shm_type,
|
444
|
+
inv_denominator: shm_type,
|
445
|
+
numerator: shm_type,
|
446
|
+
fast_shape: Tuple[int],
|
447
|
+
fast_ft_shape: Tuple[int],
|
448
|
+
rotations: BackendArray,
|
449
|
+
callback: CallbackClass,
|
450
|
+
interpolation_order: int,
|
451
|
+
template_mask: shm_type = None,
|
452
|
+
) -> Optional[CallbackClass]:
|
453
|
+
"""
|
454
|
+
Calculates a normalized cross-correlation between a target f and a template g.
|
455
|
+
|
456
|
+
.. math::
|
457
|
+
|
458
|
+
(CC(f,g) - \\text{numerator}) \\cdot \\text{inv_denominator},
|
459
|
+
|
460
|
+
where
|
461
|
+
|
462
|
+
.. math::
|
463
|
+
|
464
|
+
CC(f,g) = \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*).
|
465
|
+
|
466
|
+
Parameters
|
467
|
+
----------
|
468
|
+
template : Union[Tuple[type, tuple of ints, type], BackendArray]
|
469
|
+
Template data buffer, its shape and datatype.
|
470
|
+
template_filter : Union[Tuple[type, tuple of ints, type], BackendArray]
|
471
|
+
Template filter data buffer, its shape and datatype.
|
472
|
+
ft_target : Union[Tuple[type, tuple of ints, type], BackendArray]
|
473
|
+
Fourier transformed target data buffer, its shape and datatype.
|
474
|
+
inv_denominator : Union[Tuple[type, tuple of ints, type], BackendArray]
|
475
|
+
Inverse denominator data buffer, its shape and datatype.
|
476
|
+
numerator : Union[Tuple[type, tuple of ints, type], BackendArray]
|
477
|
+
Numerator data buffer, its shape, and its datatype.
|
478
|
+
fast_shape: tuple of ints
|
479
|
+
Data shape for the forward Fourier transform.
|
480
|
+
fast_ft_shape: tuple of ints
|
481
|
+
Data shape for the inverse Fourier transform.
|
482
|
+
rotations : BackendArray
|
483
|
+
Rotation matrices to be sampled (n, d, d).
|
484
|
+
callback : CallbackClass
|
485
|
+
A callable for processing the result of each rotation.
|
486
|
+
interpolation_order : int
|
487
|
+
Spline order for template rotations.
|
488
|
+
template_mask : Union[Tuple[type, tuple of ints, type], BackendArray], optional
|
489
|
+
Template mask data buffer, its shape and datatype, None by default.
|
490
|
+
|
491
|
+
Returns
|
492
|
+
-------
|
493
|
+
Optional[CallbackClass]
|
494
|
+
``callback`` if provided otherwise None.
|
495
|
+
"""
|
496
|
+
template = be.from_sharedarr(template)
|
497
|
+
ft_target = be.from_sharedarr(ft_target)
|
498
|
+
inv_denominator = be.from_sharedarr(inv_denominator)
|
499
|
+
numerator = be.from_sharedarr(numerator)
|
500
|
+
template_filter = be.from_sharedarr(template_filter)
|
501
|
+
|
502
|
+
norm_func, norm_template, mask_sum = normalize_template, False, 1
|
503
|
+
if template_mask is not None:
|
504
|
+
template_mask = be.from_sharedarr(template_mask)
|
505
|
+
norm_template, mask_sum = True, be.sum(template_mask)
|
506
|
+
if be.datatype_bytes(template_mask.dtype) == 2:
|
507
|
+
norm_func = _normalize_template_overflow_safe
|
508
|
+
mask_sum = be.sum(be.astype(template_mask, be._overflow_safe_dtype))
|
509
|
+
|
510
|
+
callback_func = conditional_execute(callback, callback is not None)
|
511
|
+
norm_template = conditional_execute(norm_func, norm_template)
|
512
|
+
norm_numerator = conditional_execute(
|
513
|
+
be.subtract, identity, _shape_match(numerator.shape, fast_shape)
|
514
|
+
)
|
515
|
+
norm_denominator = conditional_execute(
|
516
|
+
be.multiply, identity, _shape_match(inv_denominator.shape, fast_shape)
|
517
|
+
)
|
518
|
+
|
519
|
+
arr = be.zeros(fast_shape, be._float_dtype)
|
520
|
+
ft_temp = be.zeros(fast_ft_shape, be._complex_dtype)
|
521
|
+
rfftn, irfftn = be.build_fft(
|
522
|
+
fast_shape=fast_shape,
|
523
|
+
fast_ft_shape=fast_ft_shape,
|
524
|
+
real_dtype=be._float_dtype,
|
525
|
+
complex_dtype=be._complex_dtype,
|
526
|
+
temp_real=arr,
|
527
|
+
temp_fft=ft_temp,
|
528
|
+
)
|
529
|
+
|
530
|
+
template_filter_func = _setup_template_filtering(
|
531
|
+
forward_ft_shape=fast_shape,
|
532
|
+
inverse_ft_shape=fast_ft_shape,
|
533
|
+
template_shape=template.shape,
|
534
|
+
template_filter=template_filter,
|
535
|
+
rfftn=rfftn,
|
536
|
+
irfftn=irfftn,
|
537
|
+
)
|
538
|
+
unpadded_slice = tuple(slice(0, stop) for stop in template.shape)
|
539
|
+
for index in range(rotations.shape[0]):
|
540
|
+
rotation = rotations[index]
|
541
|
+
arr = be.fill(arr, 0)
|
542
|
+
arr, _ = be.rigid_transform(
|
543
|
+
arr=template,
|
544
|
+
rotation_matrix=rotation,
|
545
|
+
out=arr,
|
546
|
+
use_geometric_center=True,
|
547
|
+
order=interpolation_order,
|
548
|
+
cache=True,
|
549
|
+
)
|
550
|
+
arr = template_filter_func(arr, ft_temp, template_filter)
|
551
|
+
norm_template(arr[unpadded_slice], template_mask, mask_sum)
|
552
|
+
|
553
|
+
ft_temp = rfftn(arr, ft_temp)
|
554
|
+
ft_temp = be.multiply(ft_target, ft_temp, out=ft_temp)
|
555
|
+
arr = irfftn(ft_temp, arr)
|
556
|
+
|
557
|
+
arr = norm_numerator(arr, numerator, out=arr)
|
558
|
+
arr = norm_denominator(arr, inv_denominator, out=arr)
|
559
|
+
callback_func(arr, rotation_matrix=rotation)
|
560
|
+
|
561
|
+
return callback
|
562
|
+
|
563
|
+
|
564
|
+
def flc_scoring(
|
565
|
+
template: shm_type,
|
566
|
+
template_mask: shm_type,
|
567
|
+
ft_target: shm_type,
|
568
|
+
ft_target2: shm_type,
|
569
|
+
template_filter: shm_type,
|
570
|
+
fast_shape: Tuple[int],
|
571
|
+
fast_ft_shape: Tuple[int],
|
572
|
+
rotations: BackendArray,
|
573
|
+
callback: CallbackClass,
|
574
|
+
interpolation_order: int,
|
575
|
+
) -> Optional[CallbackClass]:
|
576
|
+
"""
|
577
|
+
Computes a normalized cross-correlation between ``target`` (f),
|
578
|
+
``template`` (g), and ``template_mask`` (m)
|
579
|
+
|
580
|
+
.. math::
|
581
|
+
|
582
|
+
\\frac{CC(f, \\frac{g*m - \\overline{g*m}}{\\sigma_{g*m}})}
|
583
|
+
{N_m * \\sqrt{
|
584
|
+
\\frac{CC(f^2, m)}{N_m} - (\\frac{CC(f, m)}{N_m})^2}
|
585
|
+
},
|
586
|
+
|
587
|
+
where
|
588
|
+
|
589
|
+
.. math::
|
590
|
+
|
591
|
+
CC(f,g) = \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*)
|
592
|
+
|
593
|
+
and Nm is the sum of g.
|
594
|
+
|
595
|
+
Parameters
|
596
|
+
----------
|
597
|
+
template : Union[Tuple[type, tuple of ints, type], BackendArray]
|
598
|
+
Template data buffer, its shape and datatype.
|
599
|
+
template_mask : Union[Tuple[type, tuple of ints, type], BackendArray]
|
600
|
+
Template mask data buffer, its shape and datatype.
|
601
|
+
template_filter : Union[Tuple[type, tuple of ints, type], BackendArray]
|
602
|
+
Template filter data buffer, its shape and datatype.
|
603
|
+
ft_target : Union[Tuple[type, tuple of ints, type], BackendArray]
|
604
|
+
Fourier transformed target data buffer, its shape and datatype.
|
605
|
+
ft_target2 : Union[Tuple[type, tuple of ints, type], BackendArray]
|
606
|
+
Fourier transformed squared target data buffer, its shape and datatype.
|
607
|
+
fast_shape : tuple of ints
|
608
|
+
Data shape for the forward Fourier transform.
|
609
|
+
fast_ft_shape : tuple of ints
|
610
|
+
Data shape for the inverse Fourier transform.
|
611
|
+
rotations : BackendArray
|
612
|
+
Rotation matrices to be sampled (n, d, d).
|
613
|
+
callback : CallbackClass
|
614
|
+
A callable for processing the result of each rotation.
|
615
|
+
callback_class_args : Dict
|
616
|
+
Dictionary of arguments to be passed to ``callback``.
|
617
|
+
interpolation_order : int
|
618
|
+
Spline order for template rotations.
|
619
|
+
|
620
|
+
Returns
|
621
|
+
-------
|
622
|
+
Optional[CallbackClass]
|
623
|
+
``callback`` if provided otherwise None.
|
624
|
+
|
625
|
+
References
|
626
|
+
----------
|
627
|
+
.. [1] Hrabe T. et al, J. Struct. Biol. 178, 177 (2012).
|
628
|
+
"""
|
629
|
+
float_dtype, complex_dtype = be._float_dtype, be._complex_dtype
|
630
|
+
template = be.from_sharedarr(template)
|
631
|
+
template_mask = be.from_sharedarr(template_mask)
|
632
|
+
ft_target = be.from_sharedarr(ft_target)
|
633
|
+
ft_target2 = be.from_sharedarr(ft_target2)
|
634
|
+
template_filter = be.from_sharedarr(template_filter)
|
635
|
+
|
636
|
+
arr = be.zeros(fast_shape, float_dtype)
|
637
|
+
temp = be.zeros(fast_shape, float_dtype)
|
638
|
+
temp2 = be.zeros(fast_shape, float_dtype)
|
639
|
+
ft_temp = be.zeros(fast_ft_shape, complex_dtype)
|
640
|
+
ft_denom = be.zeros(fast_ft_shape, complex_dtype)
|
641
|
+
|
642
|
+
rfftn, irfftn = be.build_fft(
|
643
|
+
fast_shape=fast_shape,
|
644
|
+
fast_ft_shape=fast_ft_shape,
|
645
|
+
real_dtype=float_dtype,
|
646
|
+
complex_dtype=complex_dtype,
|
647
|
+
temp_real=arr,
|
648
|
+
temp_fft=ft_temp,
|
649
|
+
)
|
650
|
+
|
651
|
+
template_filter_func = _setup_template_filtering(
|
652
|
+
forward_ft_shape=fast_shape,
|
653
|
+
inverse_ft_shape=fast_ft_shape,
|
654
|
+
template_shape=template.shape,
|
655
|
+
template_filter=template_filter,
|
656
|
+
rfftn=rfftn,
|
657
|
+
irfftn=irfftn,
|
658
|
+
)
|
659
|
+
|
660
|
+
eps = be.eps(float_dtype)
|
661
|
+
callback_func = conditional_execute(callback, callback is not None)
|
662
|
+
for index in range(rotations.shape[0]):
|
663
|
+
rotation = rotations[index]
|
664
|
+
arr = be.fill(arr, 0)
|
665
|
+
temp = be.fill(temp, 0)
|
666
|
+
arr, temp = be.rigid_transform(
|
667
|
+
arr=template,
|
668
|
+
arr_mask=template_mask,
|
669
|
+
rotation_matrix=rotations[index],
|
670
|
+
out=arr,
|
671
|
+
out_mask=temp,
|
672
|
+
use_geometric_center=True,
|
673
|
+
order=interpolation_order,
|
674
|
+
cache=True,
|
675
|
+
)
|
676
|
+
|
677
|
+
n_obs = be.sum(temp)
|
678
|
+
arr = template_filter_func(arr, ft_temp, template_filter)
|
679
|
+
arr = normalize_template(arr, temp, n_obs)
|
680
|
+
|
681
|
+
ft_temp = rfftn(temp, ft_temp)
|
682
|
+
ft_denom = be.multiply(ft_target, ft_temp, out=ft_denom)
|
683
|
+
temp = irfftn(ft_denom, temp)
|
684
|
+
ft_denom = be.multiply(ft_target2, ft_temp, out=ft_denom)
|
685
|
+
temp2 = irfftn(ft_denom, temp2)
|
686
|
+
|
687
|
+
ft_temp = rfftn(arr, ft_temp)
|
688
|
+
ft_temp = be.multiply(ft_target, ft_temp, out=ft_temp)
|
689
|
+
arr = irfftn(ft_temp, arr)
|
690
|
+
|
691
|
+
arr = be.norm_scores(arr, temp2, temp, n_obs, eps, arr)
|
692
|
+
callback_func(arr, rotation_matrix=rotation)
|
693
|
+
|
694
|
+
return callback
|
695
|
+
|
696
|
+
|
697
|
+
def mcc_scoring(
|
698
|
+
template: shm_type,
|
699
|
+
template_mask: shm_type,
|
700
|
+
template_filter: shm_type,
|
701
|
+
ft_target: shm_type,
|
702
|
+
ft_target2: shm_type,
|
703
|
+
ft_target_mask: shm_type,
|
704
|
+
fast_shape: Tuple[int],
|
705
|
+
fast_ft_shape: Tuple[int],
|
706
|
+
rotations: BackendArray,
|
707
|
+
callback: CallbackClass,
|
708
|
+
interpolation_order: int,
|
709
|
+
overlap_ratio: float = 0.3,
|
710
|
+
) -> CallbackClass:
|
711
|
+
"""
|
712
|
+
Computes a normalized cross-correlation score between ``target`` (f),
|
713
|
+
``template`` (g), ``template_mask`` (m) and ``target_mask`` (t)
|
714
|
+
|
715
|
+
.. math::
|
716
|
+
|
717
|
+
\\frac{
|
718
|
+
CC(f, g) - \\frac{CC(f, m) \\cdot CC(t, g)}{CC(t, m)}
|
719
|
+
}{
|
720
|
+
\\sqrt{
|
721
|
+
(CC(f ^ 2, m) - \\frac{CC(f, m) ^ 2}{CC(t, m)}) \\cdot
|
722
|
+
(CC(t, g^2) - \\frac{CC(t, g) ^ 2}{CC(t, m)})
|
723
|
+
}
|
724
|
+
},
|
725
|
+
|
726
|
+
where
|
727
|
+
|
728
|
+
.. math::
|
729
|
+
|
730
|
+
CC(f,g) = \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*).
|
731
|
+
|
732
|
+
Parameters
|
733
|
+
----------
|
734
|
+
template : Union[Tuple[type, tuple of ints, type], BackendArray]
|
735
|
+
Template data buffer, its shape and datatype.
|
736
|
+
template_mask : Union[Tuple[type, tuple of ints, type], BackendArray]
|
737
|
+
Template mask data buffer, its shape and datatype.
|
738
|
+
template_filter : Union[Tuple[type, tuple of ints, type], BackendArray]
|
739
|
+
Template filter data buffer, its shape and datatype.
|
740
|
+
ft_target : Union[Tuple[type, tuple of ints, type], BackendArray]
|
741
|
+
Fourier transformed target data buffer, its shape and datatype.
|
742
|
+
ft_target2 : Union[Tuple[type, tuple of ints, type], BackendArray]
|
743
|
+
Fourier transformed squared target data buffer, its shape and datatype.
|
744
|
+
ft_target_mask : Union[Tuple[type, tuple of ints, type], BackendArray]
|
745
|
+
Fourier transformed target mask data buffer, its shape and datatype.
|
746
|
+
fast_shape: tuple of ints
|
747
|
+
Data shape for the forward Fourier transform.
|
748
|
+
fast_ft_shape: tuple of ints
|
749
|
+
Data shape for the inverse Fourier transform.
|
750
|
+
rotations : BackendArray
|
751
|
+
Rotation matrices to be sampled (n, d, d).
|
752
|
+
callback : CallbackClass
|
753
|
+
A callable for processing the result of each rotation.
|
754
|
+
interpolation_order : int
|
755
|
+
Spline order for template rotations.
|
756
|
+
overlap_ratio : float, optional
|
757
|
+
Required fractional mask overlap, 0.3 by default.
|
758
|
+
|
759
|
+
References
|
760
|
+
----------
|
761
|
+
.. [1] Masked FFT registration, Dirk Padfield, CVPR 2010 conference
|
762
|
+
.. [2] https://scikit-image.org/docs/stable/api/skimage.registration.html
|
763
|
+
"""
|
764
|
+
float_dtype, complex_dtype = be._float_dtype, be._complex_dtype
|
765
|
+
template = be.from_sharedarr(template)
|
766
|
+
target_ft = be.from_sharedarr(ft_target)
|
767
|
+
target_ft2 = be.from_sharedarr(ft_target2)
|
768
|
+
template_mask = be.from_sharedarr(template_mask)
|
769
|
+
target_mask_ft = be.from_sharedarr(ft_target_mask)
|
770
|
+
template_filter = be.from_sharedarr(template_filter)
|
771
|
+
|
772
|
+
axes = tuple(range(template.ndim))
|
773
|
+
eps = be.eps(float_dtype)
|
774
|
+
|
775
|
+
# Allocate score and process specific arrays
|
776
|
+
template_rot = be.zeros(fast_shape, float_dtype)
|
777
|
+
mask_overlap = be.zeros(fast_shape, float_dtype)
|
778
|
+
numerator = be.zeros(fast_shape, float_dtype)
|
779
|
+
temp = be.zeros(fast_shape, float_dtype)
|
780
|
+
temp2 = be.zeros(fast_shape, float_dtype)
|
781
|
+
temp3 = be.zeros(fast_shape, float_dtype)
|
782
|
+
temp_ft = be.zeros(fast_ft_shape, complex_dtype)
|
783
|
+
|
784
|
+
rfftn, irfftn = be.build_fft(
|
785
|
+
fast_shape=fast_shape,
|
786
|
+
fast_ft_shape=fast_ft_shape,
|
787
|
+
real_dtype=float_dtype,
|
788
|
+
complex_dtype=complex_dtype,
|
789
|
+
temp_real=numerator,
|
790
|
+
temp_fft=temp_ft,
|
791
|
+
)
|
792
|
+
|
793
|
+
template_filter_func = _setup_template_filtering(
|
794
|
+
forward_ft_shape=fast_shape,
|
795
|
+
inverse_ft_shape=fast_ft_shape,
|
796
|
+
template_shape=template.shape,
|
797
|
+
template_filter=template_filter,
|
798
|
+
rfftn=rfftn,
|
799
|
+
irfftn=irfftn,
|
800
|
+
)
|
801
|
+
|
802
|
+
callback_func = conditional_execute(callback, callback is not None)
|
803
|
+
for index in range(rotations.shape[0]):
|
804
|
+
rotation = rotations[index]
|
805
|
+
template_rot = be.fill(template_rot, 0)
|
806
|
+
temp = be.fill(temp, 0)
|
807
|
+
be.rigid_transform(
|
808
|
+
arr=template,
|
809
|
+
arr_mask=template_mask,
|
810
|
+
rotation_matrix=rotation,
|
811
|
+
out=template_rot,
|
812
|
+
out_mask=temp,
|
813
|
+
use_geometric_center=True,
|
814
|
+
order=interpolation_order,
|
815
|
+
cache=True,
|
816
|
+
)
|
817
|
+
|
818
|
+
template_filter_func(template_rot, temp_ft, template_filter)
|
819
|
+
normalize_template(template_rot, temp, be.sum(temp))
|
820
|
+
|
821
|
+
temp_ft = rfftn(template_rot, temp_ft)
|
822
|
+
temp2 = irfftn(target_mask_ft * temp_ft, temp2)
|
823
|
+
numerator = irfftn(target_ft * temp_ft, numerator)
|
824
|
+
|
825
|
+
# temp template_mask_rot | temp_ft template_mask_rot_ft
|
826
|
+
# Calculate overlap of masks at every point in the convolution.
|
827
|
+
# Locations with high overlap should not be taken into account.
|
828
|
+
temp_ft = rfftn(temp, temp_ft)
|
829
|
+
mask_overlap = irfftn(temp_ft * target_mask_ft, mask_overlap)
|
830
|
+
be.maximum(mask_overlap, eps, out=mask_overlap)
|
831
|
+
temp = irfftn(temp_ft * target_ft, temp)
|
832
|
+
|
833
|
+
be.subtract(
|
834
|
+
numerator,
|
835
|
+
be.divide(be.multiply(temp, temp2), mask_overlap),
|
836
|
+
out=numerator,
|
837
|
+
)
|
838
|
+
|
839
|
+
# temp_3 = fixed_denom
|
840
|
+
be.multiply(temp_ft, target_ft2, out=temp_ft)
|
841
|
+
temp3 = irfftn(temp_ft, temp3)
|
842
|
+
be.subtract(temp3, be.divide(be.square(temp), mask_overlap), out=temp3)
|
843
|
+
be.maximum(temp3, 0.0, out=temp3)
|
844
|
+
|
845
|
+
# temp = moving_denom
|
846
|
+
temp_ft = rfftn(be.square(template_rot), temp_ft)
|
847
|
+
be.multiply(target_mask_ft, temp_ft, out=temp_ft)
|
848
|
+
temp = irfftn(temp_ft, temp)
|
849
|
+
|
850
|
+
be.subtract(temp, be.divide(be.square(temp2), mask_overlap), out=temp)
|
851
|
+
be.maximum(temp, 0.0, out=temp)
|
852
|
+
|
853
|
+
# temp_2 = denom
|
854
|
+
be.multiply(temp3, temp, out=temp)
|
855
|
+
be.sqrt(temp, temp2)
|
856
|
+
|
857
|
+
# Pixels where `denom` is very small will introduce large
|
858
|
+
# numbers after division. To get around this problem,
|
859
|
+
# we zero-out problematic pixels.
|
860
|
+
tol = 1e3 * eps * be.max(be.abs(temp2), axis=axes, keepdims=True)
|
861
|
+
|
862
|
+
temp2[temp2 < tol] = 1
|
863
|
+
temp = be.divide(numerator, temp2, out=temp)
|
864
|
+
temp = be.clip(temp, a_min=-1, a_max=1, out=temp)
|
865
|
+
|
866
|
+
# Apply overlap ratio threshold
|
867
|
+
number_px_threshold = overlap_ratio * be.max(
|
868
|
+
mask_overlap, axis=axes, keepdims=True
|
869
|
+
)
|
870
|
+
temp[mask_overlap < number_px_threshold] = 0.0
|
871
|
+
callback_func(temp, rotation_matrix=rotation)
|
872
|
+
|
873
|
+
return callback
|
874
|
+
|
875
|
+
|
876
|
+
MATCHING_EXHAUSTIVE_REGISTER = {
|
877
|
+
"CC": (cc_setup, corr_scoring),
|
878
|
+
"LCC": (lcc_setup, corr_scoring),
|
879
|
+
"CORR": (corr_setup, corr_scoring),
|
880
|
+
"CAM": (cam_setup, corr_scoring),
|
881
|
+
"FLCSphericalMask": (flcSphericalMask_setup, corr_scoring),
|
882
|
+
"FLC": (flc_setup, flc_scoring),
|
883
|
+
"MCC": (mcc_setup, mcc_scoring),
|
884
|
+
}
|