pytme 0.3b0__cp311-cp311-macosx_15_0_arm64.whl → 0.3.1__cp311-cp311-macosx_15_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pytme-0.3b0.data → pytme-0.3.1.data}/scripts/estimate_memory_usage.py +1 -5
- {pytme-0.3b0.data → pytme-0.3.1.data}/scripts/match_template.py +177 -226
- {pytme-0.3b0.data → pytme-0.3.1.data}/scripts/postprocess.py +69 -47
- {pytme-0.3b0.data → pytme-0.3.1.data}/scripts/preprocess.py +10 -23
- {pytme-0.3b0.data → pytme-0.3.1.data}/scripts/preprocessor_gui.py +98 -28
- pytme-0.3.1.data/scripts/pytme_runner.py +1223 -0
- {pytme-0.3b0.dist-info → pytme-0.3.1.dist-info}/METADATA +15 -15
- pytme-0.3.1.dist-info/RECORD +133 -0
- {pytme-0.3b0.dist-info → pytme-0.3.1.dist-info}/entry_points.txt +1 -0
- pytme-0.3.1.dist-info/licenses/LICENSE +339 -0
- scripts/estimate_memory_usage.py +1 -5
- scripts/eval.py +93 -0
- scripts/extract_candidates.py +118 -99
- scripts/match_template.py +177 -226
- scripts/match_template_filters.py +1200 -0
- scripts/postprocess.py +69 -47
- scripts/preprocess.py +10 -23
- scripts/preprocessor_gui.py +98 -28
- scripts/pytme_runner.py +1223 -0
- scripts/refine_matches.py +156 -387
- tests/data/.DS_Store +0 -0
- tests/data/Blurring/.DS_Store +0 -0
- tests/data/Maps/.DS_Store +0 -0
- tests/data/Raw/.DS_Store +0 -0
- tests/data/Structures/.DS_Store +0 -0
- tests/preprocessing/test_frequency_filters.py +19 -10
- tests/preprocessing/test_utils.py +18 -0
- tests/test_analyzer.py +122 -122
- tests/test_backends.py +4 -9
- tests/test_density.py +0 -1
- tests/test_matching_cli.py +30 -30
- tests/test_matching_data.py +5 -5
- tests/test_matching_utils.py +11 -61
- tests/test_rotations.py +1 -1
- tme/__version__.py +1 -1
- tme/analyzer/__init__.py +1 -1
- tme/analyzer/_utils.py +5 -8
- tme/analyzer/aggregation.py +28 -9
- tme/analyzer/base.py +25 -36
- tme/analyzer/peaks.py +49 -122
- tme/analyzer/proxy.py +1 -0
- tme/backends/_jax_utils.py +31 -28
- tme/backends/_numpyfftw_utils.py +270 -0
- tme/backends/cupy_backend.py +11 -54
- tme/backends/jax_backend.py +72 -48
- tme/backends/matching_backend.py +6 -51
- tme/backends/mlx_backend.py +1 -27
- tme/backends/npfftw_backend.py +95 -90
- tme/backends/pytorch_backend.py +5 -26
- tme/density.py +7 -10
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/filters/__init__.py +2 -2
- tme/filters/_utils.py +32 -7
- tme/filters/bandpass.py +225 -186
- tme/filters/ctf.py +138 -87
- tme/filters/reconstruction.py +38 -9
- tme/filters/wedge.py +98 -112
- tme/filters/whitening.py +1 -6
- tme/mask.py +341 -0
- tme/matching_data.py +20 -44
- tme/matching_exhaustive.py +46 -56
- tme/matching_optimization.py +2 -1
- tme/matching_scores.py +216 -412
- tme/matching_utils.py +82 -424
- tme/memory.py +1 -1
- tme/orientations.py +16 -8
- tme/parser.py +109 -29
- tme/preprocessor.py +2 -2
- tme/rotations.py +1 -1
- pytme-0.3b0.dist-info/RECORD +0 -122
- pytme-0.3b0.dist-info/licenses/LICENSE +0 -153
- {pytme-0.3b0.dist-info → pytme-0.3.1.dist-info}/WHEEL +0 -0
- {pytme-0.3b0.dist-info → pytme-0.3.1.dist-info}/top_level.txt +0 -0
tme/matching_scores.py
CHANGED
@@ -18,138 +18,9 @@ from .matching_utils import (
|
|
18
18
|
conditional_execute,
|
19
19
|
identity,
|
20
20
|
normalize_template,
|
21
|
-
_normalize_template_overflow_safe,
|
22
21
|
)
|
23
22
|
|
24
23
|
|
25
|
-
def _shape_match(shape1: Tuple[int], shape2: Tuple[int]) -> bool:
|
26
|
-
"""
|
27
|
-
Determine whether ``shape1`` is equal to ``shape2``.
|
28
|
-
|
29
|
-
Parameters
|
30
|
-
----------
|
31
|
-
shape1, shape2 : tuple of ints
|
32
|
-
Shapes to compare.
|
33
|
-
|
34
|
-
Returns
|
35
|
-
-------
|
36
|
-
Bool
|
37
|
-
``shape1`` is equal to ``shape2``.
|
38
|
-
"""
|
39
|
-
if len(shape1) != len(shape2):
|
40
|
-
return False
|
41
|
-
return shape1 == shape2
|
42
|
-
|
43
|
-
|
44
|
-
def _create_filter_func(
|
45
|
-
fwd_shape: Tuple[int],
|
46
|
-
inv_shape: Tuple[int],
|
47
|
-
arr_shape: Tuple[int],
|
48
|
-
arr_filter: BackendArray,
|
49
|
-
arr_ft_shape: Tuple[int],
|
50
|
-
inv_output_shape: Tuple[int],
|
51
|
-
real_dtype: type,
|
52
|
-
cmpl_dtype: type,
|
53
|
-
fwd_axes=None,
|
54
|
-
inv_axes=None,
|
55
|
-
rfftn: Callable = None,
|
56
|
-
irfftn: Callable = None,
|
57
|
-
) -> Callable:
|
58
|
-
"""
|
59
|
-
Configure template filtering function for Fourier transforms.
|
60
|
-
|
61
|
-
Conceptually we distinguish between three cases. The base case
|
62
|
-
is that both template and the corresponding filter have the same
|
63
|
-
shape. Padding is used when the template filter is larger than
|
64
|
-
the template, for instance to better resolve Fourier filters. Finally
|
65
|
-
this function also handles the case when a filter is supposed to be
|
66
|
-
broadcasted over the template batch dimension.
|
67
|
-
|
68
|
-
Parameters
|
69
|
-
----------
|
70
|
-
fwd_shape : tuple of ints
|
71
|
-
Input shape of rfftn.
|
72
|
-
inv_shape : tuple of ints
|
73
|
-
Input shape of irfftn.
|
74
|
-
arr_shape : tuple of ints
|
75
|
-
Shape of the array to be filtered.
|
76
|
-
arr_ft_shape : tuple of ints
|
77
|
-
Shape of the Fourier transform of the array.
|
78
|
-
arr_filter : BackendArray
|
79
|
-
Precomputed filter to apply in the frequency domain.
|
80
|
-
rfftn : Callable, optional
|
81
|
-
Foward Fourier transform.
|
82
|
-
irfftn : Callable, optional
|
83
|
-
Inverse Fourier transform.
|
84
|
-
|
85
|
-
Returns
|
86
|
-
-------
|
87
|
-
Callable
|
88
|
-
Filter function with parameters template, ft_temp and template_filter.
|
89
|
-
"""
|
90
|
-
if be.size(arr_filter) == 1:
|
91
|
-
return conditional_execute(identity, identity, False)
|
92
|
-
|
93
|
-
filter_shape = tuple(int(x) for x in arr_filter.shape)
|
94
|
-
try:
|
95
|
-
product_ft_shape = np.broadcast_shapes(arr_ft_shape, filter_shape)
|
96
|
-
except ValueError:
|
97
|
-
product_ft_shape, inv_output_shape = filter_shape, arr_shape
|
98
|
-
|
99
|
-
rfft_valid = _shape_match(arr_shape, fwd_shape)
|
100
|
-
rfft_valid = rfft_valid and _shape_match(product_ft_shape, inv_shape)
|
101
|
-
rfft_valid = rfft_valid and rfftn is not None and irfftn is not None
|
102
|
-
|
103
|
-
# FTTs were not or built for the wrong shape
|
104
|
-
if not rfft_valid:
|
105
|
-
_fwd_shape = arr_shape
|
106
|
-
if all(x > y for x, y in zip(arr_shape, product_ft_shape)):
|
107
|
-
_fwd_shape = fwd_shape
|
108
|
-
|
109
|
-
rfftn, irfftn = be.build_fft(
|
110
|
-
fwd_shape=_fwd_shape,
|
111
|
-
inv_shape=product_ft_shape,
|
112
|
-
real_dtype=real_dtype,
|
113
|
-
cmpl_dtype=cmpl_dtype,
|
114
|
-
inv_output_shape=inv_output_shape,
|
115
|
-
fwd_axes=fwd_axes,
|
116
|
-
inv_axes=inv_axes,
|
117
|
-
)
|
118
|
-
|
119
|
-
# Default case, all shapes are correctly matched
|
120
|
-
def _apply_filter(template, ft_temp, template_filter):
|
121
|
-
ft_temp = rfftn(template, ft_temp)
|
122
|
-
ft_temp = be.multiply(ft_temp, template_filter, out=ft_temp)
|
123
|
-
return irfftn(ft_temp, template)
|
124
|
-
|
125
|
-
if not _shape_match(arr_ft_shape, filter_shape):
|
126
|
-
real_subset = tuple(slice(0, x) for x in arr_shape)
|
127
|
-
_template = be.zeros(arr_shape, be._float_dtype)
|
128
|
-
_ft_temp = be.zeros(product_ft_shape, be._complex_dtype)
|
129
|
-
|
130
|
-
# Arr is padded, filter is not
|
131
|
-
def _apply_filter_subset(template, ft_temp, template_filter):
|
132
|
-
# TODO: Benchmark this
|
133
|
-
_template[:] = template[real_subset]
|
134
|
-
template[real_subset] = _apply_filter(_template, _ft_temp, template_filter)
|
135
|
-
return template
|
136
|
-
|
137
|
-
# Filter application requires a broadcasting operation
|
138
|
-
def _apply_filter_broadcast(template, ft_temp, template_filter):
|
139
|
-
_ft_prod = rfftn(template, _ft_temp2)
|
140
|
-
_ft_res = be.multiply(_ft_prod, template_filter, out=_ft_temp)
|
141
|
-
return irfftn(_ft_res, _template)
|
142
|
-
|
143
|
-
if any(x > y and y == 1 for x, y in zip(filter_shape, arr_ft_shape)):
|
144
|
-
_template = be.zeros(inv_output_shape, be._float_dtype)
|
145
|
-
_ft_temp2 = be.zeros((1, *product_ft_shape[1:]), be._complex_dtype)
|
146
|
-
return _apply_filter_broadcast
|
147
|
-
|
148
|
-
return _apply_filter_subset
|
149
|
-
|
150
|
-
return _apply_filter
|
151
|
-
|
152
|
-
|
153
24
|
def cc_setup(
|
154
25
|
matching_data: type,
|
155
26
|
fast_shape: Tuple[int],
|
@@ -176,8 +47,6 @@ def cc_setup(
|
|
176
47
|
axes = matching_data._batch_axis(matching_data._batch_mask)
|
177
48
|
|
178
49
|
ret = {
|
179
|
-
"fast_shape": fast_shape,
|
180
|
-
"fast_ft_shape": fast_ft_shape,
|
181
50
|
"template": be.to_sharedarr(matching_data.template, shm_handler),
|
182
51
|
"ft_target": be.to_sharedarr(be.rfftn(target_pad, axes=axes), shm_handler),
|
183
52
|
"inv_denominator": be.to_sharedarr(
|
@@ -219,8 +88,8 @@ def lcc_setup(matching_data, **kwargs) -> Dict:
|
|
219
88
|
for subset in subsets:
|
220
89
|
template[subset] = laplace(template[subset], mode="wrap")
|
221
90
|
|
222
|
-
matching_data._target = target
|
223
|
-
matching_data._template = template
|
91
|
+
matching_data._target = be.to_backend_array(target)
|
92
|
+
matching_data._template = be.to_backend_array(template)
|
224
93
|
|
225
94
|
return cc_setup(matching_data=matching_data, **kwargs)
|
226
95
|
|
@@ -316,8 +185,6 @@ def corr_setup(
|
|
316
185
|
denominator = be.multiply(denominator, mask, out=denominator)
|
317
186
|
|
318
187
|
ret = {
|
319
|
-
"fast_shape": fast_shape,
|
320
|
-
"fast_ft_shape": fast_ft_shape,
|
321
188
|
"template": be.to_sharedarr(template, shm_handler),
|
322
189
|
"ft_target": be.to_sharedarr(ft_target, shm_handler),
|
323
190
|
"inv_denominator": be.to_sharedarr(denominator, shm_handler),
|
@@ -376,8 +243,6 @@ def flc_setup(
|
|
376
243
|
ft_target2 = be.rfftn(target_pad, axes=data_axes)
|
377
244
|
|
378
245
|
ret = {
|
379
|
-
"fast_shape": fast_shape,
|
380
|
-
"fast_ft_shape": fast_ft_shape,
|
381
246
|
"template": be.to_sharedarr(matching_data.template, shm_handler),
|
382
247
|
"template_mask": be.to_sharedarr(matching_data.template_mask, shm_handler),
|
383
248
|
"ft_target": be.to_sharedarr(ft_target, shm_handler),
|
@@ -438,8 +303,6 @@ def flcSphericalMask_setup(
|
|
438
303
|
|
439
304
|
temp2 = be.norm_scores(1, temp2, temp, n_obs, be.eps(be._float_dtype), temp2)
|
440
305
|
ret = {
|
441
|
-
"fast_shape": fast_shape,
|
442
|
-
"fast_ft_shape": fast_ft_shape,
|
443
306
|
"template": be.to_sharedarr(matching_data.template, shm_handler),
|
444
307
|
"template_mask": be.to_sharedarr(template_mask, shm_handler),
|
445
308
|
"ft_target": be.to_sharedarr(ft_target, shm_handler),
|
@@ -461,21 +324,14 @@ def mcc_setup(
|
|
461
324
|
Setup function for :py:meth:`mcc_scoring`.
|
462
325
|
"""
|
463
326
|
target, target_mask = matching_data.target, matching_data.target_mask
|
464
|
-
target = be.multiply(target, target_mask
|
327
|
+
target = be.multiply(target, target_mask, out=target)
|
465
328
|
|
466
|
-
target = be.topleft_pad(
|
467
|
-
target,
|
468
|
-
matching_data._batch_shape(fast_shape, matching_data._template_batch),
|
469
|
-
)
|
470
|
-
target_mask = be.topleft_pad(
|
471
|
-
target_mask,
|
472
|
-
matching_data._batch_shape(fast_shape, matching_data._template_batch),
|
473
|
-
)
|
474
329
|
ax = matching_data._batch_axis(matching_data._batch_mask)
|
330
|
+
shape = matching_data._batch_shape(fast_shape, matching_data._template_batch)
|
331
|
+
target = be.topleft_pad(target, shape)
|
332
|
+
target_mask = be.topleft_pad(target_mask, shape)
|
475
333
|
|
476
334
|
ret = {
|
477
|
-
"fast_shape": fast_shape,
|
478
|
-
"fast_ft_shape": fast_ft_shape,
|
479
335
|
"template": be.to_sharedarr(matching_data.template, shm_handler),
|
480
336
|
"template_mask": be.to_sharedarr(matching_data.template_mask, shm_handler),
|
481
337
|
"ft_target": be.to_sharedarr(be.rfftn(target, axes=ax), shm_handler),
|
@@ -541,8 +397,7 @@ def corr_scoring(
|
|
541
397
|
|
542
398
|
Returns
|
543
399
|
-------
|
544
|
-
|
545
|
-
``callback`` if provided otherwise None.
|
400
|
+
CallbackClass
|
546
401
|
"""
|
547
402
|
template = be.from_sharedarr(template)
|
548
403
|
ft_target = be.from_sharedarr(ft_target)
|
@@ -550,71 +405,49 @@ def corr_scoring(
|
|
550
405
|
numerator = be.from_sharedarr(numerator)
|
551
406
|
template_filter = be.from_sharedarr(template_filter)
|
552
407
|
|
553
|
-
|
408
|
+
n_obs = None
|
554
409
|
if template_mask is not None:
|
555
410
|
template_mask = be.from_sharedarr(template_mask)
|
556
|
-
|
557
|
-
|
558
|
-
|
559
|
-
|
560
|
-
|
561
|
-
callback_func = conditional_execute(callback, callback is not None)
|
562
|
-
norm_template = conditional_execute(norm_func, norm_template)
|
563
|
-
norm_numerator = conditional_execute(
|
564
|
-
be.subtract, identity, _shape_match(numerator.shape, fast_shape)
|
565
|
-
)
|
566
|
-
norm_denominator = conditional_execute(
|
567
|
-
be.multiply, identity, _shape_match(inv_denominator.shape, fast_shape)
|
568
|
-
)
|
411
|
+
n_obs = be.sum(template_mask) if template_mask is not None else None
|
412
|
+
|
413
|
+
norm_template = conditional_execute(normalize_template, n_obs is not None)
|
414
|
+
norm_sub = conditional_execute(be.subtract, numerator.shape != (1,))
|
415
|
+
norm_mul = conditional_execute(be.multiply, inv_denominator.shape != (1))
|
569
416
|
|
570
417
|
arr = be.zeros(fast_shape, be._float_dtype)
|
571
418
|
ft_temp = be.zeros(fast_ft_shape, be._complex_dtype)
|
572
|
-
|
573
|
-
_fftargs = {
|
574
|
-
"real_dtype": be._float_dtype,
|
575
|
-
"cmpl_dtype": be._complex_dtype,
|
576
|
-
"inv_output_shape": fast_shape,
|
577
|
-
"fwd_axes": None,
|
578
|
-
"inv_axes": None,
|
579
|
-
"inv_shape": fast_ft_shape,
|
580
|
-
"temp_fwd": arr,
|
581
|
-
}
|
582
|
-
|
583
|
-
_fftargs["fwd_shape"] = _fftargs["temp_fwd"].shape
|
584
|
-
rfftn, irfftn = be.build_fft(temp_inv=ft_temp, **_fftargs)
|
585
|
-
_ = _fftargs.pop("temp_fwd", None)
|
419
|
+
template_rot = be.zeros(template.shape, be._float_dtype)
|
586
420
|
|
587
421
|
template_filter_func = _create_filter_func(
|
588
422
|
arr_shape=template.shape,
|
589
|
-
|
590
|
-
arr_filter=template_filter,
|
591
|
-
rfftn=rfftn,
|
592
|
-
irfftn=irfftn,
|
593
|
-
**_fftargs,
|
423
|
+
filter_shape=template_filter.shape,
|
594
424
|
)
|
595
425
|
|
426
|
+
center = be.divide(be.to_backend_array(template.shape) - 1, 2)
|
596
427
|
unpadded_slice = tuple(slice(0, stop) for stop in template.shape)
|
597
428
|
for index in range(rotations.shape[0]):
|
598
429
|
rotation = rotations[index]
|
599
|
-
|
600
|
-
|
430
|
+
matrix = be._rigid_transform_matrix(rotation_matrix=rotation, center=center)
|
431
|
+
_ = be.rigid_transform(
|
601
432
|
arr=template,
|
602
|
-
rotation_matrix=
|
603
|
-
out=
|
604
|
-
use_geometric_center=True,
|
433
|
+
rotation_matrix=matrix,
|
434
|
+
out=template_rot,
|
605
435
|
order=interpolation_order,
|
606
|
-
cache=
|
436
|
+
cache=True,
|
607
437
|
)
|
608
|
-
arr = template_filter_func(arr, ft_temp, template_filter)
|
609
|
-
norm_template(arr[unpadded_slice], template_mask, mask_sum)
|
610
438
|
|
611
|
-
|
612
|
-
|
613
|
-
|
439
|
+
template_rot = template_filter_func(template_rot, ft_temp, template_filter)
|
440
|
+
norm_template(template_rot, template_mask, n_obs)
|
441
|
+
|
442
|
+
arr = be.fill(arr, 0)
|
443
|
+
arr[unpadded_slice] = template_rot
|
614
444
|
|
615
|
-
|
616
|
-
arr =
|
617
|
-
|
445
|
+
ft_temp = be.rfftn(arr, s=fast_shape, out=ft_temp)
|
446
|
+
arr = _correlate_fts(ft_target, ft_temp, ft_temp, arr, fast_shape)
|
447
|
+
|
448
|
+
arr = norm_sub(arr, numerator, out=arr)
|
449
|
+
arr = norm_mul(arr, inv_denominator, out=arr)
|
450
|
+
callback(arr, rotation_matrix=rotation)
|
618
451
|
|
619
452
|
return callback
|
620
453
|
|
@@ -675,6 +508,10 @@ def flc_scoring(
|
|
675
508
|
interpolation_order : int
|
676
509
|
Spline order for template rotations.
|
677
510
|
|
511
|
+
Returns
|
512
|
+
-------
|
513
|
+
CallbackClass
|
514
|
+
|
678
515
|
References
|
679
516
|
----------
|
680
517
|
.. [1] Hrabe T. et al, J. Struct. Biol. 178, 177 (2012).
|
@@ -691,63 +528,46 @@ def flc_scoring(
|
|
691
528
|
temp2 = be.zeros(fast_shape, float_dtype)
|
692
529
|
ft_temp = be.zeros(fast_ft_shape, complex_dtype)
|
693
530
|
ft_denom = be.zeros(fast_ft_shape, complex_dtype)
|
531
|
+
template_rot = be.zeros(template.shape, be._float_dtype)
|
532
|
+
template_mask_rot = be.zeros(template.shape, be._float_dtype)
|
694
533
|
|
695
|
-
|
696
|
-
"real_dtype": be._float_dtype,
|
697
|
-
"cmpl_dtype": be._complex_dtype,
|
698
|
-
"inv_output_shape": fast_shape,
|
699
|
-
"fwd_axes": None,
|
700
|
-
"inv_axes": None,
|
701
|
-
"inv_shape": fast_ft_shape,
|
702
|
-
"temp_fwd": arr,
|
703
|
-
}
|
704
|
-
|
705
|
-
_fftargs["fwd_shape"] = _fftargs["temp_fwd"].shape
|
706
|
-
rfftn, irfftn = be.build_fft(temp_inv=ft_temp, **_fftargs)
|
707
|
-
_ = _fftargs.pop("temp_fwd", None)
|
708
|
-
|
709
|
-
template_filter_func = _create_filter_func(
|
710
|
-
arr_shape=template.shape,
|
711
|
-
arr_ft_shape=fast_ft_shape,
|
712
|
-
arr_filter=template_filter,
|
713
|
-
rfftn=rfftn,
|
714
|
-
irfftn=irfftn,
|
715
|
-
**_fftargs,
|
716
|
-
)
|
534
|
+
tmpl_filter_func = _create_filter_func(template.shape, template_filter.shape)
|
717
535
|
|
718
536
|
eps = be.eps(float_dtype)
|
719
|
-
|
537
|
+
center = be.divide(be.to_backend_array(template.shape) - 1, 2)
|
538
|
+
unpadded_slice = tuple(slice(0, stop) for stop in template.shape)
|
720
539
|
for index in range(rotations.shape[0]):
|
721
540
|
rotation = rotations[index]
|
722
|
-
|
723
|
-
|
724
|
-
arr, temp = be.rigid_transform(
|
541
|
+
matrix = be._rigid_transform_matrix(rotation_matrix=rotation, center=center)
|
542
|
+
_ = be.rigid_transform(
|
725
543
|
arr=template,
|
726
544
|
arr_mask=template_mask,
|
727
|
-
rotation_matrix=
|
728
|
-
out=
|
729
|
-
out_mask=
|
545
|
+
rotation_matrix=matrix,
|
546
|
+
out=template_rot,
|
547
|
+
out_mask=template_mask_rot,
|
730
548
|
use_geometric_center=True,
|
731
549
|
order=interpolation_order,
|
732
|
-
cache=
|
550
|
+
cache=True,
|
733
551
|
)
|
734
552
|
|
735
|
-
n_obs = be.sum(
|
736
|
-
|
737
|
-
|
553
|
+
n_obs = be.sum(template_mask_rot)
|
554
|
+
template_rot = tmpl_filter_func(template_rot, ft_temp, template_filter)
|
555
|
+
template_rot = normalize_template(template_rot, template_mask_rot, n_obs)
|
556
|
+
|
557
|
+
arr = be.fill(arr, 0)
|
558
|
+
temp = be.fill(temp, 0)
|
559
|
+
arr[unpadded_slice] = template_rot
|
560
|
+
temp[unpadded_slice] = template_mask_rot
|
738
561
|
|
739
|
-
ft_temp = rfftn(temp, ft_temp)
|
740
|
-
|
741
|
-
|
742
|
-
ft_denom = be.multiply(ft_target2, ft_temp, out=ft_denom)
|
743
|
-
temp2 = irfftn(ft_denom, temp2)
|
562
|
+
ft_temp = be.rfftn(temp, out=ft_temp, s=fast_shape)
|
563
|
+
temp = _correlate_fts(ft_target, ft_temp, ft_denom, temp, fast_shape)
|
564
|
+
temp2 = _correlate_fts(ft_target2, ft_temp, ft_denom, temp, fast_shape)
|
744
565
|
|
745
|
-
ft_temp = rfftn(arr, ft_temp)
|
746
|
-
|
747
|
-
arr = irfftn(ft_temp, arr)
|
566
|
+
ft_temp = be.rfftn(arr, out=ft_temp, s=fast_shape)
|
567
|
+
arr = _correlate_fts(ft_target, ft_temp, ft_temp, arr, fast_shape)
|
748
568
|
|
749
569
|
arr = be.norm_scores(arr, temp2, temp, n_obs, eps, arr)
|
750
|
-
|
570
|
+
callback(arr, rotation_matrix=rotation)
|
751
571
|
|
752
572
|
return callback
|
753
573
|
|
@@ -814,6 +634,10 @@ def mcc_scoring(
|
|
814
634
|
overlap_ratio : float, optional
|
815
635
|
Required fractional mask overlap, 0.3 by default.
|
816
636
|
|
637
|
+
Returns
|
638
|
+
-------
|
639
|
+
CallbackClass
|
640
|
+
|
817
641
|
References
|
818
642
|
----------
|
819
643
|
.. [1] Masked FFT registration, Dirk Padfield, CVPR 2010 conference
|
@@ -839,30 +663,12 @@ def mcc_scoring(
|
|
839
663
|
temp3 = be.zeros(fast_shape, float_dtype)
|
840
664
|
temp_ft = be.zeros(fast_ft_shape, complex_dtype)
|
841
665
|
|
842
|
-
_fftargs = {
|
843
|
-
"real_dtype": be._float_dtype,
|
844
|
-
"cmpl_dtype": be._complex_dtype,
|
845
|
-
"inv_output_shape": fast_shape,
|
846
|
-
"fwd_axes": None,
|
847
|
-
"inv_axes": None,
|
848
|
-
"inv_shape": fast_ft_shape,
|
849
|
-
"temp_fwd": temp,
|
850
|
-
}
|
851
|
-
|
852
|
-
_fftargs["fwd_shape"] = _fftargs["temp_fwd"].shape
|
853
|
-
rfftn, irfftn = be.build_fft(temp_inv=temp_ft, **_fftargs)
|
854
|
-
_ = _fftargs.pop("temp_fwd", None)
|
855
|
-
|
856
666
|
template_filter_func = _create_filter_func(
|
857
667
|
arr_shape=template.shape,
|
858
|
-
|
859
|
-
|
860
|
-
rfftn=rfftn,
|
861
|
-
irfftn=irfftn,
|
862
|
-
**_fftargs,
|
668
|
+
filter_shape=template_filter.shape,
|
669
|
+
arr_padded=True,
|
863
670
|
)
|
864
671
|
|
865
|
-
callback_func = conditional_execute(callback, callback is not None)
|
866
672
|
for index in range(rotations.shape[0]):
|
867
673
|
rotation = rotations[index]
|
868
674
|
template_rot = be.fill(template_rot, 0)
|
@@ -875,23 +681,25 @@ def mcc_scoring(
|
|
875
681
|
out_mask=temp,
|
876
682
|
use_geometric_center=True,
|
877
683
|
order=interpolation_order,
|
878
|
-
cache=
|
684
|
+
cache=True,
|
879
685
|
)
|
880
686
|
|
881
687
|
template_filter_func(template_rot, temp_ft, template_filter)
|
882
688
|
normalize_template(template_rot, temp, be.sum(temp))
|
883
689
|
|
884
|
-
temp_ft = rfftn(template_rot, temp_ft)
|
885
|
-
temp2 = irfftn(target_mask_ft * temp_ft, temp2)
|
886
|
-
numerator = irfftn(target_ft * temp_ft, numerator)
|
690
|
+
temp_ft = be.rfftn(template_rot, out=temp_ft, s=fast_shape)
|
691
|
+
temp2 = be.irfftn(target_mask_ft * temp_ft, out=temp2, s=fast_shape)
|
692
|
+
numerator = be.irfftn(target_ft * temp_ft, out=numerator, s=fast_shape)
|
887
693
|
|
888
694
|
# temp template_mask_rot | temp_ft template_mask_rot_ft
|
889
695
|
# Calculate overlap of masks at every point in the convolution.
|
890
696
|
# Locations with high overlap should not be taken into account.
|
891
|
-
temp_ft = rfftn(temp, temp_ft)
|
892
|
-
mask_overlap = irfftn(
|
697
|
+
temp_ft = be.rfftn(temp, out=temp_ft, s=fast_shape)
|
698
|
+
mask_overlap = be.irfftn(
|
699
|
+
temp_ft * target_mask_ft, out=mask_overlap, s=fast_shape
|
700
|
+
)
|
893
701
|
be.maximum(mask_overlap, eps, out=mask_overlap)
|
894
|
-
temp = irfftn(temp_ft * target_ft, temp)
|
702
|
+
temp = be.irfftn(temp_ft * target_ft, out=temp, s=fast_shape)
|
895
703
|
|
896
704
|
be.subtract(
|
897
705
|
numerator,
|
@@ -901,21 +709,21 @@ def mcc_scoring(
|
|
901
709
|
|
902
710
|
# temp_3 = fixed_denom
|
903
711
|
be.multiply(temp_ft, target_ft2, out=temp_ft)
|
904
|
-
temp3 = irfftn(temp_ft, temp3)
|
712
|
+
temp3 = be.irfftn(temp_ft, out=temp3, s=fast_shape)
|
905
713
|
be.subtract(temp3, be.divide(be.square(temp), mask_overlap), out=temp3)
|
906
714
|
be.maximum(temp3, 0.0, out=temp3)
|
907
715
|
|
908
716
|
# temp = moving_denom
|
909
|
-
temp_ft = rfftn(be.square(template_rot), temp_ft)
|
717
|
+
temp_ft = be.rfftn(be.square(template_rot), out=temp_ft, s=fast_shape)
|
910
718
|
be.multiply(target_mask_ft, temp_ft, out=temp_ft)
|
911
|
-
temp = irfftn(temp_ft, temp)
|
719
|
+
temp = be.irfftn(temp_ft, out=temp, s=fast_shape)
|
912
720
|
|
913
721
|
be.subtract(temp, be.divide(be.square(temp2), mask_overlap), out=temp)
|
914
722
|
be.maximum(temp, 0.0, out=temp)
|
915
723
|
|
916
724
|
# temp_2 = denom
|
917
725
|
be.multiply(temp3, temp, out=temp)
|
918
|
-
be.sqrt(temp, temp2)
|
726
|
+
be.sqrt(temp, out=temp2)
|
919
727
|
|
920
728
|
# Pixels where `denom` is very small will introduce large
|
921
729
|
# numbers after division. To get around this problem,
|
@@ -931,29 +739,11 @@ def mcc_scoring(
|
|
931
739
|
mask_overlap, axis=axes, keepdims=True
|
932
740
|
)
|
933
741
|
temp[mask_overlap < number_px_threshold] = 0.0
|
934
|
-
|
742
|
+
callback(temp, rotation_matrix=rotation)
|
935
743
|
|
936
744
|
return callback
|
937
745
|
|
938
746
|
|
939
|
-
def _format_slice(shape, squeeze_axis):
|
940
|
-
ret = tuple(
|
941
|
-
slice(None) if i not in squeeze_axis else 0 for i, _ in enumerate(shape)
|
942
|
-
)
|
943
|
-
return ret
|
944
|
-
|
945
|
-
|
946
|
-
def _get_batch_dim(target, template):
|
947
|
-
target_batch, template_batch = [], []
|
948
|
-
for i in range(len(target.shape)):
|
949
|
-
if target.shape[i] == 1 and template.shape[i] != 1:
|
950
|
-
template_batch.append(i)
|
951
|
-
if target.shape[i] != 1 and template.shape[i] == 1:
|
952
|
-
target_batch.append(i)
|
953
|
-
|
954
|
-
return target_batch, template_batch
|
955
|
-
|
956
|
-
|
957
747
|
def flc_scoring2(
|
958
748
|
template: shm_type,
|
959
749
|
template_mask: shm_type,
|
@@ -966,25 +756,22 @@ def flc_scoring2(
|
|
966
756
|
callback: CallbackClass,
|
967
757
|
interpolation_order: int,
|
968
758
|
) -> CallbackClass:
|
969
|
-
callback_func = conditional_execute(callback, callback is not None)
|
970
|
-
|
971
|
-
# Retrieve objects from shared memory
|
972
759
|
template = be.from_sharedarr(template)
|
973
760
|
template_mask = be.from_sharedarr(template_mask)
|
974
761
|
ft_target = be.from_sharedarr(ft_target)
|
975
762
|
ft_target2 = be.from_sharedarr(ft_target2)
|
976
763
|
template_filter = be.from_sharedarr(template_filter)
|
977
764
|
|
978
|
-
|
979
|
-
|
980
|
-
|
981
|
-
sqz_slice = tuple(slice(0, 1) if
|
765
|
+
tar_batch, tmpl_batch = _get_batch_dim(ft_target, template)
|
766
|
+
|
767
|
+
nd = len(fast_shape)
|
768
|
+
sqz_slice = tuple(slice(0, 1) if i in tar_batch else slice(None) for i in range(nd))
|
769
|
+
tmpl_subset = tuple(0 if i in tar_batch else slice(None) for i in range(nd))
|
982
770
|
|
983
|
-
|
984
|
-
if len(
|
985
|
-
|
986
|
-
|
987
|
-
data_shape = tuple(fast_shape[i] for i in data_axes)
|
771
|
+
axes, shape, batched = None, fast_shape, len(tmpl_batch) > 0
|
772
|
+
if len(tar_batch) or len(tmpl_batch):
|
773
|
+
axes = tuple(i for i in range(nd) if i not in (*tar_batch, *tmpl_batch))
|
774
|
+
shape = tuple(fast_shape[i] for i in axes)
|
988
775
|
|
989
776
|
arr = be.zeros(fast_shape, be._float_dtype)
|
990
777
|
temp = be.zeros(fast_shape, be._float_dtype)
|
@@ -992,34 +779,11 @@ def flc_scoring2(
|
|
992
779
|
ft_denom = be.zeros(fast_ft_shape, be._complex_dtype)
|
993
780
|
|
994
781
|
tmp_sqz, arr_sqz, ft_temp = temp[sqz_slice], arr[sqz_slice], ft_denom[sqz_slice]
|
995
|
-
if be.size(template_filter) != 1:
|
996
|
-
ret_shape = np.broadcast_shapes(
|
997
|
-
sqz_cmpl, tuple(int(x) for x in template_filter.shape)
|
998
|
-
)
|
999
|
-
ft_temp = be.zeros(ret_shape, be._complex_dtype)
|
1000
|
-
|
1001
|
-
_fftargs = {
|
1002
|
-
"real_dtype": be._float_dtype,
|
1003
|
-
"cmpl_dtype": be._complex_dtype,
|
1004
|
-
"inv_output_shape": fast_shape,
|
1005
|
-
"fwd_axes": data_axes,
|
1006
|
-
"inv_axes": data_axes,
|
1007
|
-
"inv_shape": fast_ft_shape,
|
1008
|
-
"temp_fwd": arr_sqz if _shape_match(ft_temp.shape, sqz_cmpl) else arr,
|
1009
|
-
}
|
1010
|
-
|
1011
|
-
# build_fft ignores fwd_shape if temp_fwd is given and serves only for bookkeeping
|
1012
|
-
_fftargs["fwd_shape"] = _fftargs["temp_fwd"].shape
|
1013
|
-
rfftn, irfftn = be.build_fft(temp_inv=ft_denom, **_fftargs)
|
1014
|
-
_ = _fftargs.pop("temp_fwd", None)
|
1015
782
|
|
1016
783
|
template_filter_func = _create_filter_func(
|
1017
784
|
arr_shape=template.shape,
|
1018
|
-
|
1019
|
-
|
1020
|
-
rfftn=rfftn,
|
1021
|
-
irfftn=irfftn,
|
1022
|
-
**_fftargs,
|
785
|
+
filter_shape=template_filter.shape,
|
786
|
+
arr_padded=True,
|
1023
787
|
)
|
1024
788
|
|
1025
789
|
eps = be.eps(be._float_dtype)
|
@@ -1027,32 +791,32 @@ def flc_scoring2(
|
|
1027
791
|
rotation = rotations[index]
|
1028
792
|
be.fill(arr, 0)
|
1029
793
|
be.fill(temp, 0)
|
1030
|
-
|
1031
|
-
|
1032
|
-
|
794
|
+
|
795
|
+
_, _ = be.rigid_transform(
|
796
|
+
arr=template[tmpl_subset],
|
797
|
+
arr_mask=template_mask[tmpl_subset],
|
1033
798
|
rotation_matrix=rotation,
|
1034
|
-
out=arr_sqz,
|
1035
|
-
out_mask=tmp_sqz,
|
799
|
+
out=arr_sqz[tmpl_subset],
|
800
|
+
out_mask=tmp_sqz[tmpl_subset],
|
1036
801
|
use_geometric_center=True,
|
1037
802
|
order=interpolation_order,
|
1038
803
|
cache=False,
|
804
|
+
batched=batched,
|
1039
805
|
)
|
1040
|
-
|
806
|
+
|
807
|
+
n_obs = be.sum(tmp_sqz, axis=axes, keepdims=True)
|
1041
808
|
arr_norm = template_filter_func(arr_sqz, ft_temp, template_filter)
|
1042
|
-
arr_norm = normalize_template(arr_norm, tmp_sqz, n_obs, axis=
|
809
|
+
arr_norm = normalize_template(arr_norm, tmp_sqz, n_obs, axis=axes)
|
1043
810
|
|
1044
|
-
ft_temp = be.rfftn(tmp_sqz, ft_temp, axes=
|
1045
|
-
|
1046
|
-
|
1047
|
-
ft_denom = be.multiply(ft_target2, ft_temp, out=ft_denom)
|
1048
|
-
temp2 = be.irfftn(ft_denom, temp2, axes=data_axes, s=data_shape)
|
811
|
+
ft_temp = be.rfftn(tmp_sqz, out=ft_temp, axes=axes, s=shape)
|
812
|
+
temp = _correlate_fts(ft_target, ft_temp, ft_denom, temp, shape, axes)
|
813
|
+
temp2 = _correlate_fts(ft_target2, ft_temp, ft_denom, temp2, shape, axes)
|
1049
814
|
|
1050
|
-
ft_temp = rfftn(arr_norm,
|
1051
|
-
|
1052
|
-
arr = irfftn(ft_denom, arr)
|
815
|
+
ft_temp = be.rfftn(arr_norm, out=ft_temp, axes=axes, s=shape)
|
816
|
+
arr = _correlate_fts(ft_target, ft_temp, ft_denom, arr, shape, axes)
|
1053
817
|
|
1054
|
-
be.norm_scores(arr, temp2, temp, n_obs, eps, arr)
|
1055
|
-
|
818
|
+
arr = be.norm_scores(arr, temp2, temp, n_obs, eps, arr)
|
819
|
+
callback(arr, rotation_matrix=rotation)
|
1056
820
|
|
1057
821
|
return callback
|
1058
822
|
|
@@ -1077,100 +841,140 @@ def corr_scoring2(
|
|
1077
841
|
numerator = be.from_sharedarr(numerator)
|
1078
842
|
template_filter = be.from_sharedarr(template_filter)
|
1079
843
|
|
1080
|
-
|
1081
|
-
|
1082
|
-
|
1083
|
-
sqz_slice = tuple(slice(0, 1) if
|
1084
|
-
|
1085
|
-
|
1086
|
-
|
1087
|
-
|
1088
|
-
|
1089
|
-
|
1090
|
-
|
1091
|
-
|
844
|
+
tar_batch, tmpl_batch = _get_batch_dim(ft_target, template)
|
845
|
+
|
846
|
+
nd = len(fast_shape)
|
847
|
+
sqz_slice = tuple(slice(0, 1) if i in tar_batch else slice(None) for i in range(nd))
|
848
|
+
tmpl_subset = tuple(0 if i in tar_batch else slice(None) for i in range(nd))
|
849
|
+
|
850
|
+
axes, shape, batched = None, fast_shape, len(tmpl_batch) > 0
|
851
|
+
if len(tar_batch) or len(tmpl_batch):
|
852
|
+
axes = tuple(i for i in range(nd) if i not in (*tar_batch, *tmpl_batch))
|
853
|
+
shape = tuple(fast_shape[i] for i in axes)
|
854
|
+
|
855
|
+
unpadded_slice = tuple(
|
856
|
+
slice(None) if i in (*tar_batch, *tmpl_batch) else slice(0, x)
|
857
|
+
for i, x in enumerate(template.shape)
|
858
|
+
)
|
1092
859
|
|
1093
860
|
arr = be.zeros(fast_shape, be._float_dtype)
|
1094
861
|
ft_temp = be.zeros(fast_ft_shape, be._complex_dtype)
|
1095
862
|
arr_sqz, ft_sqz = arr[sqz_slice], ft_temp[sqz_slice]
|
1096
863
|
|
1097
|
-
|
1098
|
-
# The filter could be w.r.t the unpadded template
|
1099
|
-
ret_shape = tuple(
|
1100
|
-
int(x * y) if x == 1 or y == 1 else y
|
1101
|
-
for x, y in zip(sqz_cmpl, template_filter.shape)
|
1102
|
-
)
|
1103
|
-
ft_sqz = be.zeros(ret_shape, be._complex_dtype)
|
1104
|
-
|
1105
|
-
norm_func, norm_template, mask_sum = normalize_template, False, 1
|
864
|
+
n_obs = None
|
1106
865
|
if template_mask is not None:
|
1107
866
|
template_mask = be.from_sharedarr(template_mask)
|
1108
|
-
|
1109
|
-
be.astype(template_mask, be._overflow_safe_dtype),
|
1110
|
-
axis=data_axes,
|
1111
|
-
keepdims=True,
|
1112
|
-
)
|
1113
|
-
if be.datatype_bytes(template_mask.dtype) == 2:
|
1114
|
-
norm_func = _normalize_template_overflow_safe
|
867
|
+
n_obs = be.sum(template_mask, axis=axes, keepdims=True)
|
1115
868
|
|
1116
|
-
|
1117
|
-
|
1118
|
-
|
1119
|
-
be.subtract, identity, _shape_match(numerator.shape, fast_shape)
|
1120
|
-
)
|
1121
|
-
norm_denominator = conditional_execute(
|
1122
|
-
be.multiply, identity, _shape_match(inv_denominator.shape, fast_shape)
|
1123
|
-
)
|
1124
|
-
|
1125
|
-
_fftargs = {
|
1126
|
-
"real_dtype": be._float_dtype,
|
1127
|
-
"cmpl_dtype": be._complex_dtype,
|
1128
|
-
"fwd_axes": data_axes,
|
1129
|
-
"inv_axes": data_axes,
|
1130
|
-
"inv_shape": fast_ft_shape,
|
1131
|
-
"inv_output_shape": fast_shape,
|
1132
|
-
"temp_fwd": arr_sqz if _shape_match(ft_sqz.shape, sqz_cmpl) else arr,
|
1133
|
-
}
|
1134
|
-
|
1135
|
-
# build_fft ignores fwd_shape if temp_fwd is given and serves only for bookkeeping
|
1136
|
-
_fftargs["fwd_shape"] = _fftargs["temp_fwd"].shape
|
1137
|
-
rfftn, irfftn = be.build_fft(temp_inv=ft_temp, **_fftargs)
|
1138
|
-
_ = _fftargs.pop("temp_fwd", None)
|
869
|
+
norm_template = conditional_execute(normalize_template, n_obs is not None)
|
870
|
+
norm_sub = conditional_execute(be.subtract, numerator.shape != (1,))
|
871
|
+
norm_mul = conditional_execute(be.multiply, inv_denominator.shape != (1,))
|
1139
872
|
|
1140
873
|
template_filter_func = _create_filter_func(
|
1141
874
|
arr_shape=template.shape,
|
1142
|
-
|
1143
|
-
|
1144
|
-
rfftn=rfftn,
|
1145
|
-
irfftn=irfftn,
|
1146
|
-
**_fftargs,
|
875
|
+
filter_shape=template_filter.shape,
|
876
|
+
arr_padded=True,
|
1147
877
|
)
|
1148
878
|
|
1149
879
|
for index in range(rotations.shape[0]):
|
1150
880
|
be.fill(arr, 0)
|
1151
881
|
rotation = rotations[index]
|
1152
|
-
|
1153
|
-
arr=template,
|
882
|
+
_, _ = be.rigid_transform(
|
883
|
+
arr=template[tmpl_subset],
|
1154
884
|
rotation_matrix=rotation,
|
1155
|
-
out=arr_sqz,
|
885
|
+
out=arr_sqz[tmpl_subset],
|
1156
886
|
use_geometric_center=True,
|
1157
887
|
order=interpolation_order,
|
1158
888
|
cache=False,
|
889
|
+
batched=batched,
|
1159
890
|
)
|
1160
891
|
arr_norm = template_filter_func(arr_sqz, ft_sqz, template_filter)
|
1161
|
-
norm_template(arr_norm[unpadded_slice], template_mask,
|
892
|
+
norm_template(arr_norm[unpadded_slice], template_mask, n_obs, axis=axes)
|
1162
893
|
|
1163
|
-
ft_sqz = rfftn(arr_norm, ft_sqz)
|
1164
|
-
|
1165
|
-
arr = irfftn(ft_temp, arr)
|
894
|
+
ft_sqz = be.rfftn(arr_norm, out=ft_sqz, axes=axes, s=shape)
|
895
|
+
arr = _correlate_fts(ft_target, ft_sqz, ft_temp, arr, shape, axes)
|
1166
896
|
|
1167
|
-
arr =
|
1168
|
-
arr =
|
1169
|
-
|
897
|
+
arr = norm_sub(arr, numerator, out=arr)
|
898
|
+
arr = norm_mul(arr, inv_denominator, out=arr)
|
899
|
+
callback(arr, rotation_matrix=rotation)
|
1170
900
|
|
1171
901
|
return callback
|
1172
902
|
|
1173
903
|
|
904
|
+
def _get_batch_dim(target, template):
|
905
|
+
target_batch, template_batch = [], []
|
906
|
+
for i in range(len(target.shape)):
|
907
|
+
if target.shape[i] == 1 and template.shape[i] != 1:
|
908
|
+
template_batch.append(i)
|
909
|
+
if target.shape[i] != 1 and template.shape[i] == 1:
|
910
|
+
target_batch.append(i)
|
911
|
+
|
912
|
+
return target_batch, template_batch
|
913
|
+
|
914
|
+
|
915
|
+
def _correlate_fts(ft_tar, ft_tmpl, ft_buffer, real_buffer, fast_shape, axes=None):
|
916
|
+
ft_buffer = be.multiply(ft_tar, ft_tmpl, out=ft_buffer)
|
917
|
+
return be.irfftn(ft_buffer, out=real_buffer, s=fast_shape, axes=axes)
|
918
|
+
|
919
|
+
|
920
|
+
def _create_filter_func(
|
921
|
+
arr_shape: Tuple[int],
|
922
|
+
filter_shape: BackendArray,
|
923
|
+
arr_padded: bool = False,
|
924
|
+
axes=None,
|
925
|
+
) -> Callable:
|
926
|
+
"""
|
927
|
+
Configure template filtering function for Fourier transforms.
|
928
|
+
|
929
|
+
Conceptually we distinguish between three cases. The base case
|
930
|
+
is that both template and the corresponding filter have the same
|
931
|
+
shape. Padding is used when the template filter is larger than
|
932
|
+
the template, for instance to better resolve Fourier filters. Finally
|
933
|
+
this function also handles the case when a filter is supposed to be
|
934
|
+
broadcasted over the template batch dimension.
|
935
|
+
|
936
|
+
Parameters
|
937
|
+
----------
|
938
|
+
arr_shape : tuple of ints
|
939
|
+
Shape of the array to be filtered.
|
940
|
+
filter_shape : BackendArray
|
941
|
+
Precomputed filter to apply in the frequency domain.
|
942
|
+
arr_padded : bool, optional
|
943
|
+
Whether the input template is padded and will need to be cropped
|
944
|
+
to arr_shape prior to filter applications. Defaults to False.
|
945
|
+
axes : tuple of ints, optional
|
946
|
+
Axes to perform Fourier transform over.
|
947
|
+
|
948
|
+
Returns
|
949
|
+
-------
|
950
|
+
Callable
|
951
|
+
Filter function with parameters template, ft_temp and template_filter.
|
952
|
+
"""
|
953
|
+
if filter_shape == (1,):
|
954
|
+
return conditional_execute(identity, execute_operation=True)
|
955
|
+
|
956
|
+
# Default case, all shapes are correctly matched
|
957
|
+
def _apply_filter(template, ft_temp, template_filter):
|
958
|
+
ft_temp = be.rfftn(template, out=ft_temp, s=template.shape)
|
959
|
+
ft_temp = be.multiply(ft_temp, template_filter, out=ft_temp)
|
960
|
+
return be.irfftn(ft_temp, out=template, s=template.shape)
|
961
|
+
|
962
|
+
if not arr_padded:
|
963
|
+
return _apply_filter
|
964
|
+
|
965
|
+
# Array is padded but filter is w.r.t to the original template
|
966
|
+
real_subset = tuple(slice(0, x) for x in arr_shape)
|
967
|
+
_template = be.zeros(arr_shape, be._float_dtype)
|
968
|
+
_ft_temp = be.zeros(filter_shape, be._complex_dtype)
|
969
|
+
|
970
|
+
def _apply_filter_subset(template, ft_temp, template_filter):
|
971
|
+
_template[:] = template[real_subset]
|
972
|
+
template[real_subset] = _apply_filter(_template, _ft_temp, template_filter)
|
973
|
+
return template
|
974
|
+
|
975
|
+
return _apply_filter_subset
|
976
|
+
|
977
|
+
|
1174
978
|
MATCHING_EXHAUSTIVE_REGISTER = {
|
1175
979
|
"CC": (cc_setup, corr_scoring),
|
1176
980
|
"LCC": (lcc_setup, corr_scoring),
|
@@ -1179,6 +983,6 @@ MATCHING_EXHAUSTIVE_REGISTER = {
|
|
1179
983
|
"FLCSphericalMask": (flcSphericalMask_setup, corr_scoring),
|
1180
984
|
"FLC": (flc_setup, flc_scoring),
|
1181
985
|
"MCC": (mcc_setup, mcc_scoring),
|
1182
|
-
"
|
986
|
+
"batchFLCSphericalMask": (flcSphericalMask_setup, corr_scoring2),
|
1183
987
|
"batchFLC": (flc_setup, flc_scoring2),
|
1184
988
|
}
|