pytme 0.1.9__cp311-cp311-macosx_14_0_arm64.whl → 0.2.0__cp311-cp311-macosx_14_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pytme-0.2.0.data/scripts/match_template.py +1019 -0
- pytme-0.2.0.data/scripts/postprocess.py +570 -0
- {pytme-0.1.9.data → pytme-0.2.0.data}/scripts/preprocessor_gui.py +244 -60
- {pytme-0.1.9.dist-info → pytme-0.2.0.dist-info}/METADATA +3 -1
- pytme-0.2.0.dist-info/RECORD +72 -0
- {pytme-0.1.9.dist-info → pytme-0.2.0.dist-info}/WHEEL +1 -1
- scripts/extract_candidates.py +218 -0
- scripts/match_template.py +459 -218
- pytme-0.1.9.data/scripts/match_template.py → scripts/match_template_filters.py +459 -218
- scripts/postprocess.py +380 -435
- scripts/preprocessor_gui.py +244 -60
- scripts/refine_matches.py +218 -0
- tme/__init__.py +2 -1
- tme/__version__.py +1 -1
- tme/analyzer.py +533 -78
- tme/backends/cupy_backend.py +80 -15
- tme/backends/npfftw_backend.py +35 -6
- tme/backends/pytorch_backend.py +15 -7
- tme/density.py +173 -78
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/matching_constrained.py +195 -0
- tme/matching_data.py +76 -33
- tme/matching_exhaustive.py +354 -225
- tme/matching_memory.py +1 -0
- tme/matching_optimization.py +753 -649
- tme/matching_utils.py +152 -8
- tme/orientations.py +561 -0
- tme/preprocessing/__init__.py +2 -0
- tme/preprocessing/_utils.py +176 -0
- tme/preprocessing/composable_filter.py +30 -0
- tme/preprocessing/compose.py +52 -0
- tme/preprocessing/frequency_filters.py +322 -0
- tme/preprocessing/tilt_series.py +967 -0
- tme/preprocessor.py +35 -25
- tme/structure.py +2 -37
- pytme-0.1.9.data/scripts/postprocess.py +0 -625
- pytme-0.1.9.dist-info/RECORD +0 -61
- {pytme-0.1.9.data → pytme-0.2.0.data}/scripts/estimate_ram_usage.py +0 -0
- {pytme-0.1.9.data → pytme-0.2.0.data}/scripts/preprocess.py +0 -0
- {pytme-0.1.9.dist-info → pytme-0.2.0.dist-info}/LICENSE +0 -0
- {pytme-0.1.9.dist-info → pytme-0.2.0.dist-info}/entry_points.txt +0 -0
- {pytme-0.1.9.dist-info → pytme-0.2.0.dist-info}/top_level.txt +0 -0
tme/matching_exhaustive.py
CHANGED
@@ -19,13 +19,12 @@ from scipy.ndimage import laplace
|
|
19
19
|
|
20
20
|
from .analyzer import MaxScoreOverRotations
|
21
21
|
from .matching_utils import (
|
22
|
-
apply_convolution_mode,
|
23
22
|
handle_traceback,
|
24
23
|
split_numpy_array_slices,
|
25
24
|
conditional_execute,
|
26
25
|
)
|
27
26
|
from .matching_memory import MatchingMemoryUsage, MATCHING_MEMORY_REGISTRY
|
28
|
-
from .
|
27
|
+
from .preprocessing import Compose
|
29
28
|
from .matching_data import MatchingData
|
30
29
|
from .backends import backend
|
31
30
|
from .types import NDArray, CallbackClass
|
@@ -43,6 +42,58 @@ def _run_inner(backend_name, backend_args, **kwargs):
|
|
43
42
|
return scan(**kwargs)
|
44
43
|
|
45
44
|
|
45
|
+
def normalize_under_mask(template: NDArray, mask: NDArray, mask_intensity) -> None:
|
46
|
+
"""
|
47
|
+
Standardizes the values in in template by subtracting the mean and dividing by the
|
48
|
+
standard deviation based on the elements in mask. Subsequently, the template is
|
49
|
+
multiplied by the mask.
|
50
|
+
|
51
|
+
Parameters
|
52
|
+
----------
|
53
|
+
template : NDArray
|
54
|
+
The data array to be normalized. This array is modified in-place.
|
55
|
+
mask : NDArray
|
56
|
+
A boolean array of the same shape as `template`. True values indicate the positions in `template`
|
57
|
+
to consider for normalization.
|
58
|
+
mask_intensity : float
|
59
|
+
Mask intensity used to compute expectations.
|
60
|
+
|
61
|
+
References
|
62
|
+
----------
|
63
|
+
.. [1] T. Hrabe, Y. Chen, S. Pfeffer, L. Kuhn Cuellar, A.-V. Mangold,
|
64
|
+
and F. Förster, J. Struct. Biol. 178, 177 (2012).
|
65
|
+
.. [2] M. L. Chaillet, G. van der Schot, I. Gubins, S. Roet,
|
66
|
+
R. C. Veltkamp, and F. Förster, Int. J. Mol. Sci. 24,
|
67
|
+
13375 (2023)
|
68
|
+
|
69
|
+
Returns
|
70
|
+
-------
|
71
|
+
None
|
72
|
+
This function modifies `template` in-place and does not return any value.
|
73
|
+
"""
|
74
|
+
masked_mean = backend.sum(backend.multiply(template, mask))
|
75
|
+
masked_mean = backend.divide(masked_mean, mask_intensity)
|
76
|
+
masked_std = backend.sum(backend.multiply(backend.square(template), mask))
|
77
|
+
masked_std = backend.subtract(
|
78
|
+
masked_std / mask_intensity, backend.square(masked_mean)
|
79
|
+
)
|
80
|
+
masked_std = backend.sqrt(backend.maximum(masked_std, 0))
|
81
|
+
|
82
|
+
backend.subtract(template, masked_mean, out=template)
|
83
|
+
backend.divide(template, masked_std, out=template)
|
84
|
+
backend.multiply(template, mask, out=template)
|
85
|
+
return None
|
86
|
+
|
87
|
+
|
88
|
+
def apply_filter(ft_template, template_filter):
|
89
|
+
# This is an approximation to applying the mask, irfftn, normalize, rfftn
|
90
|
+
std_before = backend.std(ft_template)
|
91
|
+
backend.multiply(ft_template, template_filter, out=ft_template)
|
92
|
+
backend.multiply(
|
93
|
+
ft_template, std_before / backend.std(ft_template), out=ft_template
|
94
|
+
)
|
95
|
+
|
96
|
+
|
46
97
|
def cc_setup(
|
47
98
|
rfftn: Callable,
|
48
99
|
irfftn: Callable,
|
@@ -212,13 +263,12 @@ def corr_setup(
|
|
212
263
|
template_mean = backend.sum(backend.multiply(template, template_mask))
|
213
264
|
template_mean = backend.divide(template_mean, n_observations)
|
214
265
|
template_ssd = backend.sum(
|
215
|
-
backend.square(
|
216
|
-
backend.multiply(template, template_mean),
|
217
|
-
|
218
|
-
))
|
266
|
+
backend.square(
|
267
|
+
backend.multiply(backend.multiply(template, template_mean), template_mask)
|
268
|
+
)
|
219
269
|
)
|
220
270
|
template_volume = np.prod(template.shape)
|
221
|
-
backend.multiply(template, template_mask, out
|
271
|
+
backend.multiply(template, template_mask, out=template)
|
222
272
|
|
223
273
|
# Final numerator is score - numerator2
|
224
274
|
numerator2 = backend.multiply(target_window_sum, template_mean)
|
@@ -314,49 +364,6 @@ def cam_setup(**kwargs):
|
|
314
364
|
return corr_setup(**kwargs)
|
315
365
|
|
316
366
|
|
317
|
-
def _normalize_under_mask(template: NDArray, mask: NDArray, mask_intensity) -> None:
|
318
|
-
"""
|
319
|
-
Standardizes the values in in template by subtracting the mean and dividing by the
|
320
|
-
standard deviation based on the elements in mask. Subsequently, the template is
|
321
|
-
multiplied by the mask.
|
322
|
-
|
323
|
-
Parameters
|
324
|
-
----------
|
325
|
-
template : NDArray
|
326
|
-
The data array to be normalized. This array is modified in-place.
|
327
|
-
mask : NDArray
|
328
|
-
A boolean array of the same shape as `template`. True values indicate the positions in `template`
|
329
|
-
to consider for normalization.
|
330
|
-
mask_intensity : float
|
331
|
-
Mask intensity used to compute expectations.
|
332
|
-
|
333
|
-
References
|
334
|
-
----------
|
335
|
-
.. [1] T. Hrabe, Y. Chen, S. Pfeffer, L. Kuhn Cuellar, A.-V. Mangold,
|
336
|
-
and F. Förster, J. Struct. Biol. 178, 177 (2012).
|
337
|
-
.. [2] M. L. Chaillet, G. van der Schot, I. Gubins, S. Roet,
|
338
|
-
R. C. Veltkamp, and F. Förster, Int. J. Mol. Sci. 24,
|
339
|
-
13375 (2023)
|
340
|
-
|
341
|
-
Returns
|
342
|
-
-------
|
343
|
-
None
|
344
|
-
This function modifies `template` in-place and does not return any value.
|
345
|
-
"""
|
346
|
-
masked_mean = backend.sum(backend.multiply(template, mask))
|
347
|
-
masked_mean = backend.divide(masked_mean, mask_intensity)
|
348
|
-
masked_std = backend.sum(backend.multiply(backend.square(template), mask))
|
349
|
-
masked_std = backend.subtract(
|
350
|
-
masked_std / mask_intensity, backend.square(masked_mean)
|
351
|
-
)
|
352
|
-
masked_std = backend.sqrt(backend.maximum(masked_std, 0))
|
353
|
-
|
354
|
-
backend.subtract(template, masked_mean, out=template)
|
355
|
-
backend.divide(template, masked_std, out=template)
|
356
|
-
backend.multiply(template, mask, out=template)
|
357
|
-
return None
|
358
|
-
|
359
|
-
|
360
367
|
def flc_setup(
|
361
368
|
rfftn: Callable,
|
362
369
|
irfftn: Callable,
|
@@ -419,7 +426,7 @@ def flc_setup(
|
|
419
426
|
arr=ft_target2, shared_memory_handler=shared_memory_handler
|
420
427
|
)
|
421
428
|
|
422
|
-
|
429
|
+
normalize_under_mask(
|
423
430
|
template=template, mask=template_mask, mask_intensity=backend.sum(template_mask)
|
424
431
|
)
|
425
432
|
|
@@ -541,7 +548,7 @@ def flcSphericalMask_setup(
|
|
541
548
|
backend.fill(temp2, 0)
|
542
549
|
temp2[nonzero_indices] = 1 / temp[nonzero_indices]
|
543
550
|
|
544
|
-
|
551
|
+
normalize_under_mask(
|
545
552
|
template=template, mask=template_mask, mask_intensity=backend.sum(template_mask)
|
546
553
|
)
|
547
554
|
|
@@ -728,10 +735,6 @@ def corr_scoring(
|
|
728
735
|
datatype.
|
729
736
|
numerator2 : Tuple[type, Tuple[int], type]
|
730
737
|
Tuple containing a pointer to the numerator2 data, its shape, and its datatype.
|
731
|
-
targetshape : Tuple[int]
|
732
|
-
The shape of the target.
|
733
|
-
templateshape : Tuple[int]
|
734
|
-
The shape of the template.
|
735
738
|
fast_shape : Tuple[int]
|
736
739
|
The shape for fast Fourier transform.
|
737
740
|
fast_ft_shape : Tuple[int]
|
@@ -750,8 +753,6 @@ def corr_scoring(
|
|
750
753
|
instantiable.
|
751
754
|
interpolation_order : int
|
752
755
|
The order of interpolation to be used while rotating the template.
|
753
|
-
convolution_mode : str, optional
|
754
|
-
Mode to use for convolution, default is "full".
|
755
756
|
**kwargs :
|
756
757
|
Additional arguments to be passed to the function.
|
757
758
|
|
@@ -767,37 +768,23 @@ def corr_scoring(
|
|
767
768
|
:py:meth:`cam_setup`
|
768
769
|
:py:meth:`flcSphericalMask_setup`
|
769
770
|
"""
|
770
|
-
|
771
|
-
ft_target_buffer, ft_target_shape, ft_target_dtype = ft_target
|
772
|
-
inv_denominator_buffer, inv_denominator_pointer_shape, _ = inv_denominator
|
773
|
-
numerator2_buffer, numerator2_shape, _ = numerator2
|
774
|
-
filter_buffer, filter_shape, filter_dtype = template_filter
|
775
|
-
|
771
|
+
callback = callback_class
|
776
772
|
if callback_class is not None and isinstance(callback_class, type):
|
777
773
|
callback = callback_class(**callback_class_args)
|
778
|
-
elif not isinstance(callback_class, type):
|
779
|
-
callback = callback_class
|
780
774
|
|
781
|
-
|
782
|
-
template = backend.sharedarr_to_arr(template_shape, template_dtype
|
783
|
-
ft_target = backend.sharedarr_to_arr(
|
784
|
-
|
785
|
-
)
|
786
|
-
|
787
|
-
inv_denominator_pointer_shape, template_dtype, inv_denominator_buffer
|
788
|
-
)
|
789
|
-
numerator2 = backend.sharedarr_to_arr(
|
790
|
-
numerator2_shape, template_dtype, numerator2_buffer
|
791
|
-
)
|
792
|
-
template_filter = backend.sharedarr_to_arr(
|
793
|
-
filter_shape, filter_dtype, filter_buffer
|
794
|
-
)
|
775
|
+
template_buffer, template_shape, template_dtype = template
|
776
|
+
template = backend.sharedarr_to_arr(template_buffer, template_shape, template_dtype)
|
777
|
+
ft_target = backend.sharedarr_to_arr(*ft_target)
|
778
|
+
inv_denominator = backend.sharedarr_to_arr(*inv_denominator)
|
779
|
+
numerator2 = backend.sharedarr_to_arr(*numerator2)
|
780
|
+
template_filter = backend.sharedarr_to_arr(*template_filter)
|
795
781
|
|
796
782
|
norm_template, template_mask, mask_sum = False, 1, 1
|
797
783
|
if "template_mask" in kwargs:
|
798
|
-
template_mask = backend.sharedarr_to_arr(
|
784
|
+
template_mask = backend.sharedarr_to_arr(
|
785
|
+
kwargs["template_mask"][0], template_shape, template_dtype
|
786
|
+
)
|
799
787
|
norm_template, mask_sum = True, backend.sum(template_mask)
|
800
|
-
norm_template = conditional_execute(_normalize_under_mask, norm_template)
|
801
788
|
|
802
789
|
arr = backend.preallocate_array(fast_shape, real_dtype)
|
803
790
|
ft_temp = backend.preallocate_array(fast_ft_shape, complex_dtype)
|
@@ -816,17 +803,15 @@ def corr_scoring(
|
|
816
803
|
norm_denominator = (backend.sum(inv_denominator) != 1) & (
|
817
804
|
backend.size(inv_denominator) != 1
|
818
805
|
)
|
819
|
-
filter_template = backend.size(template_filter) != 0
|
820
806
|
|
807
|
+
norm_template = conditional_execute(normalize_under_mask, norm_template)
|
808
|
+
callback_func = conditional_execute(callback, callback_class is not None)
|
821
809
|
norm_func_numerator = conditional_execute(backend.subtract, norm_numerator)
|
822
810
|
norm_func_denominator = conditional_execute(backend.multiply, norm_denominator)
|
823
|
-
template_filter_func = conditional_execute(
|
824
|
-
|
825
|
-
|
826
|
-
fourier_shift = callback_class_args.get("fourier_shift", backend.zeros(arr.ndim))
|
827
|
-
fourier_shift_scores = backend.sum(fourier_shift != 0) != 0
|
811
|
+
template_filter_func = conditional_execute(
|
812
|
+
apply_filter, backend.size(template_filter) != 1
|
813
|
+
)
|
828
814
|
|
829
|
-
template_sum = backend.sum(template)
|
830
815
|
unpadded_slice = tuple(slice(0, stop) for stop in template.shape)
|
831
816
|
for index in range(rotations.shape[0]):
|
832
817
|
rotation = rotations[index]
|
@@ -838,34 +823,24 @@ def corr_scoring(
|
|
838
823
|
use_geometric_center=False,
|
839
824
|
order=interpolation_order,
|
840
825
|
)
|
841
|
-
|
842
|
-
backend.multiply(arr, rotation_norm, out=arr)
|
826
|
+
|
843
827
|
norm_template(arr[unpadded_slice], template_mask, mask_sum)
|
844
828
|
|
845
829
|
rfftn(arr, ft_temp)
|
846
|
-
template_filter_func(ft_temp, template_filter
|
847
|
-
|
830
|
+
template_filter_func(ft_template=ft_temp, template_filter=template_filter)
|
848
831
|
backend.multiply(ft_target, ft_temp, out=ft_temp)
|
849
832
|
irfftn(ft_temp, arr)
|
850
833
|
|
851
834
|
norm_func_numerator(arr, numerator2, out=arr)
|
852
835
|
norm_func_denominator(arr, inv_denominator, out=arr)
|
853
836
|
|
854
|
-
|
855
|
-
arr
|
856
|
-
|
857
|
-
|
858
|
-
|
837
|
+
callback_func(
|
838
|
+
arr,
|
839
|
+
rotation_matrix=rotation,
|
840
|
+
rotation_index=index,
|
841
|
+
**callback_class_args,
|
859
842
|
)
|
860
843
|
|
861
|
-
if callback_class is not None:
|
862
|
-
callback(
|
863
|
-
score,
|
864
|
-
rotation_matrix=rotation,
|
865
|
-
rotation_index=index,
|
866
|
-
**callback_class_args,
|
867
|
-
)
|
868
|
-
|
869
844
|
return callback
|
870
845
|
|
871
846
|
|
@@ -913,32 +888,15 @@ def flc_scoring(
|
|
913
888
|
.. [2] T. Hrabe, Y. Chen, S. Pfeffer, L. Kuhn Cuellar, A.-V. Mangold,
|
914
889
|
and F. Förster, J. Struct. Biol. 178, 177 (2012).
|
915
890
|
"""
|
916
|
-
|
917
|
-
template_mask_buffer, *_ = template_mask
|
918
|
-
filter_buffer, filter_shape, filter_dtype = template_filter
|
919
|
-
|
920
|
-
ft_target_buffer, ft_target_shape, ft_target_dtype = ft_target
|
921
|
-
ft_target2_buffer, *_ = ft_target2
|
922
|
-
|
891
|
+
callback = callback_class
|
923
892
|
if callback_class is not None and isinstance(callback_class, type):
|
924
893
|
callback = callback_class(**callback_class_args)
|
925
|
-
elif not isinstance(callback_class, type):
|
926
|
-
callback = callback_class
|
927
894
|
|
928
|
-
|
929
|
-
|
930
|
-
|
931
|
-
|
932
|
-
)
|
933
|
-
ft_target = backend.sharedarr_to_arr(
|
934
|
-
ft_target_shape, ft_target_dtype, ft_target_buffer
|
935
|
-
)
|
936
|
-
ft_target2 = backend.sharedarr_to_arr(
|
937
|
-
ft_target_shape, ft_target_dtype, ft_target2_buffer
|
938
|
-
)
|
939
|
-
template_filter = backend.sharedarr_to_arr(
|
940
|
-
filter_shape, filter_dtype, filter_buffer
|
941
|
-
)
|
895
|
+
template = backend.sharedarr_to_arr(*template)
|
896
|
+
template_mask = backend.sharedarr_to_arr(*template_mask)
|
897
|
+
ft_target = backend.sharedarr_to_arr(*ft_target)
|
898
|
+
ft_target2 = backend.sharedarr_to_arr(*ft_target2)
|
899
|
+
template_filter = backend.sharedarr_to_arr(*template_filter)
|
942
900
|
|
943
901
|
arr = backend.preallocate_array(fast_shape, real_dtype)
|
944
902
|
temp = backend.preallocate_array(fast_shape, real_dtype)
|
@@ -957,12 +915,10 @@ def flc_scoring(
|
|
957
915
|
temp_fft=ft_temp,
|
958
916
|
)
|
959
917
|
eps = backend.eps(real_dtype)
|
960
|
-
|
961
|
-
|
962
|
-
|
963
|
-
|
964
|
-
fourier_shift = callback_class_args.get("fourier_shift", backend.zeros(arr.ndim))
|
965
|
-
fourier_shift_scores = backend.sum(fourier_shift != 0) != 0
|
918
|
+
template_filter_func = conditional_execute(
|
919
|
+
apply_filter, backend.size(template_filter) != 1
|
920
|
+
)
|
921
|
+
callback_func = conditional_execute(callback, callback_class is not None)
|
966
922
|
|
967
923
|
unpadded_slice = tuple(slice(0, stop) for stop in template.shape)
|
968
924
|
for index in range(rotations.shape[0]):
|
@@ -981,7 +937,7 @@ def flc_scoring(
|
|
981
937
|
# Given the amount of FFTs, might aswell normalize properly
|
982
938
|
n_observations = backend.sum(temp)
|
983
939
|
|
984
|
-
|
940
|
+
normalize_under_mask(
|
985
941
|
template=arr[unpadded_slice],
|
986
942
|
mask=temp[unpadded_slice],
|
987
943
|
mask_intensity=n_observations,
|
@@ -1004,7 +960,7 @@ def flc_scoring(
|
|
1004
960
|
backend.multiply(temp, n_observations, out=temp)
|
1005
961
|
|
1006
962
|
rfftn(arr, ft_temp)
|
1007
|
-
template_filter_func(ft_temp, template_filter
|
963
|
+
template_filter_func(ft_template=ft_temp, template_filter=template_filter)
|
1008
964
|
backend.multiply(ft_target, ft_temp, out=ft_temp)
|
1009
965
|
irfftn(ft_temp, arr)
|
1010
966
|
|
@@ -1013,23 +969,161 @@ def flc_scoring(
|
|
1013
969
|
backend.fill(temp2, 0)
|
1014
970
|
temp2[nonzero_indices] = arr[nonzero_indices] / temp[nonzero_indices]
|
1015
971
|
|
1016
|
-
|
972
|
+
callback_func(
|
973
|
+
temp2,
|
974
|
+
rotation_matrix=rotation,
|
975
|
+
rotation_index=index,
|
976
|
+
**callback_class_args,
|
977
|
+
)
|
978
|
+
|
979
|
+
return callback
|
980
|
+
|
981
|
+
|
982
|
+
def flc_scoring2(
|
983
|
+
template: Tuple[type, Tuple[int], type],
|
984
|
+
template_mask: Tuple[type, Tuple[int], type],
|
985
|
+
ft_target: Tuple[type, Tuple[int], type],
|
986
|
+
ft_target2: Tuple[type, Tuple[int], type],
|
987
|
+
template_filter: Tuple[type, Tuple[int], type],
|
988
|
+
targetshape: Tuple[int],
|
989
|
+
templateshape: Tuple[int],
|
990
|
+
fast_shape: Tuple[int],
|
991
|
+
fast_ft_shape: Tuple[int],
|
992
|
+
rotations: NDArray,
|
993
|
+
real_dtype: type,
|
994
|
+
complex_dtype: type,
|
995
|
+
callback_class: CallbackClass,
|
996
|
+
callback_class_args: Dict,
|
997
|
+
interpolation_order: int,
|
998
|
+
**kwargs,
|
999
|
+
) -> CallbackClass:
|
1000
|
+
"""
|
1001
|
+
Computes a normalized cross-correlation score of a target f a template g
|
1002
|
+
and a mask m:
|
1003
|
+
|
1004
|
+
.. math::
|
1005
|
+
|
1006
|
+
\\frac{CC(f, \\frac{g*m - \\overline{g*m}}{\\sigma_{g*m}})}
|
1007
|
+
{N_m * \\sqrt{
|
1008
|
+
\\frac{CC(f^2, m)}{N_m} - (\\frac{CC(f, m)}{N_m})^2}
|
1009
|
+
}
|
1010
|
+
|
1011
|
+
Where:
|
1012
|
+
|
1013
|
+
.. math::
|
1014
|
+
|
1015
|
+
CC(f,g) = \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*)
|
1016
|
+
|
1017
|
+
and Nm is the number of voxels within the template mask m.
|
1018
|
+
|
1019
|
+
References
|
1020
|
+
----------
|
1021
|
+
.. [1] W. Wan, S. Khavnekar, J. Wagner, P. Erdmann, and W. Baumeister
|
1022
|
+
Microsc. Microanal. 26, 2516 (2020)
|
1023
|
+
.. [2] T. Hrabe, Y. Chen, S. Pfeffer, L. Kuhn Cuellar, A.-V. Mangold,
|
1024
|
+
and F. Förster, J. Struct. Biol. 178, 177 (2012).
|
1025
|
+
"""
|
1026
|
+
callback = callback_class
|
1027
|
+
if callback_class is not None and isinstance(callback_class, type):
|
1028
|
+
callback = callback_class(**callback_class_args)
|
1029
|
+
|
1030
|
+
# Retrieve objects from shared memory
|
1031
|
+
template = backend.sharedarr_to_arr(*template)
|
1032
|
+
template_mask = backend.sharedarr_to_arr(*template_mask)
|
1033
|
+
ft_target = backend.sharedarr_to_arr(*ft_target)
|
1034
|
+
ft_target2 = backend.sharedarr_to_arr(*ft_target2)
|
1035
|
+
template_filter = backend.sharedarr_to_arr(*template_filter)
|
1017
1036
|
|
1018
|
-
|
1019
|
-
|
1037
|
+
arr = backend.preallocate_array(fast_shape, real_dtype)
|
1038
|
+
temp = backend.preallocate_array(fast_shape, real_dtype)
|
1039
|
+
temp2 = backend.preallocate_array(fast_shape, real_dtype)
|
1040
|
+
|
1041
|
+
ft_temp = backend.preallocate_array(fast_ft_shape, complex_dtype)
|
1042
|
+
ft_denom = backend.preallocate_array(fast_ft_shape, complex_dtype)
|
1020
1043
|
|
1021
|
-
|
1022
|
-
|
1044
|
+
eps = backend.eps(real_dtype)
|
1045
|
+
template_filter_func = conditional_execute(
|
1046
|
+
apply_filter, backend.size(template_filter) != 1
|
1047
|
+
)
|
1048
|
+
callback_func = conditional_execute(callback, callback_class is not None)
|
1049
|
+
|
1050
|
+
squeeze_axis = tuple(i for i, x in enumerate(template.shape) if x == 1)
|
1051
|
+
squeeze = tuple(
|
1052
|
+
slice(0, stop) if i not in squeeze_axis else 0
|
1053
|
+
for i, stop in enumerate(template.shape)
|
1054
|
+
)
|
1055
|
+
squeeze_fast = tuple(
|
1056
|
+
slice(0, stop) if i not in squeeze_axis else 0
|
1057
|
+
for i, stop in enumerate(fast_shape)
|
1058
|
+
)
|
1059
|
+
squeeze_fast_ft = tuple(
|
1060
|
+
slice(0, stop) if i not in squeeze_axis else 0
|
1061
|
+
for i, stop in enumerate(fast_ft_shape)
|
1062
|
+
)
|
1063
|
+
|
1064
|
+
rfftn, irfftn = backend.build_fft(
|
1065
|
+
fast_shape=temp[squeeze_fast].shape,
|
1066
|
+
fast_ft_shape=fast_ft_shape,
|
1067
|
+
real_dtype=real_dtype,
|
1068
|
+
complex_dtype=complex_dtype,
|
1069
|
+
fftargs=kwargs.get("fftargs", {}),
|
1070
|
+
inverse_fast_shape=fast_shape,
|
1071
|
+
temp_real=arr[squeeze_fast],
|
1072
|
+
temp_fft=ft_temp,
|
1073
|
+
)
|
1074
|
+
for index in range(rotations.shape[0]):
|
1075
|
+
rotation = rotations[index]
|
1076
|
+
backend.fill(arr, 0)
|
1077
|
+
backend.fill(temp, 0)
|
1078
|
+
backend.rotate_array(
|
1079
|
+
arr=template[squeeze],
|
1080
|
+
arr_mask=template_mask[squeeze],
|
1081
|
+
rotation_matrix=rotation,
|
1082
|
+
out=arr[squeeze],
|
1083
|
+
out_mask=temp[squeeze],
|
1084
|
+
use_geometric_center=False,
|
1085
|
+
order=interpolation_order,
|
1023
1086
|
)
|
1087
|
+
# Given the amount of FFTs, might aswell normalize properly
|
1088
|
+
n_observations = backend.sum(temp)
|
1024
1089
|
|
1025
|
-
|
1026
|
-
|
1027
|
-
|
1028
|
-
|
1029
|
-
|
1030
|
-
|
1031
|
-
|
1090
|
+
normalize_under_mask(
|
1091
|
+
template=arr[squeeze],
|
1092
|
+
mask=temp[squeeze],
|
1093
|
+
mask_intensity=n_observations,
|
1094
|
+
)
|
1095
|
+
rfftn(temp[squeeze_fast], ft_temp[squeeze_fast_ft])
|
1096
|
+
|
1097
|
+
backend.multiply(ft_target, ft_temp[squeeze_fast_ft], out=ft_denom)
|
1098
|
+
irfftn(ft_denom, temp)
|
1099
|
+
backend.divide(temp, n_observations, out=temp)
|
1100
|
+
backend.square(temp, out=temp)
|
1101
|
+
|
1102
|
+
backend.multiply(ft_target2, ft_temp[squeeze_fast_ft], out=ft_denom)
|
1103
|
+
irfftn(ft_denom, temp2)
|
1104
|
+
backend.divide(temp2, n_observations, out=temp2)
|
1032
1105
|
|
1106
|
+
backend.subtract(temp2, temp, out=temp)
|
1107
|
+
backend.maximum(temp, 0.0, out=temp)
|
1108
|
+
backend.sqrt(temp, out=temp)
|
1109
|
+
backend.multiply(temp, n_observations, out=temp)
|
1110
|
+
|
1111
|
+
rfftn(arr[squeeze_fast], ft_temp[squeeze_fast_ft])
|
1112
|
+
template_filter_func(ft_template=ft_temp, template_filter=template_filter)
|
1113
|
+
backend.multiply(ft_target, ft_temp[squeeze_fast_ft], out=ft_denom)
|
1114
|
+
irfftn(ft_denom, arr)
|
1115
|
+
|
1116
|
+
tol = tol = 1e3 * eps * backend.max(backend.abs(temp))
|
1117
|
+
nonzero_indices = temp > tol
|
1118
|
+
backend.fill(temp2, 0)
|
1119
|
+
temp2[nonzero_indices] = arr[nonzero_indices] / temp[nonzero_indices]
|
1120
|
+
|
1121
|
+
callback_func(
|
1122
|
+
temp2,
|
1123
|
+
rotation_matrix=rotation,
|
1124
|
+
rotation_index=index,
|
1125
|
+
**callback_class_args,
|
1126
|
+
)
|
1033
1127
|
return callback
|
1034
1128
|
|
1035
1129
|
|
@@ -1083,35 +1177,18 @@ def mcc_scoring(
|
|
1083
1177
|
--------
|
1084
1178
|
:py:class:`tme.matching_optimization.MaskedCrossCorrelation`
|
1085
1179
|
"""
|
1086
|
-
|
1087
|
-
ft_target_buffer, ft_target_shape, ft_target_dtype = ft_target
|
1088
|
-
ft_target2_buffer, ft_target_shape, ft_target_dtype = ft_target2
|
1089
|
-
template_mask_buffer, _, _ = template
|
1090
|
-
ft_target_mask_buffer, _, _ = ft_target
|
1091
|
-
filter_buffer, filter_shape, filter_dtype = template_filter
|
1092
|
-
|
1180
|
+
callback = callback_class
|
1093
1181
|
if callback_class is not None and isinstance(callback_class, type):
|
1094
1182
|
callback = callback_class(**callback_class_args)
|
1095
|
-
elif not isinstance(callback_class, type):
|
1096
|
-
callback = callback_class
|
1097
1183
|
|
1098
1184
|
# Retrieve objects from shared memory
|
1099
|
-
|
1100
|
-
|
1101
|
-
|
1102
|
-
)
|
1103
|
-
|
1104
|
-
|
1105
|
-
)
|
1106
|
-
template_mask = backend.sharedarr_to_arr(
|
1107
|
-
template_shape, template_dtype, template_mask_buffer
|
1108
|
-
)
|
1109
|
-
target_mask_ft = backend.sharedarr_to_arr(
|
1110
|
-
ft_target_shape, ft_target_dtype, ft_target_mask_buffer
|
1111
|
-
)
|
1112
|
-
template_filter = backend.sharedarr_to_arr(
|
1113
|
-
filter_shape, filter_dtype, filter_buffer
|
1114
|
-
)
|
1185
|
+
template_buffer, template_shape, template_dtype = template
|
1186
|
+
template = backend.sharedarr_to_arr(*template)
|
1187
|
+
target_ft = backend.sharedarr_to_arr(*ft_target)
|
1188
|
+
target_ft2 = backend.sharedarr_to_arr(*ft_target2)
|
1189
|
+
template_mask = backend.sharedarr_to_arr(*template_mask)
|
1190
|
+
target_mask_ft = backend.sharedarr_to_arr(*ft_target_mask)
|
1191
|
+
template_filter = backend.sharedarr_to_arr(*template_filter)
|
1115
1192
|
|
1116
1193
|
axes = tuple(range(template.ndim))
|
1117
1194
|
eps = backend.eps(real_dtype)
|
@@ -1136,14 +1213,10 @@ def mcc_scoring(
|
|
1136
1213
|
temp_fft=temp_ft,
|
1137
1214
|
)
|
1138
1215
|
|
1139
|
-
|
1140
|
-
|
1141
|
-
|
1142
|
-
axis = tuple(range(template.ndim))
|
1143
|
-
fourier_shift = callback_class_args.get(
|
1144
|
-
"fourier_shift", backend.zeros(template.ndim)
|
1216
|
+
template_filter_func = conditional_execute(
|
1217
|
+
apply_filter, backend.size(template_filter) != 1
|
1145
1218
|
)
|
1146
|
-
|
1219
|
+
callback_func = conditional_execute(callback, callback_class is not None)
|
1147
1220
|
|
1148
1221
|
# Calculate scores across all rotations
|
1149
1222
|
for index in range(rotations.shape[0]):
|
@@ -1165,7 +1238,7 @@ def mcc_scoring(
|
|
1165
1238
|
|
1166
1239
|
# template_rot_ft
|
1167
1240
|
rfftn(template_rot, temp_ft)
|
1168
|
-
template_filter_func(temp_ft, template_filter
|
1241
|
+
template_filter_func(ft_template=temp_ft, template_filter=template_filter)
|
1169
1242
|
irfftn(target_mask_ft * temp_ft, temp2)
|
1170
1243
|
irfftn(target_ft * temp_ft, numerator)
|
1171
1244
|
|
@@ -1221,25 +1294,68 @@ def mcc_scoring(
|
|
1221
1294
|
mask_overlap, axis=axes, keepdims=True
|
1222
1295
|
)
|
1223
1296
|
temp[mask_overlap < number_px_threshold] = 0.0
|
1224
|
-
convolution_mode = kwargs.get("convolution_mode", "full")
|
1225
1297
|
|
1226
|
-
|
1227
|
-
temp
|
1228
|
-
|
1229
|
-
|
1230
|
-
|
1298
|
+
callback_func(
|
1299
|
+
temp,
|
1300
|
+
rotation_matrix=rotation,
|
1301
|
+
rotation_index=index,
|
1302
|
+
**callback_class_args,
|
1231
1303
|
)
|
1232
|
-
if callback_class is not None:
|
1233
|
-
callback(
|
1234
|
-
score,
|
1235
|
-
rotation_matrix=rotation,
|
1236
|
-
rotation_index=index,
|
1237
|
-
**callback_class_args,
|
1238
|
-
)
|
1239
1304
|
|
1240
1305
|
return callback
|
1241
1306
|
|
1242
1307
|
|
1308
|
+
def _setup_template_filter_apply_target_filter(
|
1309
|
+
matching_data: MatchingData,
|
1310
|
+
rfftn: Callable,
|
1311
|
+
irfftn: Callable,
|
1312
|
+
fast_shape: Tuple[int],
|
1313
|
+
fast_ft_shape: Tuple[int],
|
1314
|
+
):
|
1315
|
+
filter_template = isinstance(matching_data.template_filter, Compose)
|
1316
|
+
filter_target = isinstance(matching_data.target_filter, Compose)
|
1317
|
+
|
1318
|
+
template_filter = backend.full(
|
1319
|
+
shape=(1,), fill_value=1, dtype=backend._default_dtype
|
1320
|
+
)
|
1321
|
+
|
1322
|
+
if not filter_template and not filter_target:
|
1323
|
+
return template_filter
|
1324
|
+
|
1325
|
+
target_temp = backend.astype(
|
1326
|
+
backend.topleft_pad(matching_data.target, fast_shape), backend._default_dtype
|
1327
|
+
)
|
1328
|
+
target_temp_ft = backend.preallocate_array(fast_ft_shape, backend._complex_dtype)
|
1329
|
+
rfftn(target_temp, target_temp_ft)
|
1330
|
+
if isinstance(matching_data.template_filter, Compose):
|
1331
|
+
template_filter = matching_data.template_filter(
|
1332
|
+
shape=fast_shape,
|
1333
|
+
return_real_fourier=True,
|
1334
|
+
shape_is_real_fourier=False,
|
1335
|
+
data_rfft=target_temp_ft,
|
1336
|
+
)
|
1337
|
+
template_filter = template_filter["data"]
|
1338
|
+
template_filter[tuple(0 for _ in range(template_filter.ndim))] = 0
|
1339
|
+
|
1340
|
+
if isinstance(matching_data.target_filter, Compose):
|
1341
|
+
target_filter = matching_data.target_filter(
|
1342
|
+
shape=fast_shape,
|
1343
|
+
return_real_fourier=True,
|
1344
|
+
shape_is_real_fourier=False,
|
1345
|
+
data_rfft=target_temp_ft,
|
1346
|
+
weight_type=None,
|
1347
|
+
)
|
1348
|
+
target_filter = target_filter["data"]
|
1349
|
+
backend.multiply(target_temp_ft, target_filter, out=target_temp_ft)
|
1350
|
+
|
1351
|
+
irfftn(target_temp_ft, target_temp)
|
1352
|
+
matching_data._target = backend.topleft_pad(
|
1353
|
+
target_temp, matching_data.target.shape
|
1354
|
+
)
|
1355
|
+
|
1356
|
+
return template_filter
|
1357
|
+
|
1358
|
+
|
1243
1359
|
def device_memory_handler(func: Callable):
|
1244
1360
|
"""Decorator function providing SharedMemory Handler."""
|
1245
1361
|
|
@@ -1310,18 +1426,17 @@ def scan(
|
|
1310
1426
|
Tuple
|
1311
1427
|
The merged results from callback_class if provided otherwise None.
|
1312
1428
|
"""
|
1429
|
+
matching_data.to_backend()
|
1313
1430
|
shape_diff = backend.subtract(
|
1314
|
-
matching_data.
|
1431
|
+
matching_data._output_target_shape, matching_data._output_template_shape
|
1315
1432
|
)
|
1433
|
+
shape_diff = backend.multiply(shape_diff, ~matching_data._batch_mask)
|
1316
1434
|
if backend.sum(shape_diff < 0) and not pad_fourier:
|
1317
1435
|
warnings.warn(
|
1318
1436
|
"Target is larger than template and Fourier padding is turned off."
|
1319
|
-
" This can lead to shifted results. You can swap template and target,
|
1320
|
-
" zero-pad the target."
|
1437
|
+
" This can lead to shifted results. You can swap template and target, "
|
1438
|
+
" zero-pad the target or turn off template centering."
|
1321
1439
|
)
|
1322
|
-
|
1323
|
-
matching_data.to_backend()
|
1324
|
-
|
1325
1440
|
fast_shape, fast_ft_shape, fourier_shift = matching_data.fourier_padding(
|
1326
1441
|
pad_fourier=pad_fourier
|
1327
1442
|
)
|
@@ -1334,6 +1449,15 @@ def scan(
|
|
1334
1449
|
complex_dtype=matching_data._complex_dtype,
|
1335
1450
|
fftargs=fftargs,
|
1336
1451
|
)
|
1452
|
+
|
1453
|
+
template_filter = _setup_template_filter_apply_target_filter(
|
1454
|
+
matching_data=matching_data,
|
1455
|
+
rfftn=rfftn,
|
1456
|
+
irfftn=irfftn,
|
1457
|
+
fast_shape=fast_shape,
|
1458
|
+
fast_ft_shape=fast_ft_shape,
|
1459
|
+
)
|
1460
|
+
|
1337
1461
|
setup = matching_setup(
|
1338
1462
|
rfftn=rfftn,
|
1339
1463
|
irfftn=irfftn,
|
@@ -1351,22 +1475,7 @@ def scan(
|
|
1351
1475
|
)
|
1352
1476
|
rfftn, irfftn = None, None
|
1353
1477
|
|
1354
|
-
template_filter
|
1355
|
-
for method, parameters in matching_data.template_filter.items():
|
1356
|
-
parameters["shape"] = fast_shape
|
1357
|
-
parameters["omit_negative_frequencies"] = True
|
1358
|
-
out = preprocessor.apply_method(method=method, parameters=parameters)
|
1359
|
-
if template_filter is None:
|
1360
|
-
template_filter = out
|
1361
|
-
np.multiply(template_filter, out, out=template_filter)
|
1362
|
-
|
1363
|
-
if template_filter is None:
|
1364
|
-
template_filter = backend.full(
|
1365
|
-
shape=(1,), fill_value=1, dtype=backend._default_dtype
|
1366
|
-
)
|
1367
|
-
else:
|
1368
|
-
template_filter = backend.to_backend_array(template_filter)
|
1369
|
-
|
1478
|
+
template_filter = backend.to_backend_array(template_filter)
|
1370
1479
|
template_filter = backend.astype(template_filter, backend._default_dtype)
|
1371
1480
|
template_filter_buffer = backend.arr_to_sharedarr(
|
1372
1481
|
arr=template_filter,
|
@@ -1388,14 +1497,20 @@ def scan(
|
|
1388
1497
|
callback_class = setup.pop("callback_class", callback_class)
|
1389
1498
|
callback_class_args = setup.pop("callback_class_args", callback_class_args)
|
1390
1499
|
callback_classes = [callback_class for _ in range(n_callback_classes)]
|
1500
|
+
|
1501
|
+
convolution_mode = "same"
|
1502
|
+
if backend.sum(backend.to_backend_array(matching_data._target_pad)) > 0:
|
1503
|
+
convolution_mode = "valid"
|
1504
|
+
|
1505
|
+
callback_class_args["fourier_shift"] = fourier_shift
|
1506
|
+
callback_class_args["convolution_mode"] = convolution_mode
|
1507
|
+
callback_class_args["targetshape"] = setup["targetshape"]
|
1508
|
+
callback_class_args["templateshape"] = setup["templateshape"]
|
1509
|
+
|
1391
1510
|
if callback_class == MaxScoreOverRotations:
|
1392
|
-
score_space_shape = backend.subtract(
|
1393
|
-
matching_data.target.shape,
|
1394
|
-
matching_data._target_pad,
|
1395
|
-
)
|
1396
1511
|
callback_classes = [
|
1397
1512
|
class_name(
|
1398
|
-
score_space_shape=
|
1513
|
+
score_space_shape=fast_shape,
|
1399
1514
|
score_space_dtype=matching_data._default_dtype,
|
1400
1515
|
shared_memory_handler=kwargs.get("shared_memory_handler", None),
|
1401
1516
|
rotation_space_dtype=backend._default_dtype_int,
|
@@ -1435,9 +1550,20 @@ def scan(
|
|
1435
1550
|
for index, rotation in enumerate(rotation_list)
|
1436
1551
|
)
|
1437
1552
|
|
1553
|
+
callbacks = callbacks[0:n_callback_classes]
|
1438
1554
|
callbacks = [
|
1439
|
-
tuple(
|
1440
|
-
|
1555
|
+
tuple(
|
1556
|
+
callback._postprocess(
|
1557
|
+
fourier_shift=fourier_shift,
|
1558
|
+
convolution_mode=convolution_mode,
|
1559
|
+
targetshape=setup["targetshape"],
|
1560
|
+
templateshape=setup["templateshape"],
|
1561
|
+
shared_memory_handler=kwargs.get("shared_memory_handler", None),
|
1562
|
+
)
|
1563
|
+
)
|
1564
|
+
if hasattr(callback, "_postprocess")
|
1565
|
+
else tuple(callback)
|
1566
|
+
for callback in callbacks
|
1441
1567
|
if callback is not None
|
1442
1568
|
]
|
1443
1569
|
backend.free_cache()
|
@@ -1549,11 +1675,13 @@ def scan_subsets(
|
|
1549
1675
|
matching_data._target, matching_data._template = None, None
|
1550
1676
|
matching_data._target_mask, matching_data._template_mask = None, None
|
1551
1677
|
|
1678
|
+
candidates = None
|
1552
1679
|
if callback_class is not None:
|
1553
1680
|
candidates = callback_class.merge(
|
1554
1681
|
results, **callback_class_args, inner_merge=False
|
1555
1682
|
)
|
1556
|
-
|
1683
|
+
|
1684
|
+
return candidates
|
1557
1685
|
|
1558
1686
|
|
1559
1687
|
MATCHING_EXHAUSTIVE_REGISTER = {
|
@@ -1563,6 +1691,7 @@ MATCHING_EXHAUSTIVE_REGISTER = {
|
|
1563
1691
|
"CAM": (cam_setup, corr_scoring),
|
1564
1692
|
"FLCSphericalMask": (flcSphericalMask_setup, corr_scoring),
|
1565
1693
|
"FLC": (flc_setup, flc_scoring),
|
1694
|
+
"FLC2": (flc_setup, flc_scoring2),
|
1566
1695
|
"MCC": (mcc_setup, mcc_scoring),
|
1567
1696
|
}
|
1568
1697
|
|