pytme 0.3b0.post1__cp311-cp311-macosx_15_0_arm64.whl → 0.3.1__cp311-cp311-macosx_15_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pytme-0.3b0.post1.data → pytme-0.3.1.data}/scripts/match_template.py +28 -39
- {pytme-0.3b0.post1.data → pytme-0.3.1.data}/scripts/postprocess.py +23 -10
- {pytme-0.3b0.post1.data → pytme-0.3.1.data}/scripts/preprocessor_gui.py +95 -24
- pytme-0.3.1.data/scripts/pytme_runner.py +1223 -0
- {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dist-info}/METADATA +5 -5
- {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dist-info}/RECORD +53 -46
- scripts/extract_candidates.py +118 -99
- scripts/match_template.py +28 -39
- scripts/postprocess.py +23 -10
- scripts/preprocessor_gui.py +95 -24
- scripts/pytme_runner.py +644 -190
- scripts/refine_matches.py +156 -386
- tests/data/.DS_Store +0 -0
- tests/data/Blurring/.DS_Store +0 -0
- tests/data/Maps/.DS_Store +0 -0
- tests/data/Raw/.DS_Store +0 -0
- tests/data/Structures/.DS_Store +0 -0
- tests/preprocessing/test_utils.py +18 -0
- tests/test_backends.py +3 -9
- tests/test_density.py +0 -1
- tests/test_matching_utils.py +10 -60
- tests/test_rotations.py +1 -1
- tme/__version__.py +1 -1
- tme/analyzer/_utils.py +4 -4
- tme/analyzer/aggregation.py +13 -3
- tme/analyzer/peaks.py +11 -10
- tme/backends/_jax_utils.py +15 -13
- tme/backends/_numpyfftw_utils.py +270 -0
- tme/backends/cupy_backend.py +5 -44
- tme/backends/jax_backend.py +58 -37
- tme/backends/matching_backend.py +6 -51
- tme/backends/mlx_backend.py +1 -27
- tme/backends/npfftw_backend.py +68 -65
- tme/backends/pytorch_backend.py +1 -26
- tme/density.py +2 -6
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/filters/ctf.py +22 -21
- tme/filters/wedge.py +10 -7
- tme/mask.py +341 -0
- tme/matching_data.py +7 -19
- tme/matching_exhaustive.py +34 -47
- tme/matching_optimization.py +2 -1
- tme/matching_scores.py +206 -411
- tme/matching_utils.py +73 -422
- tme/memory.py +1 -1
- tme/orientations.py +4 -6
- tme/rotations.py +1 -1
- pytme-0.3b0.post1.data/scripts/pytme_runner.py +0 -769
- {pytme-0.3b0.post1.data → pytme-0.3.1.data}/scripts/estimate_memory_usage.py +0 -0
- {pytme-0.3b0.post1.data → pytme-0.3.1.data}/scripts/preprocess.py +0 -0
- {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dist-info}/WHEEL +0 -0
- {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dist-info}/entry_points.txt +0 -0
- {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dist-info}/licenses/LICENSE +0 -0
- {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dist-info}/top_level.txt +0 -0
tme/matching_scores.py
CHANGED
@@ -18,138 +18,9 @@ from .matching_utils import (
|
|
18
18
|
conditional_execute,
|
19
19
|
identity,
|
20
20
|
normalize_template,
|
21
|
-
_normalize_template_overflow_safe,
|
22
21
|
)
|
23
22
|
|
24
23
|
|
25
|
-
def _shape_match(shape1: Tuple[int], shape2: Tuple[int]) -> bool:
|
26
|
-
"""
|
27
|
-
Determine whether ``shape1`` is equal to ``shape2``.
|
28
|
-
|
29
|
-
Parameters
|
30
|
-
----------
|
31
|
-
shape1, shape2 : tuple of ints
|
32
|
-
Shapes to compare.
|
33
|
-
|
34
|
-
Returns
|
35
|
-
-------
|
36
|
-
Bool
|
37
|
-
``shape1`` is equal to ``shape2``.
|
38
|
-
"""
|
39
|
-
if len(shape1) != len(shape2):
|
40
|
-
return False
|
41
|
-
return shape1 == shape2
|
42
|
-
|
43
|
-
|
44
|
-
def _create_filter_func(
|
45
|
-
fwd_shape: Tuple[int],
|
46
|
-
inv_shape: Tuple[int],
|
47
|
-
arr_shape: Tuple[int],
|
48
|
-
arr_filter: BackendArray,
|
49
|
-
arr_ft_shape: Tuple[int],
|
50
|
-
inv_output_shape: Tuple[int],
|
51
|
-
real_dtype: type,
|
52
|
-
cmpl_dtype: type,
|
53
|
-
fwd_axes=None,
|
54
|
-
inv_axes=None,
|
55
|
-
rfftn: Callable = None,
|
56
|
-
irfftn: Callable = None,
|
57
|
-
) -> Callable:
|
58
|
-
"""
|
59
|
-
Configure template filtering function for Fourier transforms.
|
60
|
-
|
61
|
-
Conceptually we distinguish between three cases. The base case
|
62
|
-
is that both template and the corresponding filter have the same
|
63
|
-
shape. Padding is used when the template filter is larger than
|
64
|
-
the template, for instance to better resolve Fourier filters. Finally
|
65
|
-
this function also handles the case when a filter is supposed to be
|
66
|
-
broadcasted over the template batch dimension.
|
67
|
-
|
68
|
-
Parameters
|
69
|
-
----------
|
70
|
-
fwd_shape : tuple of ints
|
71
|
-
Input shape of rfftn.
|
72
|
-
inv_shape : tuple of ints
|
73
|
-
Input shape of irfftn.
|
74
|
-
arr_shape : tuple of ints
|
75
|
-
Shape of the array to be filtered.
|
76
|
-
arr_ft_shape : tuple of ints
|
77
|
-
Shape of the Fourier transform of the array.
|
78
|
-
arr_filter : BackendArray
|
79
|
-
Precomputed filter to apply in the frequency domain.
|
80
|
-
rfftn : Callable, optional
|
81
|
-
Foward Fourier transform.
|
82
|
-
irfftn : Callable, optional
|
83
|
-
Inverse Fourier transform.
|
84
|
-
|
85
|
-
Returns
|
86
|
-
-------
|
87
|
-
Callable
|
88
|
-
Filter function with parameters template, ft_temp and template_filter.
|
89
|
-
"""
|
90
|
-
if be.size(arr_filter) == 1:
|
91
|
-
return conditional_execute(identity, identity, False)
|
92
|
-
|
93
|
-
filter_shape = tuple(int(x) for x in arr_filter.shape)
|
94
|
-
try:
|
95
|
-
product_ft_shape = np.broadcast_shapes(arr_ft_shape, filter_shape)
|
96
|
-
except ValueError:
|
97
|
-
product_ft_shape, inv_output_shape = filter_shape, arr_shape
|
98
|
-
|
99
|
-
rfft_valid = _shape_match(arr_shape, fwd_shape)
|
100
|
-
rfft_valid = rfft_valid and _shape_match(product_ft_shape, inv_shape)
|
101
|
-
rfft_valid = rfft_valid and rfftn is not None and irfftn is not None
|
102
|
-
|
103
|
-
# FTTs were not or built for the wrong shape
|
104
|
-
if not rfft_valid:
|
105
|
-
_fwd_shape = arr_shape
|
106
|
-
if all(x > y for x, y in zip(arr_shape, product_ft_shape)):
|
107
|
-
_fwd_shape = fwd_shape
|
108
|
-
|
109
|
-
rfftn, irfftn = be.build_fft(
|
110
|
-
fwd_shape=_fwd_shape,
|
111
|
-
inv_shape=product_ft_shape,
|
112
|
-
real_dtype=real_dtype,
|
113
|
-
cmpl_dtype=cmpl_dtype,
|
114
|
-
inv_output_shape=inv_output_shape,
|
115
|
-
fwd_axes=fwd_axes,
|
116
|
-
inv_axes=inv_axes,
|
117
|
-
)
|
118
|
-
|
119
|
-
# Default case, all shapes are correctly matched
|
120
|
-
def _apply_filter(template, ft_temp, template_filter):
|
121
|
-
ft_temp = rfftn(template, ft_temp)
|
122
|
-
ft_temp = be.multiply(ft_temp, template_filter, out=ft_temp)
|
123
|
-
return irfftn(ft_temp, template)
|
124
|
-
|
125
|
-
if not _shape_match(arr_ft_shape, filter_shape):
|
126
|
-
real_subset = tuple(slice(0, x) for x in arr_shape)
|
127
|
-
_template = be.zeros(arr_shape, be._float_dtype)
|
128
|
-
_ft_temp = be.zeros(product_ft_shape, be._complex_dtype)
|
129
|
-
|
130
|
-
# Arr is padded, filter is not
|
131
|
-
def _apply_filter_subset(template, ft_temp, template_filter):
|
132
|
-
# TODO: Benchmark this
|
133
|
-
_template[:] = template[real_subset]
|
134
|
-
template[real_subset] = _apply_filter(_template, _ft_temp, template_filter)
|
135
|
-
return template
|
136
|
-
|
137
|
-
# Filter application requires a broadcasting operation
|
138
|
-
def _apply_filter_broadcast(template, ft_temp, template_filter):
|
139
|
-
_ft_prod = rfftn(template, _ft_temp2)
|
140
|
-
_ft_res = be.multiply(_ft_prod, template_filter, out=_ft_temp)
|
141
|
-
return irfftn(_ft_res, _template)
|
142
|
-
|
143
|
-
if any(x > y and y == 1 for x, y in zip(filter_shape, arr_ft_shape)):
|
144
|
-
_template = be.zeros(inv_output_shape, be._float_dtype)
|
145
|
-
_ft_temp2 = be.zeros((1, *product_ft_shape[1:]), be._complex_dtype)
|
146
|
-
return _apply_filter_broadcast
|
147
|
-
|
148
|
-
return _apply_filter_subset
|
149
|
-
|
150
|
-
return _apply_filter
|
151
|
-
|
152
|
-
|
153
24
|
def cc_setup(
|
154
25
|
matching_data: type,
|
155
26
|
fast_shape: Tuple[int],
|
@@ -176,8 +47,6 @@ def cc_setup(
|
|
176
47
|
axes = matching_data._batch_axis(matching_data._batch_mask)
|
177
48
|
|
178
49
|
ret = {
|
179
|
-
"fast_shape": fast_shape,
|
180
|
-
"fast_ft_shape": fast_ft_shape,
|
181
50
|
"template": be.to_sharedarr(matching_data.template, shm_handler),
|
182
51
|
"ft_target": be.to_sharedarr(be.rfftn(target_pad, axes=axes), shm_handler),
|
183
52
|
"inv_denominator": be.to_sharedarr(
|
@@ -219,8 +88,8 @@ def lcc_setup(matching_data, **kwargs) -> Dict:
|
|
219
88
|
for subset in subsets:
|
220
89
|
template[subset] = laplace(template[subset], mode="wrap")
|
221
90
|
|
222
|
-
matching_data._target = target
|
223
|
-
matching_data._template = template
|
91
|
+
matching_data._target = be.to_backend_array(target)
|
92
|
+
matching_data._template = be.to_backend_array(template)
|
224
93
|
|
225
94
|
return cc_setup(matching_data=matching_data, **kwargs)
|
226
95
|
|
@@ -316,8 +185,6 @@ def corr_setup(
|
|
316
185
|
denominator = be.multiply(denominator, mask, out=denominator)
|
317
186
|
|
318
187
|
ret = {
|
319
|
-
"fast_shape": fast_shape,
|
320
|
-
"fast_ft_shape": fast_ft_shape,
|
321
188
|
"template": be.to_sharedarr(template, shm_handler),
|
322
189
|
"ft_target": be.to_sharedarr(ft_target, shm_handler),
|
323
190
|
"inv_denominator": be.to_sharedarr(denominator, shm_handler),
|
@@ -376,8 +243,6 @@ def flc_setup(
|
|
376
243
|
ft_target2 = be.rfftn(target_pad, axes=data_axes)
|
377
244
|
|
378
245
|
ret = {
|
379
|
-
"fast_shape": fast_shape,
|
380
|
-
"fast_ft_shape": fast_ft_shape,
|
381
246
|
"template": be.to_sharedarr(matching_data.template, shm_handler),
|
382
247
|
"template_mask": be.to_sharedarr(matching_data.template_mask, shm_handler),
|
383
248
|
"ft_target": be.to_sharedarr(ft_target, shm_handler),
|
@@ -438,8 +303,6 @@ def flcSphericalMask_setup(
|
|
438
303
|
|
439
304
|
temp2 = be.norm_scores(1, temp2, temp, n_obs, be.eps(be._float_dtype), temp2)
|
440
305
|
ret = {
|
441
|
-
"fast_shape": fast_shape,
|
442
|
-
"fast_ft_shape": fast_ft_shape,
|
443
306
|
"template": be.to_sharedarr(matching_data.template, shm_handler),
|
444
307
|
"template_mask": be.to_sharedarr(template_mask, shm_handler),
|
445
308
|
"ft_target": be.to_sharedarr(ft_target, shm_handler),
|
@@ -461,21 +324,14 @@ def mcc_setup(
|
|
461
324
|
Setup function for :py:meth:`mcc_scoring`.
|
462
325
|
"""
|
463
326
|
target, target_mask = matching_data.target, matching_data.target_mask
|
464
|
-
target = be.multiply(target, target_mask
|
327
|
+
target = be.multiply(target, target_mask, out=target)
|
465
328
|
|
466
|
-
target = be.topleft_pad(
|
467
|
-
target,
|
468
|
-
matching_data._batch_shape(fast_shape, matching_data._template_batch),
|
469
|
-
)
|
470
|
-
target_mask = be.topleft_pad(
|
471
|
-
target_mask,
|
472
|
-
matching_data._batch_shape(fast_shape, matching_data._template_batch),
|
473
|
-
)
|
474
329
|
ax = matching_data._batch_axis(matching_data._batch_mask)
|
330
|
+
shape = matching_data._batch_shape(fast_shape, matching_data._template_batch)
|
331
|
+
target = be.topleft_pad(target, shape)
|
332
|
+
target_mask = be.topleft_pad(target_mask, shape)
|
475
333
|
|
476
334
|
ret = {
|
477
|
-
"fast_shape": fast_shape,
|
478
|
-
"fast_ft_shape": fast_ft_shape,
|
479
335
|
"template": be.to_sharedarr(matching_data.template, shm_handler),
|
480
336
|
"template_mask": be.to_sharedarr(matching_data.template_mask, shm_handler),
|
481
337
|
"ft_target": be.to_sharedarr(be.rfftn(target, axes=ax), shm_handler),
|
@@ -541,8 +397,7 @@ def corr_scoring(
|
|
541
397
|
|
542
398
|
Returns
|
543
399
|
-------
|
544
|
-
|
545
|
-
``callback`` if provided otherwise None.
|
400
|
+
CallbackClass
|
546
401
|
"""
|
547
402
|
template = be.from_sharedarr(template)
|
548
403
|
ft_target = be.from_sharedarr(ft_target)
|
@@ -550,58 +405,30 @@ def corr_scoring(
|
|
550
405
|
numerator = be.from_sharedarr(numerator)
|
551
406
|
template_filter = be.from_sharedarr(template_filter)
|
552
407
|
|
553
|
-
|
408
|
+
n_obs = None
|
554
409
|
if template_mask is not None:
|
555
410
|
template_mask = be.from_sharedarr(template_mask)
|
556
|
-
|
557
|
-
|
558
|
-
|
559
|
-
|
560
|
-
|
561
|
-
callback_func = conditional_execute(callback, callback is not None)
|
562
|
-
norm_template = conditional_execute(norm_func, norm_template)
|
563
|
-
norm_numerator = conditional_execute(
|
564
|
-
be.subtract, identity, _shape_match(numerator.shape, fast_shape)
|
565
|
-
)
|
566
|
-
norm_denominator = conditional_execute(
|
567
|
-
be.multiply, identity, _shape_match(inv_denominator.shape, fast_shape)
|
568
|
-
)
|
411
|
+
n_obs = be.sum(template_mask) if template_mask is not None else None
|
412
|
+
|
413
|
+
norm_template = conditional_execute(normalize_template, n_obs is not None)
|
414
|
+
norm_sub = conditional_execute(be.subtract, numerator.shape != (1,))
|
415
|
+
norm_mul = conditional_execute(be.multiply, inv_denominator.shape != (1))
|
569
416
|
|
570
417
|
arr = be.zeros(fast_shape, be._float_dtype)
|
571
418
|
ft_temp = be.zeros(fast_ft_shape, be._complex_dtype)
|
572
|
-
|
573
|
-
_fftargs = {
|
574
|
-
"real_dtype": be._float_dtype,
|
575
|
-
"cmpl_dtype": be._complex_dtype,
|
576
|
-
"inv_output_shape": fast_shape,
|
577
|
-
"fwd_axes": None,
|
578
|
-
"inv_axes": None,
|
579
|
-
"inv_shape": fast_ft_shape,
|
580
|
-
"temp_fwd": arr,
|
581
|
-
}
|
582
|
-
|
583
|
-
_fftargs["fwd_shape"] = _fftargs["temp_fwd"].shape
|
584
|
-
rfftn, irfftn = be.build_fft(temp_inv=ft_temp, **_fftargs)
|
585
|
-
_ = _fftargs.pop("temp_fwd", None)
|
419
|
+
template_rot = be.zeros(template.shape, be._float_dtype)
|
586
420
|
|
587
421
|
template_filter_func = _create_filter_func(
|
588
422
|
arr_shape=template.shape,
|
589
|
-
|
590
|
-
arr_filter=template_filter,
|
591
|
-
rfftn=rfftn,
|
592
|
-
irfftn=irfftn,
|
593
|
-
**_fftargs,
|
423
|
+
filter_shape=template_filter.shape,
|
594
424
|
)
|
595
425
|
|
596
426
|
center = be.divide(be.to_backend_array(template.shape) - 1, 2)
|
597
427
|
unpadded_slice = tuple(slice(0, stop) for stop in template.shape)
|
598
|
-
|
599
|
-
template_rot = be.zeros(template.shape, be._float_dtype)
|
600
428
|
for index in range(rotations.shape[0]):
|
601
|
-
# d+1, d+1 rigid transform matrix from d,d rotation matrix
|
602
429
|
rotation = rotations[index]
|
603
430
|
matrix = be._rigid_transform_matrix(rotation_matrix=rotation, center=center)
|
604
|
-
|
431
|
+
_ = be.rigid_transform(
|
605
432
|
arr=template,
|
606
433
|
rotation_matrix=matrix,
|
607
434
|
out=template_rot,
|
@@ -610,18 +437,17 @@ def corr_scoring(
|
|
610
437
|
)
|
611
438
|
|
612
439
|
template_rot = template_filter_func(template_rot, ft_temp, template_filter)
|
613
|
-
norm_template(template_rot, template_mask,
|
440
|
+
norm_template(template_rot, template_mask, n_obs)
|
614
441
|
|
615
442
|
arr = be.fill(arr, 0)
|
616
443
|
arr[unpadded_slice] = template_rot
|
617
444
|
|
618
|
-
ft_temp = rfftn(arr, ft_temp)
|
619
|
-
|
620
|
-
arr = irfftn(ft_temp, arr)
|
445
|
+
ft_temp = be.rfftn(arr, s=fast_shape, out=ft_temp)
|
446
|
+
arr = _correlate_fts(ft_target, ft_temp, ft_temp, arr, fast_shape)
|
621
447
|
|
622
|
-
arr =
|
623
|
-
arr =
|
624
|
-
|
448
|
+
arr = norm_sub(arr, numerator, out=arr)
|
449
|
+
arr = norm_mul(arr, inv_denominator, out=arr)
|
450
|
+
callback(arr, rotation_matrix=rotation)
|
625
451
|
|
626
452
|
return callback
|
627
453
|
|
@@ -682,6 +508,10 @@ def flc_scoring(
|
|
682
508
|
interpolation_order : int
|
683
509
|
Spline order for template rotations.
|
684
510
|
|
511
|
+
Returns
|
512
|
+
-------
|
513
|
+
CallbackClass
|
514
|
+
|
685
515
|
References
|
686
516
|
----------
|
687
517
|
.. [1] Hrabe T. et al, J. Struct. Biol. 178, 177 (2012).
|
@@ -698,63 +528,46 @@ def flc_scoring(
|
|
698
528
|
temp2 = be.zeros(fast_shape, float_dtype)
|
699
529
|
ft_temp = be.zeros(fast_ft_shape, complex_dtype)
|
700
530
|
ft_denom = be.zeros(fast_ft_shape, complex_dtype)
|
531
|
+
template_rot = be.zeros(template.shape, be._float_dtype)
|
532
|
+
template_mask_rot = be.zeros(template.shape, be._float_dtype)
|
701
533
|
|
702
|
-
|
703
|
-
"real_dtype": be._float_dtype,
|
704
|
-
"cmpl_dtype": be._complex_dtype,
|
705
|
-
"inv_output_shape": fast_shape,
|
706
|
-
"fwd_axes": None,
|
707
|
-
"inv_axes": None,
|
708
|
-
"inv_shape": fast_ft_shape,
|
709
|
-
"temp_fwd": arr,
|
710
|
-
}
|
711
|
-
|
712
|
-
_fftargs["fwd_shape"] = _fftargs["temp_fwd"].shape
|
713
|
-
rfftn, irfftn = be.build_fft(temp_inv=ft_temp, **_fftargs)
|
714
|
-
_ = _fftargs.pop("temp_fwd", None)
|
715
|
-
|
716
|
-
template_filter_func = _create_filter_func(
|
717
|
-
arr_shape=template.shape,
|
718
|
-
arr_ft_shape=fast_ft_shape,
|
719
|
-
arr_filter=template_filter,
|
720
|
-
rfftn=rfftn,
|
721
|
-
irfftn=irfftn,
|
722
|
-
**_fftargs,
|
723
|
-
)
|
534
|
+
tmpl_filter_func = _create_filter_func(template.shape, template_filter.shape)
|
724
535
|
|
725
536
|
eps = be.eps(float_dtype)
|
726
|
-
|
537
|
+
center = be.divide(be.to_backend_array(template.shape) - 1, 2)
|
538
|
+
unpadded_slice = tuple(slice(0, stop) for stop in template.shape)
|
727
539
|
for index in range(rotations.shape[0]):
|
728
540
|
rotation = rotations[index]
|
729
|
-
|
730
|
-
|
731
|
-
arr, temp = be.rigid_transform(
|
541
|
+
matrix = be._rigid_transform_matrix(rotation_matrix=rotation, center=center)
|
542
|
+
_ = be.rigid_transform(
|
732
543
|
arr=template,
|
733
544
|
arr_mask=template_mask,
|
734
|
-
rotation_matrix=
|
735
|
-
out=
|
736
|
-
out_mask=
|
545
|
+
rotation_matrix=matrix,
|
546
|
+
out=template_rot,
|
547
|
+
out_mask=template_mask_rot,
|
737
548
|
use_geometric_center=True,
|
738
549
|
order=interpolation_order,
|
739
550
|
cache=True,
|
740
551
|
)
|
741
552
|
|
742
|
-
n_obs = be.sum(
|
743
|
-
|
744
|
-
|
553
|
+
n_obs = be.sum(template_mask_rot)
|
554
|
+
template_rot = tmpl_filter_func(template_rot, ft_temp, template_filter)
|
555
|
+
template_rot = normalize_template(template_rot, template_mask_rot, n_obs)
|
556
|
+
|
557
|
+
arr = be.fill(arr, 0)
|
558
|
+
temp = be.fill(temp, 0)
|
559
|
+
arr[unpadded_slice] = template_rot
|
560
|
+
temp[unpadded_slice] = template_mask_rot
|
745
561
|
|
746
|
-
ft_temp = rfftn(temp, ft_temp)
|
747
|
-
|
748
|
-
|
749
|
-
ft_denom = be.multiply(ft_target2, ft_temp, out=ft_denom)
|
750
|
-
temp2 = irfftn(ft_denom, temp2)
|
562
|
+
ft_temp = be.rfftn(temp, out=ft_temp, s=fast_shape)
|
563
|
+
temp = _correlate_fts(ft_target, ft_temp, ft_denom, temp, fast_shape)
|
564
|
+
temp2 = _correlate_fts(ft_target2, ft_temp, ft_denom, temp, fast_shape)
|
751
565
|
|
752
|
-
ft_temp = rfftn(arr, ft_temp)
|
753
|
-
|
754
|
-
arr = irfftn(ft_temp, arr)
|
566
|
+
ft_temp = be.rfftn(arr, out=ft_temp, s=fast_shape)
|
567
|
+
arr = _correlate_fts(ft_target, ft_temp, ft_temp, arr, fast_shape)
|
755
568
|
|
756
569
|
arr = be.norm_scores(arr, temp2, temp, n_obs, eps, arr)
|
757
|
-
|
570
|
+
callback(arr, rotation_matrix=rotation)
|
758
571
|
|
759
572
|
return callback
|
760
573
|
|
@@ -821,6 +634,10 @@ def mcc_scoring(
|
|
821
634
|
overlap_ratio : float, optional
|
822
635
|
Required fractional mask overlap, 0.3 by default.
|
823
636
|
|
637
|
+
Returns
|
638
|
+
-------
|
639
|
+
CallbackClass
|
640
|
+
|
824
641
|
References
|
825
642
|
----------
|
826
643
|
.. [1] Masked FFT registration, Dirk Padfield, CVPR 2010 conference
|
@@ -846,30 +663,12 @@ def mcc_scoring(
|
|
846
663
|
temp3 = be.zeros(fast_shape, float_dtype)
|
847
664
|
temp_ft = be.zeros(fast_ft_shape, complex_dtype)
|
848
665
|
|
849
|
-
_fftargs = {
|
850
|
-
"real_dtype": be._float_dtype,
|
851
|
-
"cmpl_dtype": be._complex_dtype,
|
852
|
-
"inv_output_shape": fast_shape,
|
853
|
-
"fwd_axes": None,
|
854
|
-
"inv_axes": None,
|
855
|
-
"inv_shape": fast_ft_shape,
|
856
|
-
"temp_fwd": temp,
|
857
|
-
}
|
858
|
-
|
859
|
-
_fftargs["fwd_shape"] = _fftargs["temp_fwd"].shape
|
860
|
-
rfftn, irfftn = be.build_fft(temp_inv=temp_ft, **_fftargs)
|
861
|
-
_ = _fftargs.pop("temp_fwd", None)
|
862
|
-
|
863
666
|
template_filter_func = _create_filter_func(
|
864
667
|
arr_shape=template.shape,
|
865
|
-
|
866
|
-
|
867
|
-
rfftn=rfftn,
|
868
|
-
irfftn=irfftn,
|
869
|
-
**_fftargs,
|
668
|
+
filter_shape=template_filter.shape,
|
669
|
+
arr_padded=True,
|
870
670
|
)
|
871
671
|
|
872
|
-
callback_func = conditional_execute(callback, callback is not None)
|
873
672
|
for index in range(rotations.shape[0]):
|
874
673
|
rotation = rotations[index]
|
875
674
|
template_rot = be.fill(template_rot, 0)
|
@@ -888,17 +687,19 @@ def mcc_scoring(
|
|
888
687
|
template_filter_func(template_rot, temp_ft, template_filter)
|
889
688
|
normalize_template(template_rot, temp, be.sum(temp))
|
890
689
|
|
891
|
-
temp_ft = rfftn(template_rot, temp_ft)
|
892
|
-
temp2 = irfftn(target_mask_ft * temp_ft, temp2)
|
893
|
-
numerator = irfftn(target_ft * temp_ft, numerator)
|
690
|
+
temp_ft = be.rfftn(template_rot, out=temp_ft, s=fast_shape)
|
691
|
+
temp2 = be.irfftn(target_mask_ft * temp_ft, out=temp2, s=fast_shape)
|
692
|
+
numerator = be.irfftn(target_ft * temp_ft, out=numerator, s=fast_shape)
|
894
693
|
|
895
694
|
# temp template_mask_rot | temp_ft template_mask_rot_ft
|
896
695
|
# Calculate overlap of masks at every point in the convolution.
|
897
696
|
# Locations with high overlap should not be taken into account.
|
898
|
-
temp_ft = rfftn(temp, temp_ft)
|
899
|
-
mask_overlap = irfftn(
|
697
|
+
temp_ft = be.rfftn(temp, out=temp_ft, s=fast_shape)
|
698
|
+
mask_overlap = be.irfftn(
|
699
|
+
temp_ft * target_mask_ft, out=mask_overlap, s=fast_shape
|
700
|
+
)
|
900
701
|
be.maximum(mask_overlap, eps, out=mask_overlap)
|
901
|
-
temp = irfftn(temp_ft * target_ft, temp)
|
702
|
+
temp = be.irfftn(temp_ft * target_ft, out=temp, s=fast_shape)
|
902
703
|
|
903
704
|
be.subtract(
|
904
705
|
numerator,
|
@@ -908,21 +709,21 @@ def mcc_scoring(
|
|
908
709
|
|
909
710
|
# temp_3 = fixed_denom
|
910
711
|
be.multiply(temp_ft, target_ft2, out=temp_ft)
|
911
|
-
temp3 = irfftn(temp_ft, temp3)
|
712
|
+
temp3 = be.irfftn(temp_ft, out=temp3, s=fast_shape)
|
912
713
|
be.subtract(temp3, be.divide(be.square(temp), mask_overlap), out=temp3)
|
913
714
|
be.maximum(temp3, 0.0, out=temp3)
|
914
715
|
|
915
716
|
# temp = moving_denom
|
916
|
-
temp_ft = rfftn(be.square(template_rot), temp_ft)
|
717
|
+
temp_ft = be.rfftn(be.square(template_rot), out=temp_ft, s=fast_shape)
|
917
718
|
be.multiply(target_mask_ft, temp_ft, out=temp_ft)
|
918
|
-
temp = irfftn(temp_ft, temp)
|
719
|
+
temp = be.irfftn(temp_ft, out=temp, s=fast_shape)
|
919
720
|
|
920
721
|
be.subtract(temp, be.divide(be.square(temp2), mask_overlap), out=temp)
|
921
722
|
be.maximum(temp, 0.0, out=temp)
|
922
723
|
|
923
724
|
# temp_2 = denom
|
924
725
|
be.multiply(temp3, temp, out=temp)
|
925
|
-
be.sqrt(temp, temp2)
|
726
|
+
be.sqrt(temp, out=temp2)
|
926
727
|
|
927
728
|
# Pixels where `denom` is very small will introduce large
|
928
729
|
# numbers after division. To get around this problem,
|
@@ -938,29 +739,11 @@ def mcc_scoring(
|
|
938
739
|
mask_overlap, axis=axes, keepdims=True
|
939
740
|
)
|
940
741
|
temp[mask_overlap < number_px_threshold] = 0.0
|
941
|
-
|
742
|
+
callback(temp, rotation_matrix=rotation)
|
942
743
|
|
943
744
|
return callback
|
944
745
|
|
945
746
|
|
946
|
-
def _format_slice(shape, squeeze_axis):
|
947
|
-
ret = tuple(
|
948
|
-
slice(None) if i not in squeeze_axis else 0 for i, _ in enumerate(shape)
|
949
|
-
)
|
950
|
-
return ret
|
951
|
-
|
952
|
-
|
953
|
-
def _get_batch_dim(target, template):
|
954
|
-
target_batch, template_batch = [], []
|
955
|
-
for i in range(len(target.shape)):
|
956
|
-
if target.shape[i] == 1 and template.shape[i] != 1:
|
957
|
-
template_batch.append(i)
|
958
|
-
if target.shape[i] != 1 and template.shape[i] == 1:
|
959
|
-
target_batch.append(i)
|
960
|
-
|
961
|
-
return target_batch, template_batch
|
962
|
-
|
963
|
-
|
964
747
|
def flc_scoring2(
|
965
748
|
template: shm_type,
|
966
749
|
template_mask: shm_type,
|
@@ -973,25 +756,22 @@ def flc_scoring2(
|
|
973
756
|
callback: CallbackClass,
|
974
757
|
interpolation_order: int,
|
975
758
|
) -> CallbackClass:
|
976
|
-
callback_func = conditional_execute(callback, callback is not None)
|
977
|
-
|
978
|
-
# Retrieve objects from shared memory
|
979
759
|
template = be.from_sharedarr(template)
|
980
760
|
template_mask = be.from_sharedarr(template_mask)
|
981
761
|
ft_target = be.from_sharedarr(ft_target)
|
982
762
|
ft_target2 = be.from_sharedarr(ft_target2)
|
983
763
|
template_filter = be.from_sharedarr(template_filter)
|
984
764
|
|
985
|
-
|
986
|
-
target_batch, template_batch = _get_batch_dim(ft_target, template)
|
987
|
-
sqz_cmpl = tuple(1 if i in target_batch else x for i, x in enumerate(fast_ft_shape))
|
988
|
-
sqz_slice = tuple(slice(0, 1) if x == 1 else slice(None) for x in sqz_cmpl)
|
765
|
+
tar_batch, tmpl_batch = _get_batch_dim(ft_target, template)
|
989
766
|
|
990
|
-
|
991
|
-
if
|
992
|
-
|
993
|
-
|
994
|
-
|
767
|
+
nd = len(fast_shape)
|
768
|
+
sqz_slice = tuple(slice(0, 1) if i in tar_batch else slice(None) for i in range(nd))
|
769
|
+
tmpl_subset = tuple(0 if i in tar_batch else slice(None) for i in range(nd))
|
770
|
+
|
771
|
+
axes, shape, batched = None, fast_shape, len(tmpl_batch) > 0
|
772
|
+
if len(tar_batch) or len(tmpl_batch):
|
773
|
+
axes = tuple(i for i in range(nd) if i not in (*tar_batch, *tmpl_batch))
|
774
|
+
shape = tuple(fast_shape[i] for i in axes)
|
995
775
|
|
996
776
|
arr = be.zeros(fast_shape, be._float_dtype)
|
997
777
|
temp = be.zeros(fast_shape, be._float_dtype)
|
@@ -999,34 +779,11 @@ def flc_scoring2(
|
|
999
779
|
ft_denom = be.zeros(fast_ft_shape, be._complex_dtype)
|
1000
780
|
|
1001
781
|
tmp_sqz, arr_sqz, ft_temp = temp[sqz_slice], arr[sqz_slice], ft_denom[sqz_slice]
|
1002
|
-
if be.size(template_filter) != 1:
|
1003
|
-
ret_shape = np.broadcast_shapes(
|
1004
|
-
sqz_cmpl, tuple(int(x) for x in template_filter.shape)
|
1005
|
-
)
|
1006
|
-
ft_temp = be.zeros(ret_shape, be._complex_dtype)
|
1007
|
-
|
1008
|
-
_fftargs = {
|
1009
|
-
"real_dtype": be._float_dtype,
|
1010
|
-
"cmpl_dtype": be._complex_dtype,
|
1011
|
-
"inv_output_shape": fast_shape,
|
1012
|
-
"fwd_axes": data_axes,
|
1013
|
-
"inv_axes": data_axes,
|
1014
|
-
"inv_shape": fast_ft_shape,
|
1015
|
-
"temp_fwd": arr_sqz if _shape_match(ft_temp.shape, sqz_cmpl) else arr,
|
1016
|
-
}
|
1017
|
-
|
1018
|
-
# build_fft ignores fwd_shape if temp_fwd is given and serves only for bookkeeping
|
1019
|
-
_fftargs["fwd_shape"] = _fftargs["temp_fwd"].shape
|
1020
|
-
rfftn, irfftn = be.build_fft(temp_inv=ft_denom, **_fftargs)
|
1021
|
-
_ = _fftargs.pop("temp_fwd", None)
|
1022
782
|
|
1023
783
|
template_filter_func = _create_filter_func(
|
1024
784
|
arr_shape=template.shape,
|
1025
|
-
|
1026
|
-
|
1027
|
-
rfftn=rfftn,
|
1028
|
-
irfftn=irfftn,
|
1029
|
-
**_fftargs,
|
785
|
+
filter_shape=template_filter.shape,
|
786
|
+
arr_padded=True,
|
1030
787
|
)
|
1031
788
|
|
1032
789
|
eps = be.eps(be._float_dtype)
|
@@ -1034,33 +791,32 @@ def flc_scoring2(
|
|
1034
791
|
rotation = rotations[index]
|
1035
792
|
be.fill(arr, 0)
|
1036
793
|
be.fill(temp, 0)
|
1037
|
-
|
1038
|
-
|
1039
|
-
|
794
|
+
|
795
|
+
_, _ = be.rigid_transform(
|
796
|
+
arr=template[tmpl_subset],
|
797
|
+
arr_mask=template_mask[tmpl_subset],
|
1040
798
|
rotation_matrix=rotation,
|
1041
|
-
out=arr_sqz,
|
1042
|
-
out_mask=tmp_sqz,
|
799
|
+
out=arr_sqz[tmpl_subset],
|
800
|
+
out_mask=tmp_sqz[tmpl_subset],
|
1043
801
|
use_geometric_center=True,
|
1044
802
|
order=interpolation_order,
|
1045
|
-
cache=
|
1046
|
-
batched=
|
803
|
+
cache=False,
|
804
|
+
batched=batched,
|
1047
805
|
)
|
1048
|
-
|
806
|
+
|
807
|
+
n_obs = be.sum(tmp_sqz, axis=axes, keepdims=True)
|
1049
808
|
arr_norm = template_filter_func(arr_sqz, ft_temp, template_filter)
|
1050
|
-
arr_norm = normalize_template(arr_norm, tmp_sqz, n_obs, axis=
|
809
|
+
arr_norm = normalize_template(arr_norm, tmp_sqz, n_obs, axis=axes)
|
1051
810
|
|
1052
|
-
ft_temp = be.rfftn(tmp_sqz, ft_temp, axes=
|
1053
|
-
|
1054
|
-
|
1055
|
-
ft_denom = be.multiply(ft_target2, ft_temp, out=ft_denom)
|
1056
|
-
temp2 = be.irfftn(ft_denom, temp2, axes=data_axes, s=data_shape)
|
811
|
+
ft_temp = be.rfftn(tmp_sqz, out=ft_temp, axes=axes, s=shape)
|
812
|
+
temp = _correlate_fts(ft_target, ft_temp, ft_denom, temp, shape, axes)
|
813
|
+
temp2 = _correlate_fts(ft_target2, ft_temp, ft_denom, temp2, shape, axes)
|
1057
814
|
|
1058
|
-
ft_temp = rfftn(arr_norm,
|
1059
|
-
|
1060
|
-
arr = irfftn(ft_denom, arr)
|
815
|
+
ft_temp = be.rfftn(arr_norm, out=ft_temp, axes=axes, s=shape)
|
816
|
+
arr = _correlate_fts(ft_target, ft_temp, ft_denom, arr, shape, axes)
|
1061
817
|
|
1062
|
-
be.norm_scores(arr, temp2, temp, n_obs, eps, arr)
|
1063
|
-
|
818
|
+
arr = be.norm_scores(arr, temp2, temp, n_obs, eps, arr)
|
819
|
+
callback(arr, rotation_matrix=rotation)
|
1064
820
|
|
1065
821
|
return callback
|
1066
822
|
|
@@ -1085,101 +841,140 @@ def corr_scoring2(
|
|
1085
841
|
numerator = be.from_sharedarr(numerator)
|
1086
842
|
template_filter = be.from_sharedarr(template_filter)
|
1087
843
|
|
1088
|
-
|
1089
|
-
|
1090
|
-
|
1091
|
-
sqz_slice = tuple(slice(0, 1) if
|
1092
|
-
|
1093
|
-
|
1094
|
-
|
1095
|
-
|
1096
|
-
|
1097
|
-
|
1098
|
-
|
1099
|
-
|
844
|
+
tar_batch, tmpl_batch = _get_batch_dim(ft_target, template)
|
845
|
+
|
846
|
+
nd = len(fast_shape)
|
847
|
+
sqz_slice = tuple(slice(0, 1) if i in tar_batch else slice(None) for i in range(nd))
|
848
|
+
tmpl_subset = tuple(0 if i in tar_batch else slice(None) for i in range(nd))
|
849
|
+
|
850
|
+
axes, shape, batched = None, fast_shape, len(tmpl_batch) > 0
|
851
|
+
if len(tar_batch) or len(tmpl_batch):
|
852
|
+
axes = tuple(i for i in range(nd) if i not in (*tar_batch, *tmpl_batch))
|
853
|
+
shape = tuple(fast_shape[i] for i in axes)
|
854
|
+
|
855
|
+
unpadded_slice = tuple(
|
856
|
+
slice(None) if i in (*tar_batch, *tmpl_batch) else slice(0, x)
|
857
|
+
for i, x in enumerate(template.shape)
|
858
|
+
)
|
1100
859
|
|
1101
860
|
arr = be.zeros(fast_shape, be._float_dtype)
|
1102
861
|
ft_temp = be.zeros(fast_ft_shape, be._complex_dtype)
|
1103
862
|
arr_sqz, ft_sqz = arr[sqz_slice], ft_temp[sqz_slice]
|
1104
863
|
|
1105
|
-
|
1106
|
-
# The filter could be w.r.t the unpadded template
|
1107
|
-
ret_shape = tuple(
|
1108
|
-
int(x * y) if x == 1 or y == 1 else y
|
1109
|
-
for x, y in zip(sqz_cmpl, template_filter.shape)
|
1110
|
-
)
|
1111
|
-
ft_sqz = be.zeros(ret_shape, be._complex_dtype)
|
1112
|
-
|
1113
|
-
norm_func, norm_template, mask_sum = normalize_template, False, 1
|
864
|
+
n_obs = None
|
1114
865
|
if template_mask is not None:
|
1115
866
|
template_mask = be.from_sharedarr(template_mask)
|
1116
|
-
|
1117
|
-
be.astype(template_mask, be._overflow_safe_dtype),
|
1118
|
-
axis=data_axes,
|
1119
|
-
keepdims=True,
|
1120
|
-
)
|
1121
|
-
if be.datatype_bytes(template_mask.dtype) == 2:
|
1122
|
-
norm_func = _normalize_template_overflow_safe
|
1123
|
-
|
1124
|
-
callback_func = conditional_execute(callback, callback is not None)
|
1125
|
-
norm_template = conditional_execute(norm_func, norm_template)
|
1126
|
-
norm_numerator = conditional_execute(
|
1127
|
-
be.subtract, identity, _shape_match(numerator.shape, fast_shape)
|
1128
|
-
)
|
1129
|
-
norm_denominator = conditional_execute(
|
1130
|
-
be.multiply, identity, _shape_match(inv_denominator.shape, fast_shape)
|
1131
|
-
)
|
867
|
+
n_obs = be.sum(template_mask, axis=axes, keepdims=True)
|
1132
868
|
|
1133
|
-
|
1134
|
-
|
1135
|
-
|
1136
|
-
"fwd_axes": data_axes,
|
1137
|
-
"inv_axes": data_axes,
|
1138
|
-
"inv_shape": fast_ft_shape,
|
1139
|
-
"inv_output_shape": fast_shape,
|
1140
|
-
"temp_fwd": arr_sqz if _shape_match(ft_sqz.shape, sqz_cmpl) else arr,
|
1141
|
-
}
|
1142
|
-
|
1143
|
-
# build_fft ignores fwd_shape if temp_fwd is given and serves only for bookkeeping
|
1144
|
-
_fftargs["fwd_shape"] = _fftargs["temp_fwd"].shape
|
1145
|
-
rfftn, irfftn = be.build_fft(temp_inv=ft_temp, **_fftargs)
|
1146
|
-
_ = _fftargs.pop("temp_fwd", None)
|
869
|
+
norm_template = conditional_execute(normalize_template, n_obs is not None)
|
870
|
+
norm_sub = conditional_execute(be.subtract, numerator.shape != (1,))
|
871
|
+
norm_mul = conditional_execute(be.multiply, inv_denominator.shape != (1,))
|
1147
872
|
|
1148
873
|
template_filter_func = _create_filter_func(
|
1149
874
|
arr_shape=template.shape,
|
1150
|
-
|
1151
|
-
|
1152
|
-
rfftn=rfftn,
|
1153
|
-
irfftn=irfftn,
|
1154
|
-
**_fftargs,
|
875
|
+
filter_shape=template_filter.shape,
|
876
|
+
arr_padded=True,
|
1155
877
|
)
|
1156
878
|
|
1157
879
|
for index in range(rotations.shape[0]):
|
1158
880
|
be.fill(arr, 0)
|
1159
881
|
rotation = rotations[index]
|
1160
|
-
|
1161
|
-
arr=template,
|
882
|
+
_, _ = be.rigid_transform(
|
883
|
+
arr=template[tmpl_subset],
|
1162
884
|
rotation_matrix=rotation,
|
1163
|
-
out=arr_sqz,
|
885
|
+
out=arr_sqz[tmpl_subset],
|
1164
886
|
use_geometric_center=True,
|
1165
887
|
order=interpolation_order,
|
1166
|
-
cache=
|
1167
|
-
batched=
|
888
|
+
cache=False,
|
889
|
+
batched=batched,
|
1168
890
|
)
|
1169
891
|
arr_norm = template_filter_func(arr_sqz, ft_sqz, template_filter)
|
1170
|
-
norm_template(arr_norm[unpadded_slice], template_mask,
|
892
|
+
norm_template(arr_norm[unpadded_slice], template_mask, n_obs, axis=axes)
|
1171
893
|
|
1172
|
-
ft_sqz = rfftn(arr_norm, ft_sqz)
|
1173
|
-
|
1174
|
-
arr = irfftn(ft_temp, arr)
|
894
|
+
ft_sqz = be.rfftn(arr_norm, out=ft_sqz, axes=axes, s=shape)
|
895
|
+
arr = _correlate_fts(ft_target, ft_sqz, ft_temp, arr, shape, axes)
|
1175
896
|
|
1176
|
-
arr =
|
1177
|
-
arr =
|
1178
|
-
|
897
|
+
arr = norm_sub(arr, numerator, out=arr)
|
898
|
+
arr = norm_mul(arr, inv_denominator, out=arr)
|
899
|
+
callback(arr, rotation_matrix=rotation)
|
1179
900
|
|
1180
901
|
return callback
|
1181
902
|
|
1182
903
|
|
904
|
+
def _get_batch_dim(target, template):
|
905
|
+
target_batch, template_batch = [], []
|
906
|
+
for i in range(len(target.shape)):
|
907
|
+
if target.shape[i] == 1 and template.shape[i] != 1:
|
908
|
+
template_batch.append(i)
|
909
|
+
if target.shape[i] != 1 and template.shape[i] == 1:
|
910
|
+
target_batch.append(i)
|
911
|
+
|
912
|
+
return target_batch, template_batch
|
913
|
+
|
914
|
+
|
915
|
+
def _correlate_fts(ft_tar, ft_tmpl, ft_buffer, real_buffer, fast_shape, axes=None):
|
916
|
+
ft_buffer = be.multiply(ft_tar, ft_tmpl, out=ft_buffer)
|
917
|
+
return be.irfftn(ft_buffer, out=real_buffer, s=fast_shape, axes=axes)
|
918
|
+
|
919
|
+
|
920
|
+
def _create_filter_func(
|
921
|
+
arr_shape: Tuple[int],
|
922
|
+
filter_shape: BackendArray,
|
923
|
+
arr_padded: bool = False,
|
924
|
+
axes=None,
|
925
|
+
) -> Callable:
|
926
|
+
"""
|
927
|
+
Configure template filtering function for Fourier transforms.
|
928
|
+
|
929
|
+
Conceptually we distinguish between three cases. The base case
|
930
|
+
is that both template and the corresponding filter have the same
|
931
|
+
shape. Padding is used when the template filter is larger than
|
932
|
+
the template, for instance to better resolve Fourier filters. Finally
|
933
|
+
this function also handles the case when a filter is supposed to be
|
934
|
+
broadcasted over the template batch dimension.
|
935
|
+
|
936
|
+
Parameters
|
937
|
+
----------
|
938
|
+
arr_shape : tuple of ints
|
939
|
+
Shape of the array to be filtered.
|
940
|
+
filter_shape : BackendArray
|
941
|
+
Precomputed filter to apply in the frequency domain.
|
942
|
+
arr_padded : bool, optional
|
943
|
+
Whether the input template is padded and will need to be cropped
|
944
|
+
to arr_shape prior to filter applications. Defaults to False.
|
945
|
+
axes : tuple of ints, optional
|
946
|
+
Axes to perform Fourier transform over.
|
947
|
+
|
948
|
+
Returns
|
949
|
+
-------
|
950
|
+
Callable
|
951
|
+
Filter function with parameters template, ft_temp and template_filter.
|
952
|
+
"""
|
953
|
+
if filter_shape == (1,):
|
954
|
+
return conditional_execute(identity, execute_operation=True)
|
955
|
+
|
956
|
+
# Default case, all shapes are correctly matched
|
957
|
+
def _apply_filter(template, ft_temp, template_filter):
|
958
|
+
ft_temp = be.rfftn(template, out=ft_temp, s=template.shape)
|
959
|
+
ft_temp = be.multiply(ft_temp, template_filter, out=ft_temp)
|
960
|
+
return be.irfftn(ft_temp, out=template, s=template.shape)
|
961
|
+
|
962
|
+
if not arr_padded:
|
963
|
+
return _apply_filter
|
964
|
+
|
965
|
+
# Array is padded but filter is w.r.t to the original template
|
966
|
+
real_subset = tuple(slice(0, x) for x in arr_shape)
|
967
|
+
_template = be.zeros(arr_shape, be._float_dtype)
|
968
|
+
_ft_temp = be.zeros(filter_shape, be._complex_dtype)
|
969
|
+
|
970
|
+
def _apply_filter_subset(template, ft_temp, template_filter):
|
971
|
+
_template[:] = template[real_subset]
|
972
|
+
template[real_subset] = _apply_filter(_template, _ft_temp, template_filter)
|
973
|
+
return template
|
974
|
+
|
975
|
+
return _apply_filter_subset
|
976
|
+
|
977
|
+
|
1183
978
|
MATCHING_EXHAUSTIVE_REGISTER = {
|
1184
979
|
"CC": (cc_setup, corr_scoring),
|
1185
980
|
"LCC": (lcc_setup, corr_scoring),
|
@@ -1188,6 +983,6 @@ MATCHING_EXHAUSTIVE_REGISTER = {
|
|
1188
983
|
"FLCSphericalMask": (flcSphericalMask_setup, corr_scoring),
|
1189
984
|
"FLC": (flc_setup, flc_scoring),
|
1190
985
|
"MCC": (mcc_setup, mcc_scoring),
|
1191
|
-
"
|
986
|
+
"batchFLCSphericalMask": (flcSphericalMask_setup, corr_scoring2),
|
1192
987
|
"batchFLC": (flc_setup, flc_scoring2),
|
1193
988
|
}
|