pytme 0.3b0.post1__cp311-cp311-macosx_15_0_arm64.whl → 0.3.1.dev20250731__cp311-cp311-macosx_15_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pytme-0.3.1.dev20250731.data/scripts/estimate_ram_usage.py +97 -0
- {pytme-0.3b0.post1.data → pytme-0.3.1.dev20250731.data}/scripts/match_template.py +30 -41
- {pytme-0.3b0.post1.data → pytme-0.3.1.dev20250731.data}/scripts/postprocess.py +35 -21
- {pytme-0.3b0.post1.data → pytme-0.3.1.dev20250731.data}/scripts/preprocessor_gui.py +96 -24
- pytme-0.3.1.dev20250731.data/scripts/pytme_runner.py +1223 -0
- {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dev20250731.dist-info}/METADATA +5 -7
- {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dev20250731.dist-info}/RECORD +59 -49
- scripts/estimate_ram_usage.py +97 -0
- scripts/extract_candidates.py +118 -99
- scripts/match_template.py +30 -41
- scripts/match_template_devel.py +1339 -0
- scripts/postprocess.py +35 -21
- scripts/preprocessor_gui.py +96 -24
- scripts/pytme_runner.py +644 -190
- scripts/refine_matches.py +158 -390
- tests/data/.DS_Store +0 -0
- tests/data/Blurring/.DS_Store +0 -0
- tests/data/Maps/.DS_Store +0 -0
- tests/data/Raw/.DS_Store +0 -0
- tests/data/Structures/.DS_Store +0 -0
- tests/preprocessing/test_utils.py +18 -0
- tests/test_analyzer.py +2 -3
- tests/test_backends.py +3 -9
- tests/test_density.py +0 -1
- tests/test_extensions.py +0 -1
- tests/test_matching_utils.py +10 -60
- tests/test_orientations.py +0 -12
- tests/test_rotations.py +1 -1
- tme/__version__.py +1 -1
- tme/analyzer/_utils.py +4 -4
- tme/analyzer/aggregation.py +35 -15
- tme/analyzer/peaks.py +11 -10
- tme/backends/_jax_utils.py +64 -18
- tme/backends/_numpyfftw_utils.py +270 -0
- tme/backends/cupy_backend.py +16 -55
- tme/backends/jax_backend.py +79 -40
- tme/backends/matching_backend.py +17 -51
- tme/backends/mlx_backend.py +1 -27
- tme/backends/npfftw_backend.py +71 -65
- tme/backends/pytorch_backend.py +1 -26
- tme/density.py +58 -5
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/filters/ctf.py +22 -21
- tme/filters/wedge.py +10 -7
- tme/mask.py +341 -0
- tme/matching_data.py +31 -19
- tme/matching_exhaustive.py +37 -47
- tme/matching_optimization.py +2 -1
- tme/matching_scores.py +229 -411
- tme/matching_utils.py +73 -422
- tme/memory.py +1 -1
- tme/orientations.py +24 -13
- tme/rotations.py +1 -1
- pytme-0.3b0.post1.data/scripts/pytme_runner.py +0 -769
- {pytme-0.3b0.post1.data → pytme-0.3.1.dev20250731.data}/scripts/estimate_memory_usage.py +0 -0
- {pytme-0.3b0.post1.data → pytme-0.3.1.dev20250731.data}/scripts/preprocess.py +0 -0
- {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dev20250731.dist-info}/WHEEL +0 -0
- {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dev20250731.dist-info}/entry_points.txt +0 -0
- {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dev20250731.dist-info}/licenses/LICENSE +0 -0
- {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dev20250731.dist-info}/top_level.txt +0 -0
tme/matching_scores.py
CHANGED
@@ -18,138 +18,9 @@ from .matching_utils import (
|
|
18
18
|
conditional_execute,
|
19
19
|
identity,
|
20
20
|
normalize_template,
|
21
|
-
_normalize_template_overflow_safe,
|
22
21
|
)
|
23
22
|
|
24
23
|
|
25
|
-
def _shape_match(shape1: Tuple[int], shape2: Tuple[int]) -> bool:
|
26
|
-
"""
|
27
|
-
Determine whether ``shape1`` is equal to ``shape2``.
|
28
|
-
|
29
|
-
Parameters
|
30
|
-
----------
|
31
|
-
shape1, shape2 : tuple of ints
|
32
|
-
Shapes to compare.
|
33
|
-
|
34
|
-
Returns
|
35
|
-
-------
|
36
|
-
Bool
|
37
|
-
``shape1`` is equal to ``shape2``.
|
38
|
-
"""
|
39
|
-
if len(shape1) != len(shape2):
|
40
|
-
return False
|
41
|
-
return shape1 == shape2
|
42
|
-
|
43
|
-
|
44
|
-
def _create_filter_func(
|
45
|
-
fwd_shape: Tuple[int],
|
46
|
-
inv_shape: Tuple[int],
|
47
|
-
arr_shape: Tuple[int],
|
48
|
-
arr_filter: BackendArray,
|
49
|
-
arr_ft_shape: Tuple[int],
|
50
|
-
inv_output_shape: Tuple[int],
|
51
|
-
real_dtype: type,
|
52
|
-
cmpl_dtype: type,
|
53
|
-
fwd_axes=None,
|
54
|
-
inv_axes=None,
|
55
|
-
rfftn: Callable = None,
|
56
|
-
irfftn: Callable = None,
|
57
|
-
) -> Callable:
|
58
|
-
"""
|
59
|
-
Configure template filtering function for Fourier transforms.
|
60
|
-
|
61
|
-
Conceptually we distinguish between three cases. The base case
|
62
|
-
is that both template and the corresponding filter have the same
|
63
|
-
shape. Padding is used when the template filter is larger than
|
64
|
-
the template, for instance to better resolve Fourier filters. Finally
|
65
|
-
this function also handles the case when a filter is supposed to be
|
66
|
-
broadcasted over the template batch dimension.
|
67
|
-
|
68
|
-
Parameters
|
69
|
-
----------
|
70
|
-
fwd_shape : tuple of ints
|
71
|
-
Input shape of rfftn.
|
72
|
-
inv_shape : tuple of ints
|
73
|
-
Input shape of irfftn.
|
74
|
-
arr_shape : tuple of ints
|
75
|
-
Shape of the array to be filtered.
|
76
|
-
arr_ft_shape : tuple of ints
|
77
|
-
Shape of the Fourier transform of the array.
|
78
|
-
arr_filter : BackendArray
|
79
|
-
Precomputed filter to apply in the frequency domain.
|
80
|
-
rfftn : Callable, optional
|
81
|
-
Foward Fourier transform.
|
82
|
-
irfftn : Callable, optional
|
83
|
-
Inverse Fourier transform.
|
84
|
-
|
85
|
-
Returns
|
86
|
-
-------
|
87
|
-
Callable
|
88
|
-
Filter function with parameters template, ft_temp and template_filter.
|
89
|
-
"""
|
90
|
-
if be.size(arr_filter) == 1:
|
91
|
-
return conditional_execute(identity, identity, False)
|
92
|
-
|
93
|
-
filter_shape = tuple(int(x) for x in arr_filter.shape)
|
94
|
-
try:
|
95
|
-
product_ft_shape = np.broadcast_shapes(arr_ft_shape, filter_shape)
|
96
|
-
except ValueError:
|
97
|
-
product_ft_shape, inv_output_shape = filter_shape, arr_shape
|
98
|
-
|
99
|
-
rfft_valid = _shape_match(arr_shape, fwd_shape)
|
100
|
-
rfft_valid = rfft_valid and _shape_match(product_ft_shape, inv_shape)
|
101
|
-
rfft_valid = rfft_valid and rfftn is not None and irfftn is not None
|
102
|
-
|
103
|
-
# FTTs were not or built for the wrong shape
|
104
|
-
if not rfft_valid:
|
105
|
-
_fwd_shape = arr_shape
|
106
|
-
if all(x > y for x, y in zip(arr_shape, product_ft_shape)):
|
107
|
-
_fwd_shape = fwd_shape
|
108
|
-
|
109
|
-
rfftn, irfftn = be.build_fft(
|
110
|
-
fwd_shape=_fwd_shape,
|
111
|
-
inv_shape=product_ft_shape,
|
112
|
-
real_dtype=real_dtype,
|
113
|
-
cmpl_dtype=cmpl_dtype,
|
114
|
-
inv_output_shape=inv_output_shape,
|
115
|
-
fwd_axes=fwd_axes,
|
116
|
-
inv_axes=inv_axes,
|
117
|
-
)
|
118
|
-
|
119
|
-
# Default case, all shapes are correctly matched
|
120
|
-
def _apply_filter(template, ft_temp, template_filter):
|
121
|
-
ft_temp = rfftn(template, ft_temp)
|
122
|
-
ft_temp = be.multiply(ft_temp, template_filter, out=ft_temp)
|
123
|
-
return irfftn(ft_temp, template)
|
124
|
-
|
125
|
-
if not _shape_match(arr_ft_shape, filter_shape):
|
126
|
-
real_subset = tuple(slice(0, x) for x in arr_shape)
|
127
|
-
_template = be.zeros(arr_shape, be._float_dtype)
|
128
|
-
_ft_temp = be.zeros(product_ft_shape, be._complex_dtype)
|
129
|
-
|
130
|
-
# Arr is padded, filter is not
|
131
|
-
def _apply_filter_subset(template, ft_temp, template_filter):
|
132
|
-
# TODO: Benchmark this
|
133
|
-
_template[:] = template[real_subset]
|
134
|
-
template[real_subset] = _apply_filter(_template, _ft_temp, template_filter)
|
135
|
-
return template
|
136
|
-
|
137
|
-
# Filter application requires a broadcasting operation
|
138
|
-
def _apply_filter_broadcast(template, ft_temp, template_filter):
|
139
|
-
_ft_prod = rfftn(template, _ft_temp2)
|
140
|
-
_ft_res = be.multiply(_ft_prod, template_filter, out=_ft_temp)
|
141
|
-
return irfftn(_ft_res, _template)
|
142
|
-
|
143
|
-
if any(x > y and y == 1 for x, y in zip(filter_shape, arr_ft_shape)):
|
144
|
-
_template = be.zeros(inv_output_shape, be._float_dtype)
|
145
|
-
_ft_temp2 = be.zeros((1, *product_ft_shape[1:]), be._complex_dtype)
|
146
|
-
return _apply_filter_broadcast
|
147
|
-
|
148
|
-
return _apply_filter_subset
|
149
|
-
|
150
|
-
return _apply_filter
|
151
|
-
|
152
|
-
|
153
24
|
def cc_setup(
|
154
25
|
matching_data: type,
|
155
26
|
fast_shape: Tuple[int],
|
@@ -176,8 +47,6 @@ def cc_setup(
|
|
176
47
|
axes = matching_data._batch_axis(matching_data._batch_mask)
|
177
48
|
|
178
49
|
ret = {
|
179
|
-
"fast_shape": fast_shape,
|
180
|
-
"fast_ft_shape": fast_ft_shape,
|
181
50
|
"template": be.to_sharedarr(matching_data.template, shm_handler),
|
182
51
|
"ft_target": be.to_sharedarr(be.rfftn(target_pad, axes=axes), shm_handler),
|
183
52
|
"inv_denominator": be.to_sharedarr(
|
@@ -219,8 +88,8 @@ def lcc_setup(matching_data, **kwargs) -> Dict:
|
|
219
88
|
for subset in subsets:
|
220
89
|
template[subset] = laplace(template[subset], mode="wrap")
|
221
90
|
|
222
|
-
matching_data._target = target
|
223
|
-
matching_data._template = template
|
91
|
+
matching_data._target = be.to_backend_array(target)
|
92
|
+
matching_data._template = be.to_backend_array(template)
|
224
93
|
|
225
94
|
return cc_setup(matching_data=matching_data, **kwargs)
|
226
95
|
|
@@ -316,8 +185,6 @@ def corr_setup(
|
|
316
185
|
denominator = be.multiply(denominator, mask, out=denominator)
|
317
186
|
|
318
187
|
ret = {
|
319
|
-
"fast_shape": fast_shape,
|
320
|
-
"fast_ft_shape": fast_ft_shape,
|
321
188
|
"template": be.to_sharedarr(template, shm_handler),
|
322
189
|
"ft_target": be.to_sharedarr(ft_target, shm_handler),
|
323
190
|
"inv_denominator": be.to_sharedarr(denominator, shm_handler),
|
@@ -376,8 +243,6 @@ def flc_setup(
|
|
376
243
|
ft_target2 = be.rfftn(target_pad, axes=data_axes)
|
377
244
|
|
378
245
|
ret = {
|
379
|
-
"fast_shape": fast_shape,
|
380
|
-
"fast_ft_shape": fast_ft_shape,
|
381
246
|
"template": be.to_sharedarr(matching_data.template, shm_handler),
|
382
247
|
"template_mask": be.to_sharedarr(matching_data.template_mask, shm_handler),
|
383
248
|
"ft_target": be.to_sharedarr(ft_target, shm_handler),
|
@@ -438,8 +303,6 @@ def flcSphericalMask_setup(
|
|
438
303
|
|
439
304
|
temp2 = be.norm_scores(1, temp2, temp, n_obs, be.eps(be._float_dtype), temp2)
|
440
305
|
ret = {
|
441
|
-
"fast_shape": fast_shape,
|
442
|
-
"fast_ft_shape": fast_ft_shape,
|
443
306
|
"template": be.to_sharedarr(matching_data.template, shm_handler),
|
444
307
|
"template_mask": be.to_sharedarr(template_mask, shm_handler),
|
445
308
|
"ft_target": be.to_sharedarr(ft_target, shm_handler),
|
@@ -461,21 +324,14 @@ def mcc_setup(
|
|
461
324
|
Setup function for :py:meth:`mcc_scoring`.
|
462
325
|
"""
|
463
326
|
target, target_mask = matching_data.target, matching_data.target_mask
|
464
|
-
target = be.multiply(target, target_mask
|
327
|
+
target = be.multiply(target, target_mask, out=target)
|
465
328
|
|
466
|
-
target = be.topleft_pad(
|
467
|
-
target,
|
468
|
-
matching_data._batch_shape(fast_shape, matching_data._template_batch),
|
469
|
-
)
|
470
|
-
target_mask = be.topleft_pad(
|
471
|
-
target_mask,
|
472
|
-
matching_data._batch_shape(fast_shape, matching_data._template_batch),
|
473
|
-
)
|
474
329
|
ax = matching_data._batch_axis(matching_data._batch_mask)
|
330
|
+
shape = matching_data._batch_shape(fast_shape, matching_data._template_batch)
|
331
|
+
target = be.topleft_pad(target, shape)
|
332
|
+
target_mask = be.topleft_pad(target_mask, shape)
|
475
333
|
|
476
334
|
ret = {
|
477
|
-
"fast_shape": fast_shape,
|
478
|
-
"fast_ft_shape": fast_ft_shape,
|
479
335
|
"template": be.to_sharedarr(matching_data.template, shm_handler),
|
480
336
|
"template_mask": be.to_sharedarr(matching_data.template_mask, shm_handler),
|
481
337
|
"ft_target": be.to_sharedarr(be.rfftn(target, axes=ax), shm_handler),
|
@@ -500,6 +356,7 @@ def corr_scoring(
|
|
500
356
|
callback: CallbackClass,
|
501
357
|
interpolation_order: int,
|
502
358
|
template_mask: shm_type = None,
|
359
|
+
score_mask: shm_type = None,
|
503
360
|
) -> CallbackClass:
|
504
361
|
"""
|
505
362
|
Calculates a normalized cross-correlation between a target f and a template g.
|
@@ -538,70 +395,45 @@ def corr_scoring(
|
|
538
395
|
Spline order for template rotations.
|
539
396
|
template_mask : Union[Tuple[type, tuple of ints, type], BackendArray], optional
|
540
397
|
Template mask data buffer, its shape and datatype, None by default.
|
398
|
+
score_mask : Union[Tuple[type, tuple of ints, type], BackendArray], optional
|
399
|
+
Score mask data buffer, its shape and datatype, None by default.
|
541
400
|
|
542
401
|
Returns
|
543
402
|
-------
|
544
|
-
|
545
|
-
``callback`` if provided otherwise None.
|
403
|
+
CallbackClass
|
546
404
|
"""
|
547
405
|
template = be.from_sharedarr(template)
|
548
406
|
ft_target = be.from_sharedarr(ft_target)
|
549
407
|
inv_denominator = be.from_sharedarr(inv_denominator)
|
550
408
|
numerator = be.from_sharedarr(numerator)
|
551
409
|
template_filter = be.from_sharedarr(template_filter)
|
410
|
+
score_mask = be.from_sharedarr(score_mask)
|
552
411
|
|
553
|
-
|
412
|
+
n_obs = None
|
554
413
|
if template_mask is not None:
|
555
414
|
template_mask = be.from_sharedarr(template_mask)
|
556
|
-
|
557
|
-
|
558
|
-
|
559
|
-
|
560
|
-
|
561
|
-
|
562
|
-
norm_template = conditional_execute(norm_func, norm_template)
|
563
|
-
norm_numerator = conditional_execute(
|
564
|
-
be.subtract, identity, _shape_match(numerator.shape, fast_shape)
|
565
|
-
)
|
566
|
-
norm_denominator = conditional_execute(
|
567
|
-
be.multiply, identity, _shape_match(inv_denominator.shape, fast_shape)
|
568
|
-
)
|
415
|
+
n_obs = be.sum(template_mask) if template_mask is not None else None
|
416
|
+
|
417
|
+
norm_template = conditional_execute(normalize_template, n_obs is not None)
|
418
|
+
norm_sub = conditional_execute(be.subtract, numerator.shape != (1,))
|
419
|
+
norm_mul = conditional_execute(be.multiply, inv_denominator.shape != (1))
|
420
|
+
norm_mask = conditional_execute(be.multiply, score_mask.shape != (1,))
|
569
421
|
|
570
422
|
arr = be.zeros(fast_shape, be._float_dtype)
|
571
423
|
ft_temp = be.zeros(fast_ft_shape, be._complex_dtype)
|
572
|
-
|
573
|
-
_fftargs = {
|
574
|
-
"real_dtype": be._float_dtype,
|
575
|
-
"cmpl_dtype": be._complex_dtype,
|
576
|
-
"inv_output_shape": fast_shape,
|
577
|
-
"fwd_axes": None,
|
578
|
-
"inv_axes": None,
|
579
|
-
"inv_shape": fast_ft_shape,
|
580
|
-
"temp_fwd": arr,
|
581
|
-
}
|
582
|
-
|
583
|
-
_fftargs["fwd_shape"] = _fftargs["temp_fwd"].shape
|
584
|
-
rfftn, irfftn = be.build_fft(temp_inv=ft_temp, **_fftargs)
|
585
|
-
_ = _fftargs.pop("temp_fwd", None)
|
424
|
+
template_rot = be.zeros(template.shape, be._float_dtype)
|
586
425
|
|
587
426
|
template_filter_func = _create_filter_func(
|
588
427
|
arr_shape=template.shape,
|
589
|
-
|
590
|
-
arr_filter=template_filter,
|
591
|
-
rfftn=rfftn,
|
592
|
-
irfftn=irfftn,
|
593
|
-
**_fftargs,
|
428
|
+
filter_shape=template_filter.shape,
|
594
429
|
)
|
595
430
|
|
596
431
|
center = be.divide(be.to_backend_array(template.shape) - 1, 2)
|
597
432
|
unpadded_slice = tuple(slice(0, stop) for stop in template.shape)
|
598
|
-
|
599
|
-
template_rot = be.zeros(template.shape, be._float_dtype)
|
600
433
|
for index in range(rotations.shape[0]):
|
601
|
-
# d+1, d+1 rigid transform matrix from d,d rotation matrix
|
602
434
|
rotation = rotations[index]
|
603
435
|
matrix = be._rigid_transform_matrix(rotation_matrix=rotation, center=center)
|
604
|
-
|
436
|
+
_ = be.rigid_transform(
|
605
437
|
arr=template,
|
606
438
|
rotation_matrix=matrix,
|
607
439
|
out=template_rot,
|
@@ -610,18 +442,19 @@ def corr_scoring(
|
|
610
442
|
)
|
611
443
|
|
612
444
|
template_rot = template_filter_func(template_rot, ft_temp, template_filter)
|
613
|
-
norm_template(template_rot, template_mask,
|
445
|
+
norm_template(template_rot, template_mask, n_obs)
|
614
446
|
|
615
447
|
arr = be.fill(arr, 0)
|
616
448
|
arr[unpadded_slice] = template_rot
|
617
449
|
|
618
|
-
ft_temp = rfftn(arr, ft_temp)
|
619
|
-
|
620
|
-
|
450
|
+
ft_temp = be.rfftn(arr, s=fast_shape, out=ft_temp)
|
451
|
+
arr = _correlate_fts(ft_target, ft_temp, ft_temp, arr, fast_shape)
|
452
|
+
|
453
|
+
arr = norm_sub(arr, numerator, out=arr)
|
454
|
+
arr = norm_mul(arr, inv_denominator, out=arr)
|
455
|
+
arr = norm_mask(arr, score_mask, out=arr)
|
621
456
|
|
622
|
-
|
623
|
-
arr = norm_denominator(arr, inv_denominator, out=arr)
|
624
|
-
callback_func(arr, rotation_matrix=rotation)
|
457
|
+
callback(arr, rotation_matrix=rotation)
|
625
458
|
|
626
459
|
return callback
|
627
460
|
|
@@ -637,6 +470,7 @@ def flc_scoring(
|
|
637
470
|
rotations: BackendArray,
|
638
471
|
callback: CallbackClass,
|
639
472
|
interpolation_order: int,
|
473
|
+
score_mask: shm_type = None,
|
640
474
|
) -> CallbackClass:
|
641
475
|
"""
|
642
476
|
Computes a normalized cross-correlation between ``target`` (f),
|
@@ -682,6 +516,10 @@ def flc_scoring(
|
|
682
516
|
interpolation_order : int
|
683
517
|
Spline order for template rotations.
|
684
518
|
|
519
|
+
Returns
|
520
|
+
-------
|
521
|
+
CallbackClass
|
522
|
+
|
685
523
|
References
|
686
524
|
----------
|
687
525
|
.. [1] Hrabe T. et al, J. Struct. Biol. 178, 177 (2012).
|
@@ -692,69 +530,56 @@ def flc_scoring(
|
|
692
530
|
ft_target = be.from_sharedarr(ft_target)
|
693
531
|
ft_target2 = be.from_sharedarr(ft_target2)
|
694
532
|
template_filter = be.from_sharedarr(template_filter)
|
533
|
+
score_mask = be.from_sharedarr(score_mask)
|
695
534
|
|
696
535
|
arr = be.zeros(fast_shape, float_dtype)
|
697
536
|
temp = be.zeros(fast_shape, float_dtype)
|
698
537
|
temp2 = be.zeros(fast_shape, float_dtype)
|
699
538
|
ft_temp = be.zeros(fast_ft_shape, complex_dtype)
|
700
539
|
ft_denom = be.zeros(fast_ft_shape, complex_dtype)
|
540
|
+
template_rot = be.zeros(template.shape, be._float_dtype)
|
541
|
+
template_mask_rot = be.zeros(template.shape, be._float_dtype)
|
701
542
|
|
702
|
-
|
703
|
-
|
704
|
-
"cmpl_dtype": be._complex_dtype,
|
705
|
-
"inv_output_shape": fast_shape,
|
706
|
-
"fwd_axes": None,
|
707
|
-
"inv_axes": None,
|
708
|
-
"inv_shape": fast_ft_shape,
|
709
|
-
"temp_fwd": arr,
|
710
|
-
}
|
711
|
-
|
712
|
-
_fftargs["fwd_shape"] = _fftargs["temp_fwd"].shape
|
713
|
-
rfftn, irfftn = be.build_fft(temp_inv=ft_temp, **_fftargs)
|
714
|
-
_ = _fftargs.pop("temp_fwd", None)
|
715
|
-
|
716
|
-
template_filter_func = _create_filter_func(
|
717
|
-
arr_shape=template.shape,
|
718
|
-
arr_ft_shape=fast_ft_shape,
|
719
|
-
arr_filter=template_filter,
|
720
|
-
rfftn=rfftn,
|
721
|
-
irfftn=irfftn,
|
722
|
-
**_fftargs,
|
723
|
-
)
|
543
|
+
tmpl_filter_func = _create_filter_func(template.shape, template_filter.shape)
|
544
|
+
norm_mask = conditional_execute(be.multiply, score_mask.shape != (1,))
|
724
545
|
|
725
546
|
eps = be.eps(float_dtype)
|
726
|
-
|
547
|
+
center = be.divide(be.to_backend_array(template.shape) - 1, 2)
|
548
|
+
unpadded_slice = tuple(slice(0, stop) for stop in template.shape)
|
727
549
|
for index in range(rotations.shape[0]):
|
728
550
|
rotation = rotations[index]
|
729
|
-
|
730
|
-
|
731
|
-
arr, temp = be.rigid_transform(
|
551
|
+
matrix = be._rigid_transform_matrix(rotation_matrix=rotation, center=center)
|
552
|
+
_ = be.rigid_transform(
|
732
553
|
arr=template,
|
733
554
|
arr_mask=template_mask,
|
734
|
-
rotation_matrix=
|
735
|
-
out=
|
736
|
-
out_mask=
|
555
|
+
rotation_matrix=matrix,
|
556
|
+
out=template_rot,
|
557
|
+
out_mask=template_mask_rot,
|
737
558
|
use_geometric_center=True,
|
738
559
|
order=interpolation_order,
|
739
560
|
cache=True,
|
740
561
|
)
|
741
562
|
|
742
|
-
n_obs = be.sum(
|
743
|
-
|
744
|
-
|
563
|
+
n_obs = be.sum(template_mask_rot)
|
564
|
+
template_rot = tmpl_filter_func(template_rot, ft_temp, template_filter)
|
565
|
+
template_rot = normalize_template(template_rot, template_mask_rot, n_obs)
|
745
566
|
|
746
|
-
|
747
|
-
|
748
|
-
|
749
|
-
|
750
|
-
|
567
|
+
arr = be.fill(arr, 0)
|
568
|
+
temp = be.fill(temp, 0)
|
569
|
+
arr[unpadded_slice] = template_rot
|
570
|
+
temp[unpadded_slice] = template_mask_rot
|
571
|
+
|
572
|
+
ft_temp = be.rfftn(temp, out=ft_temp, s=fast_shape)
|
573
|
+
temp = _correlate_fts(ft_target, ft_temp, ft_denom, temp, fast_shape)
|
574
|
+
temp2 = _correlate_fts(ft_target2, ft_temp, ft_denom, temp, fast_shape)
|
751
575
|
|
752
|
-
ft_temp = rfftn(arr, ft_temp)
|
753
|
-
|
754
|
-
arr = irfftn(ft_temp, arr)
|
576
|
+
ft_temp = be.rfftn(arr, out=ft_temp, s=fast_shape)
|
577
|
+
arr = _correlate_fts(ft_target, ft_temp, ft_temp, arr, fast_shape)
|
755
578
|
|
756
579
|
arr = be.norm_scores(arr, temp2, temp, n_obs, eps, arr)
|
757
|
-
|
580
|
+
arr = norm_mask(arr, score_mask, out=arr)
|
581
|
+
|
582
|
+
callback(arr, rotation_matrix=rotation)
|
758
583
|
|
759
584
|
return callback
|
760
585
|
|
@@ -772,6 +597,7 @@ def mcc_scoring(
|
|
772
597
|
callback: CallbackClass,
|
773
598
|
interpolation_order: int,
|
774
599
|
overlap_ratio: float = 0.3,
|
600
|
+
score_mask: shm_type = None,
|
775
601
|
) -> CallbackClass:
|
776
602
|
"""
|
777
603
|
Computes a normalized cross-correlation score between ``target`` (f),
|
@@ -821,6 +647,10 @@ def mcc_scoring(
|
|
821
647
|
overlap_ratio : float, optional
|
822
648
|
Required fractional mask overlap, 0.3 by default.
|
823
649
|
|
650
|
+
Returns
|
651
|
+
-------
|
652
|
+
CallbackClass
|
653
|
+
|
824
654
|
References
|
825
655
|
----------
|
826
656
|
.. [1] Masked FFT registration, Dirk Padfield, CVPR 2010 conference
|
@@ -846,30 +676,12 @@ def mcc_scoring(
|
|
846
676
|
temp3 = be.zeros(fast_shape, float_dtype)
|
847
677
|
temp_ft = be.zeros(fast_ft_shape, complex_dtype)
|
848
678
|
|
849
|
-
_fftargs = {
|
850
|
-
"real_dtype": be._float_dtype,
|
851
|
-
"cmpl_dtype": be._complex_dtype,
|
852
|
-
"inv_output_shape": fast_shape,
|
853
|
-
"fwd_axes": None,
|
854
|
-
"inv_axes": None,
|
855
|
-
"inv_shape": fast_ft_shape,
|
856
|
-
"temp_fwd": temp,
|
857
|
-
}
|
858
|
-
|
859
|
-
_fftargs["fwd_shape"] = _fftargs["temp_fwd"].shape
|
860
|
-
rfftn, irfftn = be.build_fft(temp_inv=temp_ft, **_fftargs)
|
861
|
-
_ = _fftargs.pop("temp_fwd", None)
|
862
|
-
|
863
679
|
template_filter_func = _create_filter_func(
|
864
680
|
arr_shape=template.shape,
|
865
|
-
|
866
|
-
|
867
|
-
rfftn=rfftn,
|
868
|
-
irfftn=irfftn,
|
869
|
-
**_fftargs,
|
681
|
+
filter_shape=template_filter.shape,
|
682
|
+
arr_padded=True,
|
870
683
|
)
|
871
684
|
|
872
|
-
callback_func = conditional_execute(callback, callback is not None)
|
873
685
|
for index in range(rotations.shape[0]):
|
874
686
|
rotation = rotations[index]
|
875
687
|
template_rot = be.fill(template_rot, 0)
|
@@ -888,17 +700,19 @@ def mcc_scoring(
|
|
888
700
|
template_filter_func(template_rot, temp_ft, template_filter)
|
889
701
|
normalize_template(template_rot, temp, be.sum(temp))
|
890
702
|
|
891
|
-
temp_ft = rfftn(template_rot, temp_ft)
|
892
|
-
temp2 = irfftn(target_mask_ft * temp_ft, temp2)
|
893
|
-
numerator = irfftn(target_ft * temp_ft, numerator)
|
703
|
+
temp_ft = be.rfftn(template_rot, out=temp_ft, s=fast_shape)
|
704
|
+
temp2 = be.irfftn(target_mask_ft * temp_ft, out=temp2, s=fast_shape)
|
705
|
+
numerator = be.irfftn(target_ft * temp_ft, out=numerator, s=fast_shape)
|
894
706
|
|
895
707
|
# temp template_mask_rot | temp_ft template_mask_rot_ft
|
896
708
|
# Calculate overlap of masks at every point in the convolution.
|
897
709
|
# Locations with high overlap should not be taken into account.
|
898
|
-
temp_ft = rfftn(temp, temp_ft)
|
899
|
-
mask_overlap = irfftn(
|
710
|
+
temp_ft = be.rfftn(temp, out=temp_ft, s=fast_shape)
|
711
|
+
mask_overlap = be.irfftn(
|
712
|
+
temp_ft * target_mask_ft, out=mask_overlap, s=fast_shape
|
713
|
+
)
|
900
714
|
be.maximum(mask_overlap, eps, out=mask_overlap)
|
901
|
-
temp = irfftn(temp_ft * target_ft, temp)
|
715
|
+
temp = be.irfftn(temp_ft * target_ft, out=temp, s=fast_shape)
|
902
716
|
|
903
717
|
be.subtract(
|
904
718
|
numerator,
|
@@ -908,21 +722,21 @@ def mcc_scoring(
|
|
908
722
|
|
909
723
|
# temp_3 = fixed_denom
|
910
724
|
be.multiply(temp_ft, target_ft2, out=temp_ft)
|
911
|
-
temp3 = irfftn(temp_ft, temp3)
|
725
|
+
temp3 = be.irfftn(temp_ft, out=temp3, s=fast_shape)
|
912
726
|
be.subtract(temp3, be.divide(be.square(temp), mask_overlap), out=temp3)
|
913
727
|
be.maximum(temp3, 0.0, out=temp3)
|
914
728
|
|
915
729
|
# temp = moving_denom
|
916
|
-
temp_ft = rfftn(be.square(template_rot), temp_ft)
|
730
|
+
temp_ft = be.rfftn(be.square(template_rot), out=temp_ft, s=fast_shape)
|
917
731
|
be.multiply(target_mask_ft, temp_ft, out=temp_ft)
|
918
|
-
temp = irfftn(temp_ft, temp)
|
732
|
+
temp = be.irfftn(temp_ft, out=temp, s=fast_shape)
|
919
733
|
|
920
734
|
be.subtract(temp, be.divide(be.square(temp2), mask_overlap), out=temp)
|
921
735
|
be.maximum(temp, 0.0, out=temp)
|
922
736
|
|
923
737
|
# temp_2 = denom
|
924
738
|
be.multiply(temp3, temp, out=temp)
|
925
|
-
be.sqrt(temp, temp2)
|
739
|
+
be.sqrt(temp, out=temp2)
|
926
740
|
|
927
741
|
# Pixels where `denom` is very small will introduce large
|
928
742
|
# numbers after division. To get around this problem,
|
@@ -938,29 +752,11 @@ def mcc_scoring(
|
|
938
752
|
mask_overlap, axis=axes, keepdims=True
|
939
753
|
)
|
940
754
|
temp[mask_overlap < number_px_threshold] = 0.0
|
941
|
-
|
755
|
+
callback(temp, rotation_matrix=rotation)
|
942
756
|
|
943
757
|
return callback
|
944
758
|
|
945
759
|
|
946
|
-
def _format_slice(shape, squeeze_axis):
|
947
|
-
ret = tuple(
|
948
|
-
slice(None) if i not in squeeze_axis else 0 for i, _ in enumerate(shape)
|
949
|
-
)
|
950
|
-
return ret
|
951
|
-
|
952
|
-
|
953
|
-
def _get_batch_dim(target, template):
|
954
|
-
target_batch, template_batch = [], []
|
955
|
-
for i in range(len(target.shape)):
|
956
|
-
if target.shape[i] == 1 and template.shape[i] != 1:
|
957
|
-
template_batch.append(i)
|
958
|
-
if target.shape[i] != 1 and template.shape[i] == 1:
|
959
|
-
target_batch.append(i)
|
960
|
-
|
961
|
-
return target_batch, template_batch
|
962
|
-
|
963
|
-
|
964
760
|
def flc_scoring2(
|
965
761
|
template: shm_type,
|
966
762
|
template_mask: shm_type,
|
@@ -972,26 +768,25 @@ def flc_scoring2(
|
|
972
768
|
rotations: BackendArray,
|
973
769
|
callback: CallbackClass,
|
974
770
|
interpolation_order: int,
|
771
|
+
score_mask: shm_type = None,
|
975
772
|
) -> CallbackClass:
|
976
|
-
callback_func = conditional_execute(callback, callback is not None)
|
977
|
-
|
978
|
-
# Retrieve objects from shared memory
|
979
773
|
template = be.from_sharedarr(template)
|
980
774
|
template_mask = be.from_sharedarr(template_mask)
|
981
775
|
ft_target = be.from_sharedarr(ft_target)
|
982
776
|
ft_target2 = be.from_sharedarr(ft_target2)
|
983
777
|
template_filter = be.from_sharedarr(template_filter)
|
778
|
+
score_mask = be.from_sharedarr(score_mask)
|
984
779
|
|
985
|
-
|
986
|
-
target_batch, template_batch = _get_batch_dim(ft_target, template)
|
987
|
-
sqz_cmpl = tuple(1 if i in target_batch else x for i, x in enumerate(fast_ft_shape))
|
988
|
-
sqz_slice = tuple(slice(0, 1) if x == 1 else slice(None) for x in sqz_cmpl)
|
780
|
+
tar_batch, tmpl_batch = _get_batch_dim(ft_target, template)
|
989
781
|
|
990
|
-
|
991
|
-
if
|
992
|
-
|
993
|
-
|
994
|
-
|
782
|
+
nd = len(fast_shape)
|
783
|
+
sqz_slice = tuple(slice(0, 1) if i in tar_batch else slice(None) for i in range(nd))
|
784
|
+
tmpl_subset = tuple(0 if i in tar_batch else slice(None) for i in range(nd))
|
785
|
+
|
786
|
+
axes, shape, batched = None, fast_shape, len(tmpl_batch) > 0
|
787
|
+
if len(tar_batch) or len(tmpl_batch):
|
788
|
+
axes = tuple(i for i in range(nd) if i not in (*tar_batch, *tmpl_batch))
|
789
|
+
shape = tuple(fast_shape[i] for i in axes)
|
995
790
|
|
996
791
|
arr = be.zeros(fast_shape, be._float_dtype)
|
997
792
|
temp = be.zeros(fast_shape, be._float_dtype)
|
@@ -999,68 +794,47 @@ def flc_scoring2(
|
|
999
794
|
ft_denom = be.zeros(fast_ft_shape, be._complex_dtype)
|
1000
795
|
|
1001
796
|
tmp_sqz, arr_sqz, ft_temp = temp[sqz_slice], arr[sqz_slice], ft_denom[sqz_slice]
|
1002
|
-
if be.size(template_filter) != 1:
|
1003
|
-
ret_shape = np.broadcast_shapes(
|
1004
|
-
sqz_cmpl, tuple(int(x) for x in template_filter.shape)
|
1005
|
-
)
|
1006
|
-
ft_temp = be.zeros(ret_shape, be._complex_dtype)
|
1007
|
-
|
1008
|
-
_fftargs = {
|
1009
|
-
"real_dtype": be._float_dtype,
|
1010
|
-
"cmpl_dtype": be._complex_dtype,
|
1011
|
-
"inv_output_shape": fast_shape,
|
1012
|
-
"fwd_axes": data_axes,
|
1013
|
-
"inv_axes": data_axes,
|
1014
|
-
"inv_shape": fast_ft_shape,
|
1015
|
-
"temp_fwd": arr_sqz if _shape_match(ft_temp.shape, sqz_cmpl) else arr,
|
1016
|
-
}
|
1017
|
-
|
1018
|
-
# build_fft ignores fwd_shape if temp_fwd is given and serves only for bookkeeping
|
1019
|
-
_fftargs["fwd_shape"] = _fftargs["temp_fwd"].shape
|
1020
|
-
rfftn, irfftn = be.build_fft(temp_inv=ft_denom, **_fftargs)
|
1021
|
-
_ = _fftargs.pop("temp_fwd", None)
|
1022
797
|
|
1023
798
|
template_filter_func = _create_filter_func(
|
1024
799
|
arr_shape=template.shape,
|
1025
|
-
|
1026
|
-
|
1027
|
-
rfftn=rfftn,
|
1028
|
-
irfftn=irfftn,
|
1029
|
-
**_fftargs,
|
800
|
+
filter_shape=template_filter.shape,
|
801
|
+
arr_padded=True,
|
1030
802
|
)
|
803
|
+
norm_mask = conditional_execute(be.multiply, score_mask.shape != (1,))
|
1031
804
|
|
1032
805
|
eps = be.eps(be._float_dtype)
|
1033
806
|
for index in range(rotations.shape[0]):
|
1034
807
|
rotation = rotations[index]
|
1035
808
|
be.fill(arr, 0)
|
1036
809
|
be.fill(temp, 0)
|
1037
|
-
|
1038
|
-
|
1039
|
-
|
810
|
+
|
811
|
+
_, _ = be.rigid_transform(
|
812
|
+
arr=template[tmpl_subset],
|
813
|
+
arr_mask=template_mask[tmpl_subset],
|
1040
814
|
rotation_matrix=rotation,
|
1041
|
-
out=arr_sqz,
|
1042
|
-
out_mask=tmp_sqz,
|
815
|
+
out=arr_sqz[tmpl_subset],
|
816
|
+
out_mask=tmp_sqz[tmpl_subset],
|
1043
817
|
use_geometric_center=True,
|
1044
818
|
order=interpolation_order,
|
1045
|
-
cache=
|
1046
|
-
batched=
|
819
|
+
cache=False,
|
820
|
+
batched=batched,
|
1047
821
|
)
|
1048
|
-
|
822
|
+
|
823
|
+
n_obs = be.sum(tmp_sqz, axis=axes, keepdims=True)
|
1049
824
|
arr_norm = template_filter_func(arr_sqz, ft_temp, template_filter)
|
1050
|
-
arr_norm = normalize_template(arr_norm, tmp_sqz, n_obs, axis=
|
825
|
+
arr_norm = normalize_template(arr_norm, tmp_sqz, n_obs, axis=axes)
|
826
|
+
|
827
|
+
ft_temp = be.rfftn(tmp_sqz, out=ft_temp, axes=axes, s=shape)
|
828
|
+
temp = _correlate_fts(ft_target, ft_temp, ft_denom, temp, shape, axes)
|
829
|
+
temp2 = _correlate_fts(ft_target2, ft_temp, ft_denom, temp2, shape, axes)
|
1051
830
|
|
1052
|
-
ft_temp = be.rfftn(
|
1053
|
-
|
1054
|
-
temp = be.irfftn(ft_denom, temp, axes=data_axes, s=data_shape)
|
1055
|
-
ft_denom = be.multiply(ft_target2, ft_temp, out=ft_denom)
|
1056
|
-
temp2 = be.irfftn(ft_denom, temp2, axes=data_axes, s=data_shape)
|
831
|
+
ft_temp = be.rfftn(arr_norm, out=ft_temp, axes=axes, s=shape)
|
832
|
+
arr = _correlate_fts(ft_target, ft_temp, ft_denom, arr, shape, axes)
|
1057
833
|
|
1058
|
-
|
1059
|
-
|
1060
|
-
arr = irfftn(ft_denom, arr)
|
834
|
+
arr = be.norm_scores(arr, temp2, temp, n_obs, eps, arr)
|
835
|
+
arr = norm_mask(arr, score_mask, out=arr)
|
1061
836
|
|
1062
|
-
|
1063
|
-
callback_func(arr, rotation_matrix=rotation)
|
837
|
+
callback(arr, rotation_matrix=rotation)
|
1064
838
|
|
1065
839
|
return callback
|
1066
840
|
|
@@ -1078,108 +852,152 @@ def corr_scoring2(
|
|
1078
852
|
interpolation_order: int,
|
1079
853
|
target_filter: shm_type = None,
|
1080
854
|
template_mask: shm_type = None,
|
855
|
+
score_mask: shm_type = None,
|
1081
856
|
) -> CallbackClass:
|
1082
857
|
template = be.from_sharedarr(template)
|
1083
858
|
ft_target = be.from_sharedarr(ft_target)
|
1084
859
|
inv_denominator = be.from_sharedarr(inv_denominator)
|
1085
860
|
numerator = be.from_sharedarr(numerator)
|
1086
861
|
template_filter = be.from_sharedarr(template_filter)
|
862
|
+
score_mask = be.from_sharedarr(score_mask)
|
1087
863
|
|
1088
|
-
|
1089
|
-
|
1090
|
-
|
1091
|
-
sqz_slice = tuple(slice(0, 1) if
|
1092
|
-
|
1093
|
-
|
1094
|
-
|
1095
|
-
|
1096
|
-
|
1097
|
-
|
1098
|
-
|
1099
|
-
|
864
|
+
tar_batch, tmpl_batch = _get_batch_dim(ft_target, template)
|
865
|
+
|
866
|
+
nd = len(fast_shape)
|
867
|
+
sqz_slice = tuple(slice(0, 1) if i in tar_batch else slice(None) for i in range(nd))
|
868
|
+
tmpl_subset = tuple(0 if i in tar_batch else slice(None) for i in range(nd))
|
869
|
+
|
870
|
+
axes, shape, batched = None, fast_shape, len(tmpl_batch) > 0
|
871
|
+
if len(tar_batch) or len(tmpl_batch):
|
872
|
+
axes = tuple(i for i in range(nd) if i not in (*tar_batch, *tmpl_batch))
|
873
|
+
shape = tuple(fast_shape[i] for i in axes)
|
874
|
+
|
875
|
+
unpadded_slice = tuple(
|
876
|
+
slice(None) if i in (*tar_batch, *tmpl_batch) else slice(0, x)
|
877
|
+
for i, x in enumerate(template.shape)
|
878
|
+
)
|
1100
879
|
|
1101
880
|
arr = be.zeros(fast_shape, be._float_dtype)
|
1102
881
|
ft_temp = be.zeros(fast_ft_shape, be._complex_dtype)
|
1103
882
|
arr_sqz, ft_sqz = arr[sqz_slice], ft_temp[sqz_slice]
|
1104
883
|
|
1105
|
-
|
1106
|
-
# The filter could be w.r.t the unpadded template
|
1107
|
-
ret_shape = tuple(
|
1108
|
-
int(x * y) if x == 1 or y == 1 else y
|
1109
|
-
for x, y in zip(sqz_cmpl, template_filter.shape)
|
1110
|
-
)
|
1111
|
-
ft_sqz = be.zeros(ret_shape, be._complex_dtype)
|
1112
|
-
|
1113
|
-
norm_func, norm_template, mask_sum = normalize_template, False, 1
|
884
|
+
n_obs = None
|
1114
885
|
if template_mask is not None:
|
1115
886
|
template_mask = be.from_sharedarr(template_mask)
|
1116
|
-
|
1117
|
-
be.astype(template_mask, be._overflow_safe_dtype),
|
1118
|
-
axis=data_axes,
|
1119
|
-
keepdims=True,
|
1120
|
-
)
|
1121
|
-
if be.datatype_bytes(template_mask.dtype) == 2:
|
1122
|
-
norm_func = _normalize_template_overflow_safe
|
1123
|
-
|
1124
|
-
callback_func = conditional_execute(callback, callback is not None)
|
1125
|
-
norm_template = conditional_execute(norm_func, norm_template)
|
1126
|
-
norm_numerator = conditional_execute(
|
1127
|
-
be.subtract, identity, _shape_match(numerator.shape, fast_shape)
|
1128
|
-
)
|
1129
|
-
norm_denominator = conditional_execute(
|
1130
|
-
be.multiply, identity, _shape_match(inv_denominator.shape, fast_shape)
|
1131
|
-
)
|
887
|
+
n_obs = be.sum(template_mask, axis=axes, keepdims=True)
|
1132
888
|
|
1133
|
-
|
1134
|
-
|
1135
|
-
|
1136
|
-
|
1137
|
-
"inv_axes": data_axes,
|
1138
|
-
"inv_shape": fast_ft_shape,
|
1139
|
-
"inv_output_shape": fast_shape,
|
1140
|
-
"temp_fwd": arr_sqz if _shape_match(ft_sqz.shape, sqz_cmpl) else arr,
|
1141
|
-
}
|
1142
|
-
|
1143
|
-
# build_fft ignores fwd_shape if temp_fwd is given and serves only for bookkeeping
|
1144
|
-
_fftargs["fwd_shape"] = _fftargs["temp_fwd"].shape
|
1145
|
-
rfftn, irfftn = be.build_fft(temp_inv=ft_temp, **_fftargs)
|
1146
|
-
_ = _fftargs.pop("temp_fwd", None)
|
889
|
+
norm_template = conditional_execute(normalize_template, n_obs is not None)
|
890
|
+
norm_sub = conditional_execute(be.subtract, numerator.shape != (1,))
|
891
|
+
norm_mul = conditional_execute(be.multiply, inv_denominator.shape != (1,))
|
892
|
+
norm_mask = conditional_execute(be.multiply, score_mask.shape != (1,))
|
1147
893
|
|
1148
894
|
template_filter_func = _create_filter_func(
|
1149
895
|
arr_shape=template.shape,
|
1150
|
-
|
1151
|
-
|
1152
|
-
rfftn=rfftn,
|
1153
|
-
irfftn=irfftn,
|
1154
|
-
**_fftargs,
|
896
|
+
filter_shape=template_filter.shape,
|
897
|
+
arr_padded=True,
|
1155
898
|
)
|
1156
899
|
|
1157
900
|
for index in range(rotations.shape[0]):
|
1158
901
|
be.fill(arr, 0)
|
1159
902
|
rotation = rotations[index]
|
1160
|
-
|
1161
|
-
arr=template,
|
903
|
+
_, _ = be.rigid_transform(
|
904
|
+
arr=template[tmpl_subset],
|
1162
905
|
rotation_matrix=rotation,
|
1163
|
-
out=arr_sqz,
|
906
|
+
out=arr_sqz[tmpl_subset],
|
1164
907
|
use_geometric_center=True,
|
1165
908
|
order=interpolation_order,
|
1166
|
-
cache=
|
1167
|
-
batched=
|
909
|
+
cache=False,
|
910
|
+
batched=batched,
|
1168
911
|
)
|
1169
912
|
arr_norm = template_filter_func(arr_sqz, ft_sqz, template_filter)
|
1170
|
-
norm_template(arr_norm[unpadded_slice], template_mask,
|
913
|
+
norm_template(arr_norm[unpadded_slice], template_mask, n_obs, axis=axes)
|
1171
914
|
|
1172
|
-
ft_sqz = rfftn(arr_norm, ft_sqz)
|
1173
|
-
|
1174
|
-
arr = irfftn(ft_temp, arr)
|
915
|
+
ft_sqz = be.rfftn(arr_norm, out=ft_sqz, axes=axes, s=shape)
|
916
|
+
arr = _correlate_fts(ft_target, ft_sqz, ft_temp, arr, shape, axes)
|
1175
917
|
|
1176
|
-
arr =
|
1177
|
-
arr =
|
1178
|
-
|
918
|
+
arr = norm_sub(arr, numerator, out=arr)
|
919
|
+
arr = norm_mul(arr, inv_denominator, out=arr)
|
920
|
+
arr = norm_mask(arr, score_mask, out=arr)
|
921
|
+
|
922
|
+
callback(arr, rotation_matrix=rotation)
|
1179
923
|
|
1180
924
|
return callback
|
1181
925
|
|
1182
926
|
|
927
|
+
def _get_batch_dim(target, template):
|
928
|
+
target_batch, template_batch = [], []
|
929
|
+
for i in range(len(target.shape)):
|
930
|
+
if target.shape[i] == 1 and template.shape[i] != 1:
|
931
|
+
template_batch.append(i)
|
932
|
+
if target.shape[i] != 1 and template.shape[i] == 1:
|
933
|
+
target_batch.append(i)
|
934
|
+
|
935
|
+
return target_batch, template_batch
|
936
|
+
|
937
|
+
|
938
|
+
def _correlate_fts(ft_tar, ft_tmpl, ft_buffer, real_buffer, fast_shape, axes=None):
|
939
|
+
ft_buffer = be.multiply(ft_tar, ft_tmpl, out=ft_buffer)
|
940
|
+
return be.irfftn(ft_buffer, out=real_buffer, s=fast_shape, axes=axes)
|
941
|
+
|
942
|
+
|
943
|
+
def _create_filter_func(
|
944
|
+
arr_shape: Tuple[int],
|
945
|
+
filter_shape: BackendArray,
|
946
|
+
arr_padded: bool = False,
|
947
|
+
axes=None,
|
948
|
+
) -> Callable:
|
949
|
+
"""
|
950
|
+
Configure template filtering function for Fourier transforms.
|
951
|
+
|
952
|
+
Conceptually we distinguish between three cases. The base case
|
953
|
+
is that both template and the corresponding filter have the same
|
954
|
+
shape. Padding is used when the template filter is larger than
|
955
|
+
the template, for instance to better resolve Fourier filters. Finally
|
956
|
+
this function also handles the case when a filter is supposed to be
|
957
|
+
broadcasted over the template batch dimension.
|
958
|
+
|
959
|
+
Parameters
|
960
|
+
----------
|
961
|
+
arr_shape : tuple of ints
|
962
|
+
Shape of the array to be filtered.
|
963
|
+
filter_shape : BackendArray
|
964
|
+
Precomputed filter to apply in the frequency domain.
|
965
|
+
arr_padded : bool, optional
|
966
|
+
Whether the input template is padded and will need to be cropped
|
967
|
+
to arr_shape prior to filter applications. Defaults to False.
|
968
|
+
axes : tuple of ints, optional
|
969
|
+
Axes to perform Fourier transform over.
|
970
|
+
|
971
|
+
Returns
|
972
|
+
-------
|
973
|
+
Callable
|
974
|
+
Filter function with parameters template, ft_temp and template_filter.
|
975
|
+
"""
|
976
|
+
if filter_shape == (1,):
|
977
|
+
return conditional_execute(identity, execute_operation=True)
|
978
|
+
|
979
|
+
# Default case, all shapes are correctly matched
|
980
|
+
def _apply_filter(template, ft_temp, template_filter):
|
981
|
+
ft_temp = be.rfftn(template, out=ft_temp, s=template.shape)
|
982
|
+
ft_temp = be.multiply(ft_temp, template_filter, out=ft_temp)
|
983
|
+
return be.irfftn(ft_temp, out=template, s=template.shape)
|
984
|
+
|
985
|
+
if not arr_padded:
|
986
|
+
return _apply_filter
|
987
|
+
|
988
|
+
# Array is padded but filter is w.r.t to the original template
|
989
|
+
real_subset = tuple(slice(0, x) for x in arr_shape)
|
990
|
+
_template = be.zeros(arr_shape, be._float_dtype)
|
991
|
+
_ft_temp = be.zeros(filter_shape, be._complex_dtype)
|
992
|
+
|
993
|
+
def _apply_filter_subset(template, ft_temp, template_filter):
|
994
|
+
_template[:] = template[real_subset]
|
995
|
+
template[real_subset] = _apply_filter(_template, _ft_temp, template_filter)
|
996
|
+
return template
|
997
|
+
|
998
|
+
return _apply_filter_subset
|
999
|
+
|
1000
|
+
|
1183
1001
|
MATCHING_EXHAUSTIVE_REGISTER = {
|
1184
1002
|
"CC": (cc_setup, corr_scoring),
|
1185
1003
|
"LCC": (lcc_setup, corr_scoring),
|
@@ -1188,6 +1006,6 @@ MATCHING_EXHAUSTIVE_REGISTER = {
|
|
1188
1006
|
"FLCSphericalMask": (flcSphericalMask_setup, corr_scoring),
|
1189
1007
|
"FLC": (flc_setup, flc_scoring),
|
1190
1008
|
"MCC": (mcc_setup, mcc_scoring),
|
1191
|
-
"
|
1009
|
+
"batchFLCSphericalMask": (flcSphericalMask_setup, corr_scoring2),
|
1192
1010
|
"batchFLC": (flc_setup, flc_scoring2),
|
1193
1011
|
}
|