pytme 0.2.0b0__cp311-cp311-macosx_14_0_arm64.whl → 0.2.1__cp311-cp311-macosx_14_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pytme-0.2.0b0.data → pytme-0.2.1.data}/scripts/match_template.py +473 -140
- {pytme-0.2.0b0.data → pytme-0.2.1.data}/scripts/postprocess.py +107 -49
- {pytme-0.2.0b0.data → pytme-0.2.1.data}/scripts/preprocessor_gui.py +4 -1
- {pytme-0.2.0b0.dist-info → pytme-0.2.1.dist-info}/METADATA +2 -2
- pytme-0.2.1.dist-info/RECORD +73 -0
- scripts/extract_candidates.py +117 -85
- scripts/match_template.py +473 -140
- scripts/match_template_filters.py +458 -169
- scripts/postprocess.py +107 -49
- scripts/preprocessor_gui.py +4 -1
- scripts/refine_matches.py +364 -160
- tme/__version__.py +1 -1
- tme/analyzer.py +278 -148
- tme/backends/__init__.py +1 -0
- tme/backends/cupy_backend.py +20 -13
- tme/backends/jax_backend.py +218 -0
- tme/backends/matching_backend.py +25 -10
- tme/backends/mlx_backend.py +13 -9
- tme/backends/npfftw_backend.py +22 -12
- tme/backends/pytorch_backend.py +20 -9
- tme/density.py +85 -64
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/matching_data.py +86 -60
- tme/matching_exhaustive.py +245 -166
- tme/matching_optimization.py +137 -69
- tme/matching_utils.py +1 -1
- tme/orientations.py +175 -55
- tme/preprocessing/__init__.py +2 -0
- tme/preprocessing/_utils.py +188 -0
- tme/preprocessing/composable_filter.py +31 -0
- tme/preprocessing/compose.py +51 -0
- tme/preprocessing/frequency_filters.py +378 -0
- tme/preprocessing/tilt_series.py +1017 -0
- tme/preprocessor.py +17 -7
- tme/structure.py +4 -1
- pytme-0.2.0b0.dist-info/RECORD +0 -66
- {pytme-0.2.0b0.data → pytme-0.2.1.data}/scripts/estimate_ram_usage.py +0 -0
- {pytme-0.2.0b0.data → pytme-0.2.1.data}/scripts/preprocess.py +0 -0
- {pytme-0.2.0b0.dist-info → pytme-0.2.1.dist-info}/LICENSE +0 -0
- {pytme-0.2.0b0.dist-info → pytme-0.2.1.dist-info}/WHEEL +0 -0
- {pytme-0.2.0b0.dist-info → pytme-0.2.1.dist-info}/entry_points.txt +0 -0
- {pytme-0.2.0b0.dist-info → pytme-0.2.1.dist-info}/top_level.txt +0 -0
tme/matching_exhaustive.py
CHANGED
@@ -7,7 +7,6 @@
|
|
7
7
|
import os
|
8
8
|
import sys
|
9
9
|
import warnings
|
10
|
-
from copy import deepcopy
|
11
10
|
from itertools import product
|
12
11
|
from typing import Callable, Tuple, Dict
|
13
12
|
from functools import wraps
|
@@ -19,14 +18,12 @@ from scipy.ndimage import laplace
|
|
19
18
|
|
20
19
|
from .analyzer import MaxScoreOverRotations
|
21
20
|
from .matching_utils import (
|
22
|
-
apply_convolution_mode,
|
23
21
|
handle_traceback,
|
24
22
|
split_numpy_array_slices,
|
25
23
|
conditional_execute,
|
26
24
|
)
|
27
25
|
from .matching_memory import MatchingMemoryUsage, MATCHING_MEMORY_REGISTRY
|
28
|
-
|
29
|
-
from .preprocessor import Preprocessor
|
26
|
+
from .preprocessing import Compose
|
30
27
|
from .matching_data import MatchingData
|
31
28
|
from .backends import backend
|
32
29
|
from .types import NDArray, CallbackClass
|
@@ -64,9 +61,6 @@ def normalize_under_mask(template: NDArray, mask: NDArray, mask_intensity) -> No
|
|
64
61
|
----------
|
65
62
|
.. [1] T. Hrabe, Y. Chen, S. Pfeffer, L. Kuhn Cuellar, A.-V. Mangold,
|
66
63
|
and F. Förster, J. Struct. Biol. 178, 177 (2012).
|
67
|
-
.. [2] M. L. Chaillet, G. van der Schot, I. Gubins, S. Roet,
|
68
|
-
R. C. Veltkamp, and F. Förster, Int. J. Mol. Sci. 24,
|
69
|
-
13375 (2023)
|
70
64
|
|
71
65
|
Returns
|
72
66
|
-------
|
@@ -87,6 +81,16 @@ def normalize_under_mask(template: NDArray, mask: NDArray, mask_intensity) -> No
|
|
87
81
|
return None
|
88
82
|
|
89
83
|
|
84
|
+
def _normalize_under_mask_overflow_safe(
|
85
|
+
template: NDArray, mask: NDArray, mask_intensity
|
86
|
+
) -> None:
|
87
|
+
_template = backend.astype(template, backend._overflow_safe_dtype)
|
88
|
+
_mask = backend.astype(mask, backend._overflow_safe_dtype)
|
89
|
+
normalize_under_mask(template=_template, mask=_mask, mask_intensity=mask_intensity)
|
90
|
+
template[:] = backend.astype(_template, template.dtype)
|
91
|
+
return None
|
92
|
+
|
93
|
+
|
90
94
|
def apply_filter(ft_template, template_filter):
|
91
95
|
# This is an approximation to applying the mask, irfftn, normalize, rfftn
|
92
96
|
std_before = backend.std(ft_template)
|
@@ -103,8 +107,6 @@ def cc_setup(
|
|
103
107
|
target: NDArray,
|
104
108
|
fast_shape: Tuple[int],
|
105
109
|
fast_ft_shape: Tuple[int],
|
106
|
-
real_dtype: type,
|
107
|
-
complex_dtype: type,
|
108
110
|
shared_memory_handler: Callable,
|
109
111
|
callback_class: Callable,
|
110
112
|
callback_class_args: Dict,
|
@@ -124,7 +126,7 @@ def cc_setup(
|
|
124
126
|
:py:meth:`corr_scoring`
|
125
127
|
:py:class:`tme.matching_optimization.CrossCorrelation`
|
126
128
|
"""
|
127
|
-
|
129
|
+
real_dtype, complex_dtype = backend._float_dtype, backend._complex_dtype
|
128
130
|
target_pad = backend.topleft_pad(target, fast_shape)
|
129
131
|
target_pad_ft = backend.preallocate_array(fast_ft_shape, complex_dtype)
|
130
132
|
|
@@ -141,28 +143,26 @@ def cc_setup(
|
|
141
143
|
arr=backend.preallocate_array(1, real_dtype) + 1,
|
142
144
|
shared_memory_handler=shared_memory_handler,
|
143
145
|
)
|
144
|
-
|
146
|
+
numerator_buffer = backend.arr_to_sharedarr(
|
145
147
|
arr=backend.preallocate_array(1, real_dtype),
|
146
148
|
shared_memory_handler=shared_memory_handler,
|
147
149
|
)
|
148
150
|
|
149
151
|
target_ft_tuple = (target_ft_out, fast_ft_shape, complex_dtype)
|
150
|
-
template_tuple = (template_out, template.shape,
|
152
|
+
template_tuple = (template_out, template.shape, template.dtype)
|
151
153
|
|
152
154
|
inv_denominator_tuple = (inv_denominator_buffer, (1,), real_dtype)
|
153
|
-
|
155
|
+
numerator_tuple = (numerator_buffer, (1,), real_dtype)
|
154
156
|
|
155
157
|
ret = {
|
156
158
|
"template": template_tuple,
|
157
159
|
"ft_target": target_ft_tuple,
|
158
160
|
"inv_denominator": inv_denominator_tuple,
|
159
|
-
"
|
160
|
-
"targetshape":
|
161
|
+
"numerator": numerator_tuple,
|
162
|
+
"targetshape": target.shape,
|
161
163
|
"templateshape": template.shape,
|
162
164
|
"fast_shape": fast_shape,
|
163
165
|
"fast_ft_shape": fast_ft_shape,
|
164
|
-
"real_dtype": real_dtype,
|
165
|
-
"complex_dtype": complex_dtype,
|
166
166
|
"callback_class": callback_class,
|
167
167
|
"callback_class_args": callback_class_args,
|
168
168
|
"use_memmap": kwargs.get("use_memmap", False),
|
@@ -200,8 +200,6 @@ def corr_setup(
|
|
200
200
|
target: NDArray,
|
201
201
|
fast_shape: Tuple[int],
|
202
202
|
fast_ft_shape: Tuple[int],
|
203
|
-
real_dtype: type,
|
204
|
-
complex_dtype: type,
|
205
203
|
shared_memory_handler: Callable,
|
206
204
|
callback_class: Callable,
|
207
205
|
callback_class_args: Dict,
|
@@ -233,6 +231,7 @@ def corr_setup(
|
|
233
231
|
:py:meth:`corr_scoring`
|
234
232
|
:py:class:`tme.matching_optimization.NormalizedCrossCorrelation`.
|
235
233
|
"""
|
234
|
+
real_dtype, complex_dtype = backend._float_dtype, backend._complex_dtype
|
236
235
|
target_pad = backend.topleft_pad(target, fast_shape)
|
237
236
|
|
238
237
|
# The exact composition of the denominator is debatable
|
@@ -266,14 +265,14 @@ def corr_setup(
|
|
266
265
|
template_mean = backend.divide(template_mean, n_observations)
|
267
266
|
template_ssd = backend.sum(
|
268
267
|
backend.square(
|
269
|
-
backend.multiply(backend.
|
268
|
+
backend.multiply(backend.subtract(template, template_mean), template_mask)
|
270
269
|
)
|
271
270
|
)
|
272
|
-
template_volume = np.prod(template.shape)
|
271
|
+
template_volume = np.prod(tuple(int(x) for x in template.shape))
|
273
272
|
backend.multiply(template, template_mask, out=template)
|
274
273
|
|
275
|
-
# Final numerator is score -
|
276
|
-
|
274
|
+
# Final numerator is score - numerator
|
275
|
+
numerator = backend.multiply(target_window_sum, template_mean)
|
277
276
|
|
278
277
|
# Compute denominator
|
279
278
|
backend.multiply(target_window_sum, target_window_sum, out=target_window_sum)
|
@@ -300,29 +299,27 @@ def corr_setup(
|
|
300
299
|
inv_denominator_buffer = backend.arr_to_sharedarr(
|
301
300
|
arr=inv_denominator, shared_memory_handler=shared_memory_handler
|
302
301
|
)
|
303
|
-
|
304
|
-
arr=
|
302
|
+
numerator_buffer = backend.arr_to_sharedarr(
|
303
|
+
arr=numerator, shared_memory_handler=shared_memory_handler
|
305
304
|
)
|
306
305
|
|
307
|
-
template_tuple = (template_buffer,
|
306
|
+
template_tuple = (template_buffer, template.shape, template.dtype)
|
308
307
|
target_ft_tuple = (target_ft_buffer, fast_ft_shape, complex_dtype)
|
309
308
|
|
310
309
|
inv_denominator_tuple = (inv_denominator_buffer, fast_shape, real_dtype)
|
311
|
-
|
310
|
+
numerator_tuple = (numerator_buffer, fast_shape, real_dtype)
|
312
311
|
|
313
|
-
ft_target, inv_denominator,
|
312
|
+
ft_target, inv_denominator, numerator = None, None, None
|
314
313
|
|
315
314
|
ret = {
|
316
315
|
"template": template_tuple,
|
317
316
|
"ft_target": target_ft_tuple,
|
318
317
|
"inv_denominator": inv_denominator_tuple,
|
319
|
-
"
|
320
|
-
"targetshape":
|
321
|
-
"templateshape":
|
318
|
+
"numerator": numerator_tuple,
|
319
|
+
"targetshape": target.shape,
|
320
|
+
"templateshape": template.shape,
|
322
321
|
"fast_shape": fast_shape,
|
323
322
|
"fast_ft_shape": fast_ft_shape,
|
324
|
-
"real_dtype": real_dtype,
|
325
|
-
"complex_dtype": complex_dtype,
|
326
323
|
"callback_class": callback_class,
|
327
324
|
"callback_class_args": callback_class_args,
|
328
325
|
"template_mean": kwargs.get("template_mean", template_mean),
|
@@ -374,8 +371,6 @@ def flc_setup(
|
|
374
371
|
target: NDArray,
|
375
372
|
fast_shape: Tuple[int],
|
376
373
|
fast_ft_shape: Tuple[int],
|
377
|
-
real_dtype: type,
|
378
|
-
complex_dtype: type,
|
379
374
|
shared_memory_handler: Callable,
|
380
375
|
callback_class: Callable,
|
381
376
|
callback_class_args: Dict,
|
@@ -414,8 +409,8 @@ def flc_setup(
|
|
414
409
|
target_pad = backend.topleft_pad(target, fast_shape)
|
415
410
|
|
416
411
|
# Target and squared target window sums
|
417
|
-
ft_target = backend.preallocate_array(fast_ft_shape,
|
418
|
-
ft_target2 = backend.preallocate_array(fast_ft_shape,
|
412
|
+
ft_target = backend.preallocate_array(fast_ft_shape, backend._complex_dtype)
|
413
|
+
ft_target2 = backend.preallocate_array(fast_ft_shape, backend._complex_dtype)
|
419
414
|
rfftn(target_pad, ft_target)
|
420
415
|
backend.square(target_pad, out=target_pad)
|
421
416
|
rfftn(target_pad, ft_target2)
|
@@ -439,11 +434,15 @@ def flc_setup(
|
|
439
434
|
arr=template_mask, shared_memory_handler=shared_memory_handler
|
440
435
|
)
|
441
436
|
|
442
|
-
template_tuple = (template_buffer, template.shape,
|
443
|
-
template_mask_tuple = (
|
437
|
+
template_tuple = (template_buffer, template.shape, template.dtype)
|
438
|
+
template_mask_tuple = (
|
439
|
+
template_mask_buffer,
|
440
|
+
template_mask.shape,
|
441
|
+
template_mask.dtype,
|
442
|
+
)
|
444
443
|
|
445
|
-
target_ft_tuple = (ft_target, fast_ft_shape,
|
446
|
-
target_ft2_tuple = (ft_target2, fast_ft_shape,
|
444
|
+
target_ft_tuple = (ft_target, fast_ft_shape, backend._complex_dtype)
|
445
|
+
target_ft2_tuple = (ft_target2, fast_ft_shape, backend._complex_dtype)
|
447
446
|
|
448
447
|
ret = {
|
449
448
|
"template": template_tuple,
|
@@ -454,8 +453,6 @@ def flc_setup(
|
|
454
453
|
"templateshape": template.shape,
|
455
454
|
"fast_shape": fast_shape,
|
456
455
|
"fast_ft_shape": fast_ft_shape,
|
457
|
-
"real_dtype": real_dtype,
|
458
|
-
"complex_dtype": complex_dtype,
|
459
456
|
"callback_class": callback_class,
|
460
457
|
"callback_class_args": callback_class_args,
|
461
458
|
}
|
@@ -471,8 +468,6 @@ def flcSphericalMask_setup(
|
|
471
468
|
target: NDArray,
|
472
469
|
fast_shape: Tuple[int],
|
473
470
|
fast_ft_shape: Tuple[int],
|
474
|
-
real_dtype: type,
|
475
|
-
complex_dtype: type,
|
476
471
|
shared_memory_handler: Callable,
|
477
472
|
callback_class: Callable,
|
478
473
|
callback_class_args: Dict,
|
@@ -509,6 +504,7 @@ def flcSphericalMask_setup(
|
|
509
504
|
--------
|
510
505
|
:py:meth:`corr_scoring`
|
511
506
|
"""
|
507
|
+
real_dtype, complex_dtype = backend._float_dtype, backend._complex_dtype
|
512
508
|
target_pad = backend.topleft_pad(target, fast_shape)
|
513
509
|
|
514
510
|
# Target and squared target window sums
|
@@ -518,10 +514,14 @@ def flcSphericalMask_setup(
|
|
518
514
|
|
519
515
|
temp = backend.preallocate_array(fast_shape, real_dtype)
|
520
516
|
temp2 = backend.preallocate_array(fast_shape, real_dtype)
|
521
|
-
|
517
|
+
numerator = backend.preallocate_array(1, real_dtype)
|
522
518
|
|
523
|
-
|
524
|
-
|
519
|
+
n_observations, norm_func = backend.sum(template_mask), normalize_under_mask
|
520
|
+
if backend.datatype_bytes(template_mask.dtype) == 2:
|
521
|
+
norm_func = _normalize_under_mask_overflow_safe
|
522
|
+
n_observations = backend.sum(
|
523
|
+
backend.astype(template_mask, backend._overflow_safe_dtype)
|
524
|
+
)
|
525
525
|
|
526
526
|
template_mask_pad = backend.topleft_pad(template_mask, fast_shape)
|
527
527
|
rfftn(template_mask_pad, ft_template_mask)
|
@@ -544,15 +544,11 @@ def flcSphericalMask_setup(
|
|
544
544
|
backend.sqrt(temp, out=temp)
|
545
545
|
backend.multiply(temp, n_observations, out=temp)
|
546
546
|
|
547
|
-
tol = 1e3 * eps * backend.max(backend.abs(temp))
|
548
|
-
nonzero_indices = temp > tol
|
549
|
-
|
550
547
|
backend.fill(temp2, 0)
|
548
|
+
nonzero_indices = temp > backend.eps(real_dtype)
|
551
549
|
temp2[nonzero_indices] = 1 / temp[nonzero_indices]
|
552
550
|
|
553
|
-
|
554
|
-
template=template, mask=template_mask, mask_intensity=backend.sum(template_mask)
|
555
|
-
)
|
551
|
+
norm_func(template=template, mask=template_mask, mask_intensity=n_observations)
|
556
552
|
|
557
553
|
template_buffer = backend.arr_to_sharedarr(
|
558
554
|
arr=template, shared_memory_handler=shared_memory_handler
|
@@ -566,29 +562,27 @@ def flcSphericalMask_setup(
|
|
566
562
|
inv_denominator_buffer = backend.arr_to_sharedarr(
|
567
563
|
arr=temp2, shared_memory_handler=shared_memory_handler
|
568
564
|
)
|
569
|
-
|
570
|
-
arr=
|
565
|
+
numerator_buffer = backend.arr_to_sharedarr(
|
566
|
+
arr=numerator, shared_memory_handler=shared_memory_handler
|
571
567
|
)
|
572
568
|
|
573
|
-
template_tuple = (template_buffer, template.shape,
|
574
|
-
template_mask_tuple = (template_mask_buffer, template.shape,
|
569
|
+
template_tuple = (template_buffer, template.shape, template.dtype)
|
570
|
+
template_mask_tuple = (template_mask_buffer, template.shape, template_mask.dtype)
|
575
571
|
target_ft_tuple = (target_ft_buffer, fast_ft_shape, complex_dtype)
|
576
572
|
|
577
573
|
inv_denominator_tuple = (inv_denominator_buffer, fast_shape, real_dtype)
|
578
|
-
|
574
|
+
numerator_tuple = (numerator_buffer, (1,), real_dtype)
|
579
575
|
|
580
576
|
ret = {
|
581
577
|
"template": template_tuple,
|
582
578
|
"template_mask": template_mask_tuple,
|
583
579
|
"ft_target": target_ft_tuple,
|
584
580
|
"inv_denominator": inv_denominator_tuple,
|
585
|
-
"
|
581
|
+
"numerator": numerator_tuple,
|
586
582
|
"targetshape": target.shape,
|
587
583
|
"templateshape": template.shape,
|
588
584
|
"fast_shape": fast_shape,
|
589
585
|
"fast_ft_shape": fast_ft_shape,
|
590
|
-
"real_dtype": real_dtype,
|
591
|
-
"complex_dtype": complex_dtype,
|
592
586
|
"callback_class": callback_class,
|
593
587
|
"callback_class_args": callback_class_args,
|
594
588
|
}
|
@@ -605,8 +599,6 @@ def mcc_setup(
|
|
605
599
|
target_mask: NDArray,
|
606
600
|
fast_shape: Tuple[int],
|
607
601
|
fast_ft_shape: Tuple[int],
|
608
|
-
real_dtype: type,
|
609
|
-
complex_dtype: type,
|
610
602
|
shared_memory_handler: Callable,
|
611
603
|
callback_class: Callable,
|
612
604
|
callback_class_args: Dict,
|
@@ -643,6 +635,7 @@ def mcc_setup(
|
|
643
635
|
:py:meth:`mcc_scoring`
|
644
636
|
:py:class:`tme.matching_optimization.MaskedCrossCorrelation`
|
645
637
|
"""
|
638
|
+
real_dtype, complex_dtype = backend._float_dtype, backend._complex_dtype
|
646
639
|
target = backend.multiply(target, target_mask > 0, out=target)
|
647
640
|
|
648
641
|
target_pad = backend.topleft_pad(target, fast_shape)
|
@@ -673,8 +666,8 @@ def mcc_setup(
|
|
673
666
|
arr=template_mask, shared_memory_handler=shared_memory_handler
|
674
667
|
)
|
675
668
|
|
676
|
-
template_tuple = (template_buffer, template.shape,
|
677
|
-
template_mask_tuple = (template_mask_buffer, template.shape,
|
669
|
+
template_tuple = (template_buffer, template.shape, template.dtype)
|
670
|
+
template_mask_tuple = (template_mask_buffer, template.shape, template_mask.dtype)
|
678
671
|
|
679
672
|
target_ft_tuple = (target_ft_buffer, fast_ft_shape, complex_dtype)
|
680
673
|
target_ft2_tuple = (target_ft2_buffer, fast_ft_shape, complex_dtype)
|
@@ -690,8 +683,6 @@ def mcc_setup(
|
|
690
683
|
"templateshape": template.shape,
|
691
684
|
"fast_shape": fast_shape,
|
692
685
|
"fast_ft_shape": fast_ft_shape,
|
693
|
-
"real_dtype": real_dtype,
|
694
|
-
"complex_dtype": complex_dtype,
|
695
686
|
"callback_class": callback_class,
|
696
687
|
"callback_class_args": callback_class_args,
|
697
688
|
}
|
@@ -703,15 +694,13 @@ def corr_scoring(
|
|
703
694
|
template: Tuple[type, Tuple[int], type],
|
704
695
|
ft_target: Tuple[type, Tuple[int], type],
|
705
696
|
inv_denominator: Tuple[type, Tuple[int], type],
|
706
|
-
|
697
|
+
numerator: Tuple[type, Tuple[int], type],
|
707
698
|
template_filter: Tuple[type, Tuple[int], type],
|
708
699
|
targetshape: Tuple[int],
|
709
700
|
templateshape: Tuple[int],
|
710
701
|
fast_shape: Tuple[int],
|
711
702
|
fast_ft_shape: Tuple[int],
|
712
703
|
rotations: NDArray,
|
713
|
-
real_dtype: type,
|
714
|
-
complex_dtype: type,
|
715
704
|
callback_class: CallbackClass,
|
716
705
|
callback_class_args: Dict,
|
717
706
|
interpolation_order: int,
|
@@ -723,7 +712,7 @@ def corr_scoring(
|
|
723
712
|
|
724
713
|
.. math::
|
725
714
|
|
726
|
-
(CC(f,g) -
|
715
|
+
(CC(f,g) - numerator) \\cdot inv\\_denominator
|
727
716
|
|
728
717
|
Parameters
|
729
718
|
----------
|
@@ -735,8 +724,8 @@ def corr_scoring(
|
|
735
724
|
inv_denominator : Tuple[type, Tuple[int], type]
|
736
725
|
Tuple containing a pointer to the inverse denominator data, its shape, and its
|
737
726
|
datatype.
|
738
|
-
|
739
|
-
Tuple containing a pointer to the
|
727
|
+
numerator : Tuple[type, Tuple[int], type]
|
728
|
+
Tuple containing a pointer to the numerator data, its shape, and its datatype.
|
740
729
|
fast_shape : Tuple[int]
|
741
730
|
The shape for fast Fourier transform.
|
742
731
|
fast_ft_shape : Tuple[int]
|
@@ -770,6 +759,8 @@ def corr_scoring(
|
|
770
759
|
:py:meth:`cam_setup`
|
771
760
|
:py:meth:`flcSphericalMask_setup`
|
772
761
|
"""
|
762
|
+
real_dtype, complex_dtype = backend._float_dtype, backend._complex_dtype
|
763
|
+
|
773
764
|
callback = callback_class
|
774
765
|
if callback_class is not None and isinstance(callback_class, type):
|
775
766
|
callback = callback_class(**callback_class_args)
|
@@ -778,15 +769,33 @@ def corr_scoring(
|
|
778
769
|
template = backend.sharedarr_to_arr(template_buffer, template_shape, template_dtype)
|
779
770
|
ft_target = backend.sharedarr_to_arr(*ft_target)
|
780
771
|
inv_denominator = backend.sharedarr_to_arr(*inv_denominator)
|
781
|
-
|
772
|
+
numerator = backend.sharedarr_to_arr(*numerator)
|
782
773
|
template_filter = backend.sharedarr_to_arr(*template_filter)
|
783
774
|
|
775
|
+
norm_func = normalize_under_mask
|
784
776
|
norm_template, template_mask, mask_sum = False, 1, 1
|
785
777
|
if "template_mask" in kwargs:
|
786
|
-
|
787
|
-
|
788
|
-
)
|
789
|
-
|
778
|
+
norm_template = True
|
779
|
+
template_mask = backend.sharedarr_to_arr(*kwargs["template_mask"])
|
780
|
+
mask_sum = backend.sum(template_mask)
|
781
|
+
if backend.datatype_bytes(template_mask.dtype) == 2:
|
782
|
+
norm_func = _normalize_under_mask_overflow_safe
|
783
|
+
mask_sum = backend.sum(
|
784
|
+
backend.astype(template_mask, backend._overflow_safe_dtype)
|
785
|
+
)
|
786
|
+
|
787
|
+
norm_template = conditional_execute(norm_func, norm_template)
|
788
|
+
norm_numerator = (backend.sum(numerator) != 0) & (backend.size(numerator) != 1)
|
789
|
+
norm_func_numerator = conditional_execute(backend.subtract, norm_numerator)
|
790
|
+
|
791
|
+
norm_denominator = (backend.sum(inv_denominator) != 1) & (
|
792
|
+
backend.size(inv_denominator) != 1
|
793
|
+
)
|
794
|
+
norm_func_denominator = conditional_execute(backend.multiply, norm_denominator)
|
795
|
+
callback_func = conditional_execute(callback, callback_class is not None)
|
796
|
+
template_filter_func = conditional_execute(
|
797
|
+
apply_filter, backend.size(template_filter) != 1
|
798
|
+
)
|
790
799
|
|
791
800
|
arr = backend.preallocate_array(fast_shape, real_dtype)
|
792
801
|
ft_temp = backend.preallocate_array(fast_ft_shape, complex_dtype)
|
@@ -801,19 +810,6 @@ def corr_scoring(
|
|
801
810
|
temp_fft=ft_temp,
|
802
811
|
)
|
803
812
|
|
804
|
-
norm_numerator = (backend.sum(numerator2) != 0) & (backend.size(numerator2) != 1)
|
805
|
-
norm_denominator = (backend.sum(inv_denominator) != 1) & (
|
806
|
-
backend.size(inv_denominator) != 1
|
807
|
-
)
|
808
|
-
|
809
|
-
norm_template = conditional_execute(normalize_under_mask, norm_template)
|
810
|
-
callback_func = conditional_execute(callback, callback_class is not None)
|
811
|
-
norm_func_numerator = conditional_execute(backend.subtract, norm_numerator)
|
812
|
-
norm_func_denominator = conditional_execute(backend.multiply, norm_denominator)
|
813
|
-
template_filter_func = conditional_execute(
|
814
|
-
apply_filter, backend.size(template_filter) != 1
|
815
|
-
)
|
816
|
-
|
817
813
|
unpadded_slice = tuple(slice(0, stop) for stop in template.shape)
|
818
814
|
for index in range(rotations.shape[0]):
|
819
815
|
rotation = rotations[index]
|
@@ -822,10 +818,9 @@ def corr_scoring(
|
|
822
818
|
arr=template,
|
823
819
|
rotation_matrix=rotation,
|
824
820
|
out=arr,
|
825
|
-
use_geometric_center=
|
821
|
+
use_geometric_center=True,
|
826
822
|
order=interpolation_order,
|
827
823
|
)
|
828
|
-
|
829
824
|
norm_template(arr[unpadded_slice], template_mask, mask_sum)
|
830
825
|
|
831
826
|
rfftn(arr, ft_temp)
|
@@ -833,7 +828,7 @@ def corr_scoring(
|
|
833
828
|
backend.multiply(ft_target, ft_temp, out=ft_temp)
|
834
829
|
irfftn(ft_temp, arr)
|
835
830
|
|
836
|
-
norm_func_numerator(arr,
|
831
|
+
norm_func_numerator(arr, numerator, out=arr)
|
837
832
|
norm_func_denominator(arr, inv_denominator, out=arr)
|
838
833
|
|
839
834
|
callback_func(
|
@@ -857,8 +852,6 @@ def flc_scoring(
|
|
857
852
|
fast_shape: Tuple[int],
|
858
853
|
fast_ft_shape: Tuple[int],
|
859
854
|
rotations: NDArray,
|
860
|
-
real_dtype: type,
|
861
|
-
complex_dtype: type,
|
862
855
|
callback_class: CallbackClass,
|
863
856
|
callback_class_args: Dict,
|
864
857
|
interpolation_order: int,
|
@@ -890,6 +883,8 @@ def flc_scoring(
|
|
890
883
|
.. [2] T. Hrabe, Y. Chen, S. Pfeffer, L. Kuhn Cuellar, A.-V. Mangold,
|
891
884
|
and F. Förster, J. Struct. Biol. 178, 177 (2012).
|
892
885
|
"""
|
886
|
+
real_dtype, complex_dtype = backend._float_dtype, backend._complex_dtype
|
887
|
+
|
893
888
|
callback = callback_class
|
894
889
|
if callback_class is not None and isinstance(callback_class, type):
|
895
890
|
callback = callback_class(**callback_class_args)
|
@@ -933,7 +928,7 @@ def flc_scoring(
|
|
933
928
|
rotation_matrix=rotation,
|
934
929
|
out=arr,
|
935
930
|
out_mask=temp,
|
936
|
-
use_geometric_center=
|
931
|
+
use_geometric_center=True,
|
937
932
|
order=interpolation_order,
|
938
933
|
)
|
939
934
|
# Given the amount of FFTs, might aswell normalize properly
|
@@ -966,7 +961,8 @@ def flc_scoring(
|
|
966
961
|
backend.multiply(ft_target, ft_temp, out=ft_temp)
|
967
962
|
irfftn(ft_temp, arr)
|
968
963
|
|
969
|
-
tol =
|
964
|
+
tol = eps
|
965
|
+
# tol = 1e3 * eps * backend.max(backend.abs(temp))
|
970
966
|
nonzero_indices = temp > tol
|
971
967
|
backend.fill(temp2, 0)
|
972
968
|
temp2[nonzero_indices] = arr[nonzero_indices] / temp[nonzero_indices]
|
@@ -992,8 +988,6 @@ def flc_scoring2(
|
|
992
988
|
fast_shape: Tuple[int],
|
993
989
|
fast_ft_shape: Tuple[int],
|
994
990
|
rotations: NDArray,
|
995
|
-
real_dtype: type,
|
996
|
-
complex_dtype: type,
|
997
991
|
callback_class: CallbackClass,
|
998
992
|
callback_class_args: Dict,
|
999
993
|
interpolation_order: int,
|
@@ -1025,6 +1019,7 @@ def flc_scoring2(
|
|
1025
1019
|
.. [2] T. Hrabe, Y. Chen, S. Pfeffer, L. Kuhn Cuellar, A.-V. Mangold,
|
1026
1020
|
and F. Förster, J. Struct. Biol. 178, 177 (2012).
|
1027
1021
|
"""
|
1022
|
+
real_dtype, complex_dtype = backend._float_dtype, backend._complex_dtype
|
1028
1023
|
callback = callback_class
|
1029
1024
|
if callback_class is not None and isinstance(callback_class, type):
|
1030
1025
|
callback = callback_class(**callback_class_args)
|
@@ -1083,7 +1078,7 @@ def flc_scoring2(
|
|
1083
1078
|
rotation_matrix=rotation,
|
1084
1079
|
out=arr[squeeze],
|
1085
1080
|
out_mask=temp[squeeze],
|
1086
|
-
use_geometric_center=
|
1081
|
+
use_geometric_center=True,
|
1087
1082
|
order=interpolation_order,
|
1088
1083
|
)
|
1089
1084
|
# Given the amount of FFTs, might aswell normalize properly
|
@@ -1115,8 +1110,7 @@ def flc_scoring2(
|
|
1115
1110
|
backend.multiply(ft_target, ft_temp[squeeze_fast_ft], out=ft_denom)
|
1116
1111
|
irfftn(ft_denom, arr)
|
1117
1112
|
|
1118
|
-
|
1119
|
-
nonzero_indices = temp > tol
|
1113
|
+
nonzero_indices = temp > eps
|
1120
1114
|
backend.fill(temp2, 0)
|
1121
1115
|
temp2[nonzero_indices] = arr[nonzero_indices] / temp[nonzero_indices]
|
1122
1116
|
|
@@ -1141,8 +1135,6 @@ def mcc_scoring(
|
|
1141
1135
|
fast_shape: Tuple[int],
|
1142
1136
|
fast_ft_shape: Tuple[int],
|
1143
1137
|
rotations: NDArray,
|
1144
|
-
real_dtype: type,
|
1145
|
-
complex_dtype: type,
|
1146
1138
|
callback_class: CallbackClass,
|
1147
1139
|
callback_class_args: type,
|
1148
1140
|
interpolation_order: int,
|
@@ -1179,6 +1171,7 @@ def mcc_scoring(
|
|
1179
1171
|
--------
|
1180
1172
|
:py:class:`tme.matching_optimization.MaskedCrossCorrelation`
|
1181
1173
|
"""
|
1174
|
+
real_dtype, complex_dtype = backend._float_dtype, backend._complex_dtype
|
1182
1175
|
callback = callback_class
|
1183
1176
|
if callback_class is not None and isinstance(callback_class, type):
|
1184
1177
|
callback = callback_class(**callback_class_args)
|
@@ -1232,7 +1225,7 @@ def mcc_scoring(
|
|
1232
1225
|
rotation_matrix=rotation,
|
1233
1226
|
out=template_rot,
|
1234
1227
|
out_mask=temp,
|
1235
|
-
use_geometric_center=
|
1228
|
+
use_geometric_center=True,
|
1236
1229
|
order=interpolation_order,
|
1237
1230
|
)
|
1238
1231
|
|
@@ -1307,7 +1300,7 @@ def mcc_scoring(
|
|
1307
1300
|
return callback
|
1308
1301
|
|
1309
1302
|
|
1310
|
-
def
|
1303
|
+
def _setup_template_filter_apply_target_filter(
|
1311
1304
|
matching_data: MatchingData,
|
1312
1305
|
rfftn: Callable,
|
1313
1306
|
irfftn: Callable,
|
@@ -1317,28 +1310,36 @@ def _setup_template_filter(
|
|
1317
1310
|
filter_template = isinstance(matching_data.template_filter, Compose)
|
1318
1311
|
filter_target = isinstance(matching_data.target_filter, Compose)
|
1319
1312
|
|
1320
|
-
template_filter = backend.full(
|
1321
|
-
shape=(1,), fill_value=1, dtype=backend._default_dtype
|
1322
|
-
)
|
1313
|
+
template_filter = backend.full(shape=(1,), fill_value=1, dtype=backend._float_dtype)
|
1323
1314
|
|
1324
1315
|
if not filter_template and not filter_target:
|
1325
1316
|
return template_filter
|
1326
1317
|
|
1327
1318
|
target_temp = backend.astype(
|
1328
|
-
backend.topleft_pad(matching_data.target, fast_shape), backend.
|
1319
|
+
backend.topleft_pad(matching_data.target, fast_shape), backend._float_dtype
|
1329
1320
|
)
|
1330
1321
|
target_temp_ft = backend.preallocate_array(fast_ft_shape, backend._complex_dtype)
|
1331
|
-
rfftn(target_temp, target_temp_ft)
|
1332
1322
|
|
1323
|
+
filter_shape = backend.multiply(fast_ft_shape, 1 - matching_data._batch_mask)
|
1324
|
+
filter_shape[filter_shape == 0] = 1
|
1325
|
+
fast_shape = backend.multiply(fast_shape, 1 - matching_data._batch_mask)
|
1326
|
+
fast_shape = fast_shape[fast_shape != 0]
|
1327
|
+
|
1328
|
+
fast_shape = tuple(int(x) for x in fast_shape)
|
1329
|
+
filter_shape = tuple(int(x) for x in filter_shape)
|
1330
|
+
|
1331
|
+
rfftn(target_temp, target_temp_ft)
|
1333
1332
|
if isinstance(matching_data.template_filter, Compose):
|
1334
1333
|
template_filter = matching_data.template_filter(
|
1335
1334
|
shape=fast_shape,
|
1336
1335
|
return_real_fourier=True,
|
1337
1336
|
shape_is_real_fourier=False,
|
1338
1337
|
data_rfft=target_temp_ft,
|
1338
|
+
batch_dimension=matching_data._target_dims,
|
1339
1339
|
)
|
1340
1340
|
template_filter = template_filter["data"]
|
1341
1341
|
template_filter[tuple(0 for _ in range(template_filter.ndim))] = 0
|
1342
|
+
template_filter = backend.reshape(template_filter, filter_shape)
|
1342
1343
|
|
1343
1344
|
if isinstance(matching_data.target_filter, Compose):
|
1344
1345
|
target_filter = matching_data.target_filter(
|
@@ -1347,8 +1348,10 @@ def _setup_template_filter(
|
|
1347
1348
|
shape_is_real_fourier=False,
|
1348
1349
|
data_rfft=target_temp_ft,
|
1349
1350
|
weight_type=None,
|
1351
|
+
batch_dimension=matching_data._target_dims,
|
1350
1352
|
)
|
1351
1353
|
target_filter = target_filter["data"]
|
1354
|
+
target_filter = backend.reshape(target_filter, filter_shape)
|
1352
1355
|
backend.multiply(target_temp_ft, target_filter, out=target_temp_ft)
|
1353
1356
|
|
1354
1357
|
irfftn(target_temp_ft, target_temp)
|
@@ -1395,8 +1398,7 @@ def scan(
|
|
1395
1398
|
**kwargs,
|
1396
1399
|
) -> Tuple:
|
1397
1400
|
"""
|
1398
|
-
Perform template matching
|
1399
|
-
different rotations of template.
|
1401
|
+
Perform template matching.
|
1400
1402
|
|
1401
1403
|
Parameters
|
1402
1404
|
----------
|
@@ -1420,25 +1422,44 @@ def scan(
|
|
1420
1422
|
Order of spline interpolation for rotations.
|
1421
1423
|
jobs_per_callback_class : int, optional
|
1422
1424
|
How many jobs should be processed by a single callback_class instance,
|
1423
|
-
if
|
1425
|
+
if one is provided.
|
1424
1426
|
**kwargs : various
|
1425
|
-
Additional arguments.
|
1427
|
+
Additional keyword arguments.
|
1426
1428
|
|
1427
1429
|
Returns
|
1428
1430
|
-------
|
1429
1431
|
Tuple
|
1430
1432
|
The merged results from callback_class if provided otherwise None.
|
1433
|
+
|
1434
|
+
Examples
|
1435
|
+
--------
|
1436
|
+
Schematically, using :py:meth:`scan` is similar to :py:meth:`scan_subsets`,
|
1437
|
+
with the distinction that the objects contained in ``matching_data`` are not
|
1438
|
+
split and the search is only parallelized over angles.
|
1439
|
+
Assuming you have followed the example in :py:meth:`scan_subsets`, :py:meth:`scan`
|
1440
|
+
can be invoked like so
|
1441
|
+
|
1442
|
+
>>> from tme.matching_exhaustive import scan
|
1443
|
+
>>> results = scan(
|
1444
|
+
>>> matching_data = matching_data,
|
1445
|
+
>>> matching_score = matching_score,
|
1446
|
+
>>> matching_setup = matching_setup,
|
1447
|
+
>>> callback_class = callback_class,
|
1448
|
+
>>> callback_class_args = callback_class_args,
|
1449
|
+
>>> )
|
1450
|
+
|
1431
1451
|
"""
|
1432
1452
|
matching_data.to_backend()
|
1433
1453
|
shape_diff = backend.subtract(
|
1434
1454
|
matching_data._output_target_shape, matching_data._output_template_shape
|
1435
1455
|
)
|
1436
|
-
shape_diff = backend.multiply(shape_diff,
|
1456
|
+
shape_diff = backend.multiply(shape_diff, 1 - matching_data._batch_mask)
|
1457
|
+
|
1437
1458
|
if backend.sum(shape_diff < 0) and not pad_fourier:
|
1438
1459
|
warnings.warn(
|
1439
1460
|
"Target is larger than template and Fourier padding is turned off."
|
1440
|
-
" This can lead to shifted results. You can swap template and target,
|
1441
|
-
" zero-pad the target."
|
1461
|
+
" This can lead to shifted results. You can swap template and target,"
|
1462
|
+
" zero-pad the target or turn off template centering."
|
1442
1463
|
)
|
1443
1464
|
fast_shape, fast_ft_shape, fourier_shift = matching_data.fourier_padding(
|
1444
1465
|
pad_fourier=pad_fourier
|
@@ -1448,18 +1469,18 @@ def scan(
|
|
1448
1469
|
rfftn, irfftn = backend.build_fft(
|
1449
1470
|
fast_shape=fast_shape,
|
1450
1471
|
fast_ft_shape=fast_ft_shape,
|
1451
|
-
real_dtype=
|
1452
|
-
complex_dtype=
|
1472
|
+
real_dtype=backend._float_dtype,
|
1473
|
+
complex_dtype=backend._complex_dtype,
|
1453
1474
|
fftargs=fftargs,
|
1454
1475
|
)
|
1455
1476
|
|
1456
|
-
|
1457
|
-
|
1458
|
-
|
1459
|
-
|
1460
|
-
|
1461
|
-
|
1462
|
-
|
1477
|
+
template_filter = _setup_template_filter_apply_target_filter(
|
1478
|
+
matching_data=matching_data,
|
1479
|
+
rfftn=rfftn,
|
1480
|
+
irfftn=irfftn,
|
1481
|
+
fast_shape=fast_shape,
|
1482
|
+
fast_ft_shape=fast_ft_shape,
|
1483
|
+
)
|
1463
1484
|
|
1464
1485
|
setup = matching_setup(
|
1465
1486
|
rfftn=rfftn,
|
@@ -1470,31 +1491,14 @@ def scan(
|
|
1470
1491
|
target_mask=matching_data.target_mask,
|
1471
1492
|
fast_shape=fast_shape,
|
1472
1493
|
fast_ft_shape=fast_ft_shape,
|
1473
|
-
real_dtype=matching_data._default_dtype,
|
1474
|
-
complex_dtype=matching_data._complex_dtype,
|
1475
1494
|
callback_class=callback_class,
|
1476
1495
|
callback_class_args=callback_class_args,
|
1477
1496
|
**kwargs,
|
1478
1497
|
)
|
1479
1498
|
rfftn, irfftn = None, None
|
1480
1499
|
|
1481
|
-
|
1482
|
-
template_filter, preprocessor = None, Preprocessor()
|
1483
|
-
for method, parameters in matching_data.template_filter.items():
|
1484
|
-
parameters["shape"] = fast_shape
|
1485
|
-
parameters["omit_negative_frequencies"] = True
|
1486
|
-
out = preprocessor.apply_method(method=method, parameters=parameters)
|
1487
|
-
if template_filter is None:
|
1488
|
-
template_filter = out
|
1489
|
-
np.multiply(template_filter, out, out=template_filter)
|
1490
|
-
|
1491
|
-
if template_filter is None:
|
1492
|
-
template_filter = backend.full(
|
1493
|
-
shape=(1,), fill_value=1, dtype=backend._default_dtype
|
1494
|
-
)
|
1495
|
-
|
1496
1500
|
template_filter = backend.to_backend_array(template_filter)
|
1497
|
-
template_filter = backend.astype(template_filter, backend.
|
1501
|
+
template_filter = backend.astype(template_filter, backend._float_dtype)
|
1498
1502
|
template_filter_buffer = backend.arr_to_sharedarr(
|
1499
1503
|
arr=template_filter,
|
1500
1504
|
shared_memory_handler=kwargs.get("shared_memory_handler", None),
|
@@ -1520,7 +1524,6 @@ def scan(
|
|
1520
1524
|
if backend.sum(backend.to_backend_array(matching_data._target_pad)) > 0:
|
1521
1525
|
convolution_mode = "valid"
|
1522
1526
|
|
1523
|
-
|
1524
1527
|
callback_class_args["fourier_shift"] = fourier_shift
|
1525
1528
|
callback_class_args["convolution_mode"] = convolution_mode
|
1526
1529
|
callback_class_args["targetshape"] = setup["targetshape"]
|
@@ -1530,9 +1533,9 @@ def scan(
|
|
1530
1533
|
callback_classes = [
|
1531
1534
|
class_name(
|
1532
1535
|
score_space_shape=fast_shape,
|
1533
|
-
score_space_dtype=
|
1536
|
+
score_space_dtype=backend._float_dtype,
|
1534
1537
|
shared_memory_handler=kwargs.get("shared_memory_handler", None),
|
1535
|
-
rotation_space_dtype=backend.
|
1538
|
+
rotation_space_dtype=backend._int_dtype,
|
1536
1539
|
**callback_class_args,
|
1537
1540
|
)
|
1538
1541
|
for class_name in callback_classes
|
@@ -1571,23 +1574,31 @@ def scan(
|
|
1571
1574
|
|
1572
1575
|
callbacks = callbacks[0:n_callback_classes]
|
1573
1576
|
callbacks = [
|
1574
|
-
tuple(
|
1575
|
-
|
1576
|
-
|
1577
|
-
|
1578
|
-
|
1579
|
-
|
1580
|
-
|
1581
|
-
|
1577
|
+
tuple(
|
1578
|
+
callback._postprocess(
|
1579
|
+
fourier_shift=fourier_shift,
|
1580
|
+
convolution_mode=convolution_mode,
|
1581
|
+
targetshape=setup["targetshape"],
|
1582
|
+
templateshape=setup["templateshape"],
|
1583
|
+
shared_memory_handler=kwargs.get("shared_memory_handler", None),
|
1584
|
+
)
|
1585
|
+
)
|
1586
|
+
if hasattr(callback, "_postprocess")
|
1587
|
+
else tuple(callback)
|
1588
|
+
for callback in callbacks
|
1589
|
+
if callback is not None
|
1582
1590
|
]
|
1583
1591
|
backend.free_cache()
|
1584
1592
|
|
1585
1593
|
merged_callback = None
|
1586
1594
|
if callback_class is not None:
|
1595
|
+
score_indices = None
|
1596
|
+
if hasattr(matching_data, "indices"):
|
1597
|
+
score_indices = matching_data.indices
|
1587
1598
|
merged_callback = callback_class.merge(
|
1588
1599
|
callbacks,
|
1589
1600
|
**callback_class_args,
|
1590
|
-
score_indices=
|
1601
|
+
score_indices=score_indices,
|
1591
1602
|
inner_merge=True,
|
1592
1603
|
)
|
1593
1604
|
|
@@ -1610,14 +1621,14 @@ def scan_subsets(
|
|
1610
1621
|
**kwargs,
|
1611
1622
|
) -> Tuple:
|
1612
1623
|
"""
|
1613
|
-
Wrapper around :py:meth:`scan` that supports
|
1614
|
-
of
|
1624
|
+
Wrapper around :py:meth:`scan` that supports matching on splits
|
1625
|
+
of ``matching_data``.
|
1615
1626
|
|
1616
1627
|
Parameters
|
1617
1628
|
----------
|
1618
|
-
matching_data : MatchingData
|
1619
|
-
|
1620
|
-
|
1629
|
+
matching_data : :py:class:`tme.matching_data.MatchingData`
|
1630
|
+
MatchingData instance containing relevant data.
|
1631
|
+
matching_setup : type
|
1621
1632
|
Function pointer to setup function.
|
1622
1633
|
matching_score : type
|
1623
1634
|
Function pointer to scoring function.
|
@@ -1626,11 +1637,15 @@ def scan_subsets(
|
|
1626
1637
|
callback_class_args : dict, optional
|
1627
1638
|
Arguments passed to the callback_class. Default is an empty dictionary.
|
1628
1639
|
job_schedule : tuple of int, optional
|
1629
|
-
|
1640
|
+
Job scheduling scheme, default is (1, 1). First value corresponds
|
1641
|
+
to the number of splits that are processed in parallel, the second
|
1642
|
+
to the number of angles evaluated in parallel on each split.
|
1630
1643
|
target_splits : dict, optional
|
1631
|
-
Splits for target. Default is an empty dictionary, i.e. no splits
|
1644
|
+
Splits for target. Default is an empty dictionary, i.e. no splits.
|
1645
|
+
See :py:meth:`tme.matching_utils.compute_parallelization_schedule`.
|
1632
1646
|
template_splits : dict, optional
|
1633
1647
|
Splits for template. Default is an empty dictionary, i.e. no splits.
|
1648
|
+
See :py:meth:`tme.matching_utils.compute_parallelization_schedule`.
|
1634
1649
|
pad_target_edges : bool, optional
|
1635
1650
|
Whether to pad the target boundaries by half the template shape
|
1636
1651
|
along each axis.
|
@@ -1652,10 +1667,74 @@ def scan_subsets(
|
|
1652
1667
|
-------
|
1653
1668
|
Tuple
|
1654
1669
|
The merged results from callback_class if provided otherwise None.
|
1670
|
+
|
1671
|
+
Examples
|
1672
|
+
--------
|
1673
|
+
All data relevant to template matching will be contained in ``matching_data``, which
|
1674
|
+
is a :py:class:`tme.matching_data.MatchingData` instance and can be created like so
|
1675
|
+
|
1676
|
+
>>> import numpy as np
|
1677
|
+
>>> from tme.matching_data import MatchingData
|
1678
|
+
>>> from tme.matching_utils import get_rotation_matrices
|
1679
|
+
>>> target = np.random.rand(50,40,60)
|
1680
|
+
>>> template = target[15:25, 10:20, 30:40]
|
1681
|
+
>>> matching_data = MatchingData(target, template)
|
1682
|
+
>>> matching_data.rotations = get_rotation_matrices(
|
1683
|
+
>>> angular_sampling = 60, dim = target.ndim
|
1684
|
+
>>> )
|
1685
|
+
|
1686
|
+
The template matching procedure is determined by ``matching_setup`` and
|
1687
|
+
``matching_score``, which are unique to each score. In the following,
|
1688
|
+
we will be using the `FLCSphericalMask` score, which is composed of
|
1689
|
+
:py:meth:`flcSphericalMask_setup` and :py:meth:`corr_scoring`
|
1690
|
+
|
1691
|
+
>>> from tme.matching_exhaustive import MATCHING_EXHAUSTIVE_REGISTER
|
1692
|
+
>>> funcs = MATCHING_EXHAUSTIVE_REGISTER.get("FLCSphericalMask")
|
1693
|
+
>>> matching_setup, matching_score = funcs
|
1694
|
+
|
1695
|
+
Computed scores are flexibly analyzed by being passed through an analyzer. In the
|
1696
|
+
following, we will use :py:class:`tme.analyzer.MaxScoreOverRotations` to
|
1697
|
+
aggregate sores over rotations
|
1698
|
+
|
1699
|
+
>>> from tme.analyzer import MaxScoreOverRotations
|
1700
|
+
>>> callback_class = MaxScoreOverRotations
|
1701
|
+
>>> callback_class_args = {"score_threshold" : 0}
|
1702
|
+
|
1703
|
+
In case the entire template matching problem does not fit into memory, we can
|
1704
|
+
determine the splitting procedure. In this case, we halv the first axis of the target
|
1705
|
+
once. Splitting and ``job_schedule`` is typically computed using
|
1706
|
+
:py:meth:`tme.matching_utils.compute_parallelization_schedule`.
|
1707
|
+
|
1708
|
+
>>> target_splits = {0 : 1}
|
1709
|
+
|
1710
|
+
Finally, we can perform template matching. Note that the data
|
1711
|
+
contained in ``matching_data`` will be destroyed when running the following
|
1712
|
+
|
1713
|
+
>>> from tme.matching_exhaustive import scan_subsets
|
1714
|
+
>>> results = scan_subsets(
|
1715
|
+
>>> matching_data = matching_data,
|
1716
|
+
>>> matching_score = matching_score,
|
1717
|
+
>>> matching_setup = matching_setup,
|
1718
|
+
>>> callback_class = callback_class,
|
1719
|
+
>>> callback_class_args = callback_class_args,
|
1720
|
+
>>> target_splits = target_splits,
|
1721
|
+
>>> )
|
1722
|
+
|
1723
|
+
The retuned ``results`` tuple contains the output of the chosen analyzer.
|
1724
|
+
|
1725
|
+
See Also
|
1726
|
+
--------
|
1727
|
+
:py:meth:`tme.matching_utils.compute_parallelization_schedule`
|
1655
1728
|
"""
|
1656
1729
|
target_splits = split_numpy_array_slices(
|
1657
1730
|
matching_data._target.shape, splits=target_splits
|
1658
1731
|
)
|
1732
|
+
if (len(target_splits) > 1) and not pad_target_edges:
|
1733
|
+
warnings.warn(
|
1734
|
+
"Target splitting without padding target edges leads to unreliable "
|
1735
|
+
"similarity estimates around the split border."
|
1736
|
+
)
|
1737
|
+
|
1659
1738
|
template_splits = split_numpy_array_slices(
|
1660
1739
|
matching_data._template.shape, splits=template_splits
|
1661
1740
|
)
|