pytme 0.1.8__cp311-cp311-macosx_14_0_arm64.whl → 0.2.0__cp311-cp311-macosx_14_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pytme-0.2.0.data/scripts/match_template.py +1019 -0
- pytme-0.2.0.data/scripts/postprocess.py +570 -0
- {pytme-0.1.8.data → pytme-0.2.0.data}/scripts/preprocessor_gui.py +244 -60
- {pytme-0.1.8.dist-info → pytme-0.2.0.dist-info}/METADATA +3 -1
- pytme-0.2.0.dist-info/RECORD +72 -0
- {pytme-0.1.8.dist-info → pytme-0.2.0.dist-info}/WHEEL +1 -1
- scripts/extract_candidates.py +218 -0
- scripts/match_template.py +459 -218
- pytme-0.1.8.data/scripts/match_template.py → scripts/match_template_filters.py +459 -218
- scripts/postprocess.py +380 -435
- scripts/preprocessor_gui.py +244 -60
- scripts/refine_matches.py +218 -0
- tme/__init__.py +2 -1
- tme/__version__.py +1 -1
- tme/analyzer.py +533 -78
- tme/backends/cupy_backend.py +80 -15
- tme/backends/npfftw_backend.py +35 -6
- tme/backends/pytorch_backend.py +15 -7
- tme/density.py +173 -78
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/matching_constrained.py +195 -0
- tme/matching_data.py +78 -32
- tme/matching_exhaustive.py +369 -221
- tme/matching_memory.py +1 -0
- tme/matching_optimization.py +753 -649
- tme/matching_utils.py +152 -8
- tme/orientations.py +561 -0
- tme/preprocessing/__init__.py +2 -0
- tme/preprocessing/_utils.py +176 -0
- tme/preprocessing/composable_filter.py +30 -0
- tme/preprocessing/compose.py +52 -0
- tme/preprocessing/frequency_filters.py +322 -0
- tme/preprocessing/tilt_series.py +967 -0
- tme/preprocessor.py +35 -25
- tme/structure.py +2 -37
- pytme-0.1.8.data/scripts/postprocess.py +0 -625
- pytme-0.1.8.dist-info/RECORD +0 -61
- {pytme-0.1.8.data → pytme-0.2.0.data}/scripts/estimate_ram_usage.py +0 -0
- {pytme-0.1.8.data → pytme-0.2.0.data}/scripts/preprocess.py +0 -0
- {pytme-0.1.8.dist-info → pytme-0.2.0.dist-info}/LICENSE +0 -0
- {pytme-0.1.8.dist-info → pytme-0.2.0.dist-info}/entry_points.txt +0 -0
- {pytme-0.1.8.dist-info → pytme-0.2.0.dist-info}/top_level.txt +0 -0
tme/matching_exhaustive.py
CHANGED
@@ -19,13 +19,12 @@ from scipy.ndimage import laplace
|
|
19
19
|
|
20
20
|
from .analyzer import MaxScoreOverRotations
|
21
21
|
from .matching_utils import (
|
22
|
-
apply_convolution_mode,
|
23
22
|
handle_traceback,
|
24
23
|
split_numpy_array_slices,
|
25
24
|
conditional_execute,
|
26
25
|
)
|
27
26
|
from .matching_memory import MatchingMemoryUsage, MATCHING_MEMORY_REGISTRY
|
28
|
-
from .
|
27
|
+
from .preprocessing import Compose
|
29
28
|
from .matching_data import MatchingData
|
30
29
|
from .backends import backend
|
31
30
|
from .types import NDArray, CallbackClass
|
@@ -43,6 +42,58 @@ def _run_inner(backend_name, backend_args, **kwargs):
|
|
43
42
|
return scan(**kwargs)
|
44
43
|
|
45
44
|
|
45
|
+
def normalize_under_mask(template: NDArray, mask: NDArray, mask_intensity) -> None:
|
46
|
+
"""
|
47
|
+
Standardizes the values in in template by subtracting the mean and dividing by the
|
48
|
+
standard deviation based on the elements in mask. Subsequently, the template is
|
49
|
+
multiplied by the mask.
|
50
|
+
|
51
|
+
Parameters
|
52
|
+
----------
|
53
|
+
template : NDArray
|
54
|
+
The data array to be normalized. This array is modified in-place.
|
55
|
+
mask : NDArray
|
56
|
+
A boolean array of the same shape as `template`. True values indicate the positions in `template`
|
57
|
+
to consider for normalization.
|
58
|
+
mask_intensity : float
|
59
|
+
Mask intensity used to compute expectations.
|
60
|
+
|
61
|
+
References
|
62
|
+
----------
|
63
|
+
.. [1] T. Hrabe, Y. Chen, S. Pfeffer, L. Kuhn Cuellar, A.-V. Mangold,
|
64
|
+
and F. Förster, J. Struct. Biol. 178, 177 (2012).
|
65
|
+
.. [2] M. L. Chaillet, G. van der Schot, I. Gubins, S. Roet,
|
66
|
+
R. C. Veltkamp, and F. Förster, Int. J. Mol. Sci. 24,
|
67
|
+
13375 (2023)
|
68
|
+
|
69
|
+
Returns
|
70
|
+
-------
|
71
|
+
None
|
72
|
+
This function modifies `template` in-place and does not return any value.
|
73
|
+
"""
|
74
|
+
masked_mean = backend.sum(backend.multiply(template, mask))
|
75
|
+
masked_mean = backend.divide(masked_mean, mask_intensity)
|
76
|
+
masked_std = backend.sum(backend.multiply(backend.square(template), mask))
|
77
|
+
masked_std = backend.subtract(
|
78
|
+
masked_std / mask_intensity, backend.square(masked_mean)
|
79
|
+
)
|
80
|
+
masked_std = backend.sqrt(backend.maximum(masked_std, 0))
|
81
|
+
|
82
|
+
backend.subtract(template, masked_mean, out=template)
|
83
|
+
backend.divide(template, masked_std, out=template)
|
84
|
+
backend.multiply(template, mask, out=template)
|
85
|
+
return None
|
86
|
+
|
87
|
+
|
88
|
+
def apply_filter(ft_template, template_filter):
|
89
|
+
# This is an approximation to applying the mask, irfftn, normalize, rfftn
|
90
|
+
std_before = backend.std(ft_template)
|
91
|
+
backend.multiply(ft_template, template_filter, out=ft_template)
|
92
|
+
backend.multiply(
|
93
|
+
ft_template, std_before / backend.std(ft_template), out=ft_template
|
94
|
+
)
|
95
|
+
|
96
|
+
|
46
97
|
def cc_setup(
|
47
98
|
rfftn: Callable,
|
48
99
|
irfftn: Callable,
|
@@ -208,11 +259,16 @@ def corr_setup(
|
|
208
259
|
target_pad, ft_target2, ft_window_template = None, None, None
|
209
260
|
|
210
261
|
# Normalizing constants
|
211
|
-
|
212
|
-
|
262
|
+
n_observations = backend.sum(template_mask)
|
263
|
+
template_mean = backend.sum(backend.multiply(template, template_mask))
|
264
|
+
template_mean = backend.divide(template_mean, n_observations)
|
213
265
|
template_ssd = backend.sum(
|
214
|
-
backend.square(
|
266
|
+
backend.square(
|
267
|
+
backend.multiply(backend.multiply(template, template_mean), template_mask)
|
268
|
+
)
|
215
269
|
)
|
270
|
+
template_volume = np.prod(template.shape)
|
271
|
+
backend.multiply(template, template_mask, out=template)
|
216
272
|
|
217
273
|
# Final numerator is score - numerator2
|
218
274
|
numerator2 = backend.multiply(target_window_sum, template_mean)
|
@@ -308,49 +364,6 @@ def cam_setup(**kwargs):
|
|
308
364
|
return corr_setup(**kwargs)
|
309
365
|
|
310
366
|
|
311
|
-
def _normalize_under_mask(template: NDArray, mask: NDArray, mask_intensity) -> None:
|
312
|
-
"""
|
313
|
-
Standardizes the values in in template by subtracting the mean and dividing by the
|
314
|
-
standard deviation based on the elements in mask. Subsequently, the template is
|
315
|
-
multiplied by the mask.
|
316
|
-
|
317
|
-
Parameters
|
318
|
-
----------
|
319
|
-
template : NDArray
|
320
|
-
The data array to be normalized. This array is modified in-place.
|
321
|
-
mask : NDArray
|
322
|
-
A boolean array of the same shape as `template`. True values indicate the positions in `template`
|
323
|
-
to consider for normalization.
|
324
|
-
mask_intensity : float
|
325
|
-
Mask intensity used to compute expectations.
|
326
|
-
|
327
|
-
References
|
328
|
-
----------
|
329
|
-
.. [1] T. Hrabe, Y. Chen, S. Pfeffer, L. Kuhn Cuellar, A.-V. Mangold,
|
330
|
-
and F. Förster, J. Struct. Biol. 178, 177 (2012).
|
331
|
-
.. [2] M. L. Chaillet, G. van der Schot, I. Gubins, S. Roet,
|
332
|
-
R. C. Veltkamp, and F. Förster, Int. J. Mol. Sci. 24,
|
333
|
-
13375 (2023)
|
334
|
-
|
335
|
-
Returns
|
336
|
-
-------
|
337
|
-
None
|
338
|
-
This function modifies `template` in-place and does not return any value.
|
339
|
-
"""
|
340
|
-
masked_mean = backend.sum(backend.multiply(template, mask))
|
341
|
-
masked_mean = backend.divide(masked_mean, mask_intensity)
|
342
|
-
masked_std = backend.sum(backend.multiply(backend.square(template), mask))
|
343
|
-
masked_std = backend.subtract(
|
344
|
-
masked_std / mask_intensity, backend.square(masked_mean)
|
345
|
-
)
|
346
|
-
masked_std = backend.sqrt(backend.maximum(masked_std, 0))
|
347
|
-
|
348
|
-
backend.subtract(template, masked_mean, out=template)
|
349
|
-
backend.divide(template, masked_std, out=template)
|
350
|
-
backend.multiply(template, mask, out=template)
|
351
|
-
return None
|
352
|
-
|
353
|
-
|
354
367
|
def flc_setup(
|
355
368
|
rfftn: Callable,
|
356
369
|
irfftn: Callable,
|
@@ -413,7 +426,7 @@ def flc_setup(
|
|
413
426
|
arr=ft_target2, shared_memory_handler=shared_memory_handler
|
414
427
|
)
|
415
428
|
|
416
|
-
|
429
|
+
normalize_under_mask(
|
417
430
|
template=template, mask=template_mask, mask_intensity=backend.sum(template_mask)
|
418
431
|
)
|
419
432
|
|
@@ -535,13 +548,16 @@ def flcSphericalMask_setup(
|
|
535
548
|
backend.fill(temp2, 0)
|
536
549
|
temp2[nonzero_indices] = 1 / temp[nonzero_indices]
|
537
550
|
|
538
|
-
|
551
|
+
normalize_under_mask(
|
539
552
|
template=template, mask=template_mask, mask_intensity=backend.sum(template_mask)
|
540
553
|
)
|
541
554
|
|
542
555
|
template_buffer = backend.arr_to_sharedarr(
|
543
556
|
arr=template, shared_memory_handler=shared_memory_handler
|
544
557
|
)
|
558
|
+
template_mask_buffer = backend.arr_to_sharedarr(
|
559
|
+
arr=template_mask, shared_memory_handler=shared_memory_handler
|
560
|
+
)
|
545
561
|
target_ft_buffer = backend.arr_to_sharedarr(
|
546
562
|
arr=ft_target, shared_memory_handler=shared_memory_handler
|
547
563
|
)
|
@@ -553,6 +569,7 @@ def flcSphericalMask_setup(
|
|
553
569
|
)
|
554
570
|
|
555
571
|
template_tuple = (template_buffer, template.shape, real_dtype)
|
572
|
+
template_mask_tuple = (template_mask_buffer, template.shape, real_dtype)
|
556
573
|
target_ft_tuple = (target_ft_buffer, fast_ft_shape, complex_dtype)
|
557
574
|
|
558
575
|
inv_denominator_tuple = (inv_denominator_buffer, fast_shape, real_dtype)
|
@@ -560,6 +577,7 @@ def flcSphericalMask_setup(
|
|
560
577
|
|
561
578
|
ret = {
|
562
579
|
"template": template_tuple,
|
580
|
+
"template_mask": template_mask_tuple,
|
563
581
|
"ft_target": target_ft_tuple,
|
564
582
|
"inv_denominator": inv_denominator_tuple,
|
565
583
|
"numerator2": numerator2_tuple,
|
@@ -717,10 +735,6 @@ def corr_scoring(
|
|
717
735
|
datatype.
|
718
736
|
numerator2 : Tuple[type, Tuple[int], type]
|
719
737
|
Tuple containing a pointer to the numerator2 data, its shape, and its datatype.
|
720
|
-
targetshape : Tuple[int]
|
721
|
-
The shape of the target.
|
722
|
-
templateshape : Tuple[int]
|
723
|
-
The shape of the template.
|
724
738
|
fast_shape : Tuple[int]
|
725
739
|
The shape for fast Fourier transform.
|
726
740
|
fast_ft_shape : Tuple[int]
|
@@ -739,8 +753,6 @@ def corr_scoring(
|
|
739
753
|
instantiable.
|
740
754
|
interpolation_order : int
|
741
755
|
The order of interpolation to be used while rotating the template.
|
742
|
-
convolution_mode : str, optional
|
743
|
-
Mode to use for convolution, default is "full".
|
744
756
|
**kwargs :
|
745
757
|
Additional arguments to be passed to the function.
|
746
758
|
|
@@ -756,31 +768,23 @@ def corr_scoring(
|
|
756
768
|
:py:meth:`cam_setup`
|
757
769
|
:py:meth:`flcSphericalMask_setup`
|
758
770
|
"""
|
759
|
-
|
760
|
-
ft_target_buffer, ft_target_shape, ft_target_dtype = ft_target
|
761
|
-
inv_denominator_buffer, inv_denominator_pointer_shape, _ = inv_denominator
|
762
|
-
numerator2_buffer, numerator2_shape, _ = numerator2
|
763
|
-
filter_buffer, filter_shape, filter_dtype = template_filter
|
764
|
-
|
771
|
+
callback = callback_class
|
765
772
|
if callback_class is not None and isinstance(callback_class, type):
|
766
773
|
callback = callback_class(**callback_class_args)
|
767
|
-
elif not isinstance(callback_class, type):
|
768
|
-
callback = callback_class
|
769
774
|
|
770
|
-
|
771
|
-
template = backend.sharedarr_to_arr(template_shape, template_dtype
|
772
|
-
ft_target = backend.sharedarr_to_arr(
|
773
|
-
|
774
|
-
)
|
775
|
-
|
776
|
-
|
777
|
-
|
778
|
-
|
779
|
-
|
780
|
-
|
781
|
-
|
782
|
-
|
783
|
-
)
|
775
|
+
template_buffer, template_shape, template_dtype = template
|
776
|
+
template = backend.sharedarr_to_arr(template_buffer, template_shape, template_dtype)
|
777
|
+
ft_target = backend.sharedarr_to_arr(*ft_target)
|
778
|
+
inv_denominator = backend.sharedarr_to_arr(*inv_denominator)
|
779
|
+
numerator2 = backend.sharedarr_to_arr(*numerator2)
|
780
|
+
template_filter = backend.sharedarr_to_arr(*template_filter)
|
781
|
+
|
782
|
+
norm_template, template_mask, mask_sum = False, 1, 1
|
783
|
+
if "template_mask" in kwargs:
|
784
|
+
template_mask = backend.sharedarr_to_arr(
|
785
|
+
kwargs["template_mask"][0], template_shape, template_dtype
|
786
|
+
)
|
787
|
+
norm_template, mask_sum = True, backend.sum(template_mask)
|
784
788
|
|
785
789
|
arr = backend.preallocate_array(fast_shape, real_dtype)
|
786
790
|
ft_temp = backend.preallocate_array(fast_ft_shape, complex_dtype)
|
@@ -799,17 +803,16 @@ def corr_scoring(
|
|
799
803
|
norm_denominator = (backend.sum(inv_denominator) != 1) & (
|
800
804
|
backend.size(inv_denominator) != 1
|
801
805
|
)
|
802
|
-
filter_template = backend.size(template_filter) != 0
|
803
806
|
|
807
|
+
norm_template = conditional_execute(normalize_under_mask, norm_template)
|
808
|
+
callback_func = conditional_execute(callback, callback_class is not None)
|
804
809
|
norm_func_numerator = conditional_execute(backend.subtract, norm_numerator)
|
805
810
|
norm_func_denominator = conditional_execute(backend.multiply, norm_denominator)
|
806
|
-
template_filter_func = conditional_execute(
|
807
|
-
|
808
|
-
|
809
|
-
fourier_shift = callback_class_args.get("fourier_shift", backend.zeros(arr.ndim))
|
810
|
-
fourier_shift_scores = backend.sum(fourier_shift != 0) != 0
|
811
|
+
template_filter_func = conditional_execute(
|
812
|
+
apply_filter, backend.size(template_filter) != 1
|
813
|
+
)
|
811
814
|
|
812
|
-
|
815
|
+
unpadded_slice = tuple(slice(0, stop) for stop in template.shape)
|
813
816
|
for index in range(rotations.shape[0]):
|
814
817
|
rotation = rotations[index]
|
815
818
|
backend.fill(arr, 0)
|
@@ -820,33 +823,24 @@ def corr_scoring(
|
|
820
823
|
use_geometric_center=False,
|
821
824
|
order=interpolation_order,
|
822
825
|
)
|
823
|
-
rotation_norm = template_sum / backend.sum(arr)
|
824
|
-
backend.multiply(arr, rotation_norm, out=arr)
|
825
826
|
|
826
|
-
|
827
|
-
template_filter_func(ft_temp, template_filter, out=ft_temp)
|
827
|
+
norm_template(arr[unpadded_slice], template_mask, mask_sum)
|
828
828
|
|
829
|
+
rfftn(arr, ft_temp)
|
830
|
+
template_filter_func(ft_template=ft_temp, template_filter=template_filter)
|
829
831
|
backend.multiply(ft_target, ft_temp, out=ft_temp)
|
830
832
|
irfftn(ft_temp, arr)
|
831
833
|
|
832
834
|
norm_func_numerator(arr, numerator2, out=arr)
|
833
835
|
norm_func_denominator(arr, inv_denominator, out=arr)
|
834
836
|
|
835
|
-
|
836
|
-
arr
|
837
|
-
|
838
|
-
|
839
|
-
|
837
|
+
callback_func(
|
838
|
+
arr,
|
839
|
+
rotation_matrix=rotation,
|
840
|
+
rotation_index=index,
|
841
|
+
**callback_class_args,
|
840
842
|
)
|
841
843
|
|
842
|
-
if callback_class is not None:
|
843
|
-
callback(
|
844
|
-
score,
|
845
|
-
rotation_matrix=rotation,
|
846
|
-
rotation_index=index,
|
847
|
-
**callback_class_args,
|
848
|
-
)
|
849
|
-
|
850
844
|
return callback
|
851
845
|
|
852
846
|
|
@@ -894,32 +888,15 @@ def flc_scoring(
|
|
894
888
|
.. [2] T. Hrabe, Y. Chen, S. Pfeffer, L. Kuhn Cuellar, A.-V. Mangold,
|
895
889
|
and F. Förster, J. Struct. Biol. 178, 177 (2012).
|
896
890
|
"""
|
897
|
-
|
898
|
-
template_mask_buffer, *_ = template_mask
|
899
|
-
filter_buffer, filter_shape, filter_dtype = template_filter
|
900
|
-
|
901
|
-
ft_target_buffer, ft_target_shape, ft_target_dtype = ft_target
|
902
|
-
ft_target2_buffer, *_ = ft_target2
|
903
|
-
|
891
|
+
callback = callback_class
|
904
892
|
if callback_class is not None and isinstance(callback_class, type):
|
905
893
|
callback = callback_class(**callback_class_args)
|
906
|
-
elif not isinstance(callback_class, type):
|
907
|
-
callback = callback_class
|
908
894
|
|
909
|
-
|
910
|
-
|
911
|
-
|
912
|
-
|
913
|
-
)
|
914
|
-
ft_target = backend.sharedarr_to_arr(
|
915
|
-
ft_target_shape, ft_target_dtype, ft_target_buffer
|
916
|
-
)
|
917
|
-
ft_target2 = backend.sharedarr_to_arr(
|
918
|
-
ft_target_shape, ft_target_dtype, ft_target2_buffer
|
919
|
-
)
|
920
|
-
template_filter = backend.sharedarr_to_arr(
|
921
|
-
filter_shape, filter_dtype, filter_buffer
|
922
|
-
)
|
895
|
+
template = backend.sharedarr_to_arr(*template)
|
896
|
+
template_mask = backend.sharedarr_to_arr(*template_mask)
|
897
|
+
ft_target = backend.sharedarr_to_arr(*ft_target)
|
898
|
+
ft_target2 = backend.sharedarr_to_arr(*ft_target2)
|
899
|
+
template_filter = backend.sharedarr_to_arr(*template_filter)
|
923
900
|
|
924
901
|
arr = backend.preallocate_array(fast_shape, real_dtype)
|
925
902
|
temp = backend.preallocate_array(fast_shape, real_dtype)
|
@@ -938,12 +915,10 @@ def flc_scoring(
|
|
938
915
|
temp_fft=ft_temp,
|
939
916
|
)
|
940
917
|
eps = backend.eps(real_dtype)
|
941
|
-
|
942
|
-
|
943
|
-
|
944
|
-
|
945
|
-
fourier_shift = callback_class_args.get("fourier_shift", backend.zeros(arr.ndim))
|
946
|
-
fourier_shift_scores = backend.sum(fourier_shift != 0) != 0
|
918
|
+
template_filter_func = conditional_execute(
|
919
|
+
apply_filter, backend.size(template_filter) != 1
|
920
|
+
)
|
921
|
+
callback_func = conditional_execute(callback, callback_class is not None)
|
947
922
|
|
948
923
|
unpadded_slice = tuple(slice(0, stop) for stop in template.shape)
|
949
924
|
for index in range(rotations.shape[0]):
|
@@ -962,7 +937,7 @@ def flc_scoring(
|
|
962
937
|
# Given the amount of FFTs, might aswell normalize properly
|
963
938
|
n_observations = backend.sum(temp)
|
964
939
|
|
965
|
-
|
940
|
+
normalize_under_mask(
|
966
941
|
template=arr[unpadded_slice],
|
967
942
|
mask=temp[unpadded_slice],
|
968
943
|
mask_intensity=n_observations,
|
@@ -985,7 +960,7 @@ def flc_scoring(
|
|
985
960
|
backend.multiply(temp, n_observations, out=temp)
|
986
961
|
|
987
962
|
rfftn(arr, ft_temp)
|
988
|
-
template_filter_func(ft_temp, template_filter
|
963
|
+
template_filter_func(ft_template=ft_temp, template_filter=template_filter)
|
989
964
|
backend.multiply(ft_target, ft_temp, out=ft_temp)
|
990
965
|
irfftn(ft_temp, arr)
|
991
966
|
|
@@ -994,23 +969,161 @@ def flc_scoring(
|
|
994
969
|
backend.fill(temp2, 0)
|
995
970
|
temp2[nonzero_indices] = arr[nonzero_indices] / temp[nonzero_indices]
|
996
971
|
|
997
|
-
|
972
|
+
callback_func(
|
973
|
+
temp2,
|
974
|
+
rotation_matrix=rotation,
|
975
|
+
rotation_index=index,
|
976
|
+
**callback_class_args,
|
977
|
+
)
|
978
|
+
|
979
|
+
return callback
|
998
980
|
|
999
|
-
if fourier_shift_scores:
|
1000
|
-
temp2 = backend.roll(temp2, shift=fourier_shift, axis=axis)
|
1001
981
|
|
1002
|
-
|
1003
|
-
|
982
|
+
def flc_scoring2(
|
983
|
+
template: Tuple[type, Tuple[int], type],
|
984
|
+
template_mask: Tuple[type, Tuple[int], type],
|
985
|
+
ft_target: Tuple[type, Tuple[int], type],
|
986
|
+
ft_target2: Tuple[type, Tuple[int], type],
|
987
|
+
template_filter: Tuple[type, Tuple[int], type],
|
988
|
+
targetshape: Tuple[int],
|
989
|
+
templateshape: Tuple[int],
|
990
|
+
fast_shape: Tuple[int],
|
991
|
+
fast_ft_shape: Tuple[int],
|
992
|
+
rotations: NDArray,
|
993
|
+
real_dtype: type,
|
994
|
+
complex_dtype: type,
|
995
|
+
callback_class: CallbackClass,
|
996
|
+
callback_class_args: Dict,
|
997
|
+
interpolation_order: int,
|
998
|
+
**kwargs,
|
999
|
+
) -> CallbackClass:
|
1000
|
+
"""
|
1001
|
+
Computes a normalized cross-correlation score of a target f a template g
|
1002
|
+
and a mask m:
|
1003
|
+
|
1004
|
+
.. math::
|
1005
|
+
|
1006
|
+
\\frac{CC(f, \\frac{g*m - \\overline{g*m}}{\\sigma_{g*m}})}
|
1007
|
+
{N_m * \\sqrt{
|
1008
|
+
\\frac{CC(f^2, m)}{N_m} - (\\frac{CC(f, m)}{N_m})^2}
|
1009
|
+
}
|
1010
|
+
|
1011
|
+
Where:
|
1012
|
+
|
1013
|
+
.. math::
|
1014
|
+
|
1015
|
+
CC(f,g) = \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*)
|
1016
|
+
|
1017
|
+
and Nm is the number of voxels within the template mask m.
|
1018
|
+
|
1019
|
+
References
|
1020
|
+
----------
|
1021
|
+
.. [1] W. Wan, S. Khavnekar, J. Wagner, P. Erdmann, and W. Baumeister
|
1022
|
+
Microsc. Microanal. 26, 2516 (2020)
|
1023
|
+
.. [2] T. Hrabe, Y. Chen, S. Pfeffer, L. Kuhn Cuellar, A.-V. Mangold,
|
1024
|
+
and F. Förster, J. Struct. Biol. 178, 177 (2012).
|
1025
|
+
"""
|
1026
|
+
callback = callback_class
|
1027
|
+
if callback_class is not None and isinstance(callback_class, type):
|
1028
|
+
callback = callback_class(**callback_class_args)
|
1029
|
+
|
1030
|
+
# Retrieve objects from shared memory
|
1031
|
+
template = backend.sharedarr_to_arr(*template)
|
1032
|
+
template_mask = backend.sharedarr_to_arr(*template_mask)
|
1033
|
+
ft_target = backend.sharedarr_to_arr(*ft_target)
|
1034
|
+
ft_target2 = backend.sharedarr_to_arr(*ft_target2)
|
1035
|
+
template_filter = backend.sharedarr_to_arr(*template_filter)
|
1036
|
+
|
1037
|
+
arr = backend.preallocate_array(fast_shape, real_dtype)
|
1038
|
+
temp = backend.preallocate_array(fast_shape, real_dtype)
|
1039
|
+
temp2 = backend.preallocate_array(fast_shape, real_dtype)
|
1040
|
+
|
1041
|
+
ft_temp = backend.preallocate_array(fast_ft_shape, complex_dtype)
|
1042
|
+
ft_denom = backend.preallocate_array(fast_ft_shape, complex_dtype)
|
1043
|
+
|
1044
|
+
eps = backend.eps(real_dtype)
|
1045
|
+
template_filter_func = conditional_execute(
|
1046
|
+
apply_filter, backend.size(template_filter) != 1
|
1047
|
+
)
|
1048
|
+
callback_func = conditional_execute(callback, callback_class is not None)
|
1049
|
+
|
1050
|
+
squeeze_axis = tuple(i for i, x in enumerate(template.shape) if x == 1)
|
1051
|
+
squeeze = tuple(
|
1052
|
+
slice(0, stop) if i not in squeeze_axis else 0
|
1053
|
+
for i, stop in enumerate(template.shape)
|
1054
|
+
)
|
1055
|
+
squeeze_fast = tuple(
|
1056
|
+
slice(0, stop) if i not in squeeze_axis else 0
|
1057
|
+
for i, stop in enumerate(fast_shape)
|
1058
|
+
)
|
1059
|
+
squeeze_fast_ft = tuple(
|
1060
|
+
slice(0, stop) if i not in squeeze_axis else 0
|
1061
|
+
for i, stop in enumerate(fast_ft_shape)
|
1062
|
+
)
|
1063
|
+
|
1064
|
+
rfftn, irfftn = backend.build_fft(
|
1065
|
+
fast_shape=temp[squeeze_fast].shape,
|
1066
|
+
fast_ft_shape=fast_ft_shape,
|
1067
|
+
real_dtype=real_dtype,
|
1068
|
+
complex_dtype=complex_dtype,
|
1069
|
+
fftargs=kwargs.get("fftargs", {}),
|
1070
|
+
inverse_fast_shape=fast_shape,
|
1071
|
+
temp_real=arr[squeeze_fast],
|
1072
|
+
temp_fft=ft_temp,
|
1073
|
+
)
|
1074
|
+
for index in range(rotations.shape[0]):
|
1075
|
+
rotation = rotations[index]
|
1076
|
+
backend.fill(arr, 0)
|
1077
|
+
backend.fill(temp, 0)
|
1078
|
+
backend.rotate_array(
|
1079
|
+
arr=template[squeeze],
|
1080
|
+
arr_mask=template_mask[squeeze],
|
1081
|
+
rotation_matrix=rotation,
|
1082
|
+
out=arr[squeeze],
|
1083
|
+
out_mask=temp[squeeze],
|
1084
|
+
use_geometric_center=False,
|
1085
|
+
order=interpolation_order,
|
1086
|
+
)
|
1087
|
+
# Given the amount of FFTs, might aswell normalize properly
|
1088
|
+
n_observations = backend.sum(temp)
|
1089
|
+
|
1090
|
+
normalize_under_mask(
|
1091
|
+
template=arr[squeeze],
|
1092
|
+
mask=temp[squeeze],
|
1093
|
+
mask_intensity=n_observations,
|
1004
1094
|
)
|
1095
|
+
rfftn(temp[squeeze_fast], ft_temp[squeeze_fast_ft])
|
1005
1096
|
|
1006
|
-
|
1007
|
-
|
1008
|
-
|
1009
|
-
|
1010
|
-
rotation_index=index,
|
1011
|
-
**callback_class_args,
|
1012
|
-
)
|
1097
|
+
backend.multiply(ft_target, ft_temp[squeeze_fast_ft], out=ft_denom)
|
1098
|
+
irfftn(ft_denom, temp)
|
1099
|
+
backend.divide(temp, n_observations, out=temp)
|
1100
|
+
backend.square(temp, out=temp)
|
1013
1101
|
|
1102
|
+
backend.multiply(ft_target2, ft_temp[squeeze_fast_ft], out=ft_denom)
|
1103
|
+
irfftn(ft_denom, temp2)
|
1104
|
+
backend.divide(temp2, n_observations, out=temp2)
|
1105
|
+
|
1106
|
+
backend.subtract(temp2, temp, out=temp)
|
1107
|
+
backend.maximum(temp, 0.0, out=temp)
|
1108
|
+
backend.sqrt(temp, out=temp)
|
1109
|
+
backend.multiply(temp, n_observations, out=temp)
|
1110
|
+
|
1111
|
+
rfftn(arr[squeeze_fast], ft_temp[squeeze_fast_ft])
|
1112
|
+
template_filter_func(ft_template=ft_temp, template_filter=template_filter)
|
1113
|
+
backend.multiply(ft_target, ft_temp[squeeze_fast_ft], out=ft_denom)
|
1114
|
+
irfftn(ft_denom, arr)
|
1115
|
+
|
1116
|
+
tol = tol = 1e3 * eps * backend.max(backend.abs(temp))
|
1117
|
+
nonzero_indices = temp > tol
|
1118
|
+
backend.fill(temp2, 0)
|
1119
|
+
temp2[nonzero_indices] = arr[nonzero_indices] / temp[nonzero_indices]
|
1120
|
+
|
1121
|
+
callback_func(
|
1122
|
+
temp2,
|
1123
|
+
rotation_matrix=rotation,
|
1124
|
+
rotation_index=index,
|
1125
|
+
**callback_class_args,
|
1126
|
+
)
|
1014
1127
|
return callback
|
1015
1128
|
|
1016
1129
|
|
@@ -1064,35 +1177,18 @@ def mcc_scoring(
|
|
1064
1177
|
--------
|
1065
1178
|
:py:class:`tme.matching_optimization.MaskedCrossCorrelation`
|
1066
1179
|
"""
|
1067
|
-
|
1068
|
-
ft_target_buffer, ft_target_shape, ft_target_dtype = ft_target
|
1069
|
-
ft_target2_buffer, ft_target_shape, ft_target_dtype = ft_target2
|
1070
|
-
template_mask_buffer, _, _ = template
|
1071
|
-
ft_target_mask_buffer, _, _ = ft_target
|
1072
|
-
filter_buffer, filter_shape, filter_dtype = template_filter
|
1073
|
-
|
1180
|
+
callback = callback_class
|
1074
1181
|
if callback_class is not None and isinstance(callback_class, type):
|
1075
1182
|
callback = callback_class(**callback_class_args)
|
1076
|
-
elif not isinstance(callback_class, type):
|
1077
|
-
callback = callback_class
|
1078
1183
|
|
1079
1184
|
# Retrieve objects from shared memory
|
1080
|
-
|
1081
|
-
|
1082
|
-
|
1083
|
-
)
|
1084
|
-
|
1085
|
-
|
1086
|
-
)
|
1087
|
-
template_mask = backend.sharedarr_to_arr(
|
1088
|
-
template_shape, template_dtype, template_mask_buffer
|
1089
|
-
)
|
1090
|
-
target_mask_ft = backend.sharedarr_to_arr(
|
1091
|
-
ft_target_shape, ft_target_dtype, ft_target_mask_buffer
|
1092
|
-
)
|
1093
|
-
template_filter = backend.sharedarr_to_arr(
|
1094
|
-
filter_shape, filter_dtype, filter_buffer
|
1095
|
-
)
|
1185
|
+
template_buffer, template_shape, template_dtype = template
|
1186
|
+
template = backend.sharedarr_to_arr(*template)
|
1187
|
+
target_ft = backend.sharedarr_to_arr(*ft_target)
|
1188
|
+
target_ft2 = backend.sharedarr_to_arr(*ft_target2)
|
1189
|
+
template_mask = backend.sharedarr_to_arr(*template_mask)
|
1190
|
+
target_mask_ft = backend.sharedarr_to_arr(*ft_target_mask)
|
1191
|
+
template_filter = backend.sharedarr_to_arr(*template_filter)
|
1096
1192
|
|
1097
1193
|
axes = tuple(range(template.ndim))
|
1098
1194
|
eps = backend.eps(real_dtype)
|
@@ -1117,14 +1213,10 @@ def mcc_scoring(
|
|
1117
1213
|
temp_fft=temp_ft,
|
1118
1214
|
)
|
1119
1215
|
|
1120
|
-
|
1121
|
-
|
1122
|
-
|
1123
|
-
axis = tuple(range(template.ndim))
|
1124
|
-
fourier_shift = callback_class_args.get(
|
1125
|
-
"fourier_shift", backend.zeros(template.ndim)
|
1216
|
+
template_filter_func = conditional_execute(
|
1217
|
+
apply_filter, backend.size(template_filter) != 1
|
1126
1218
|
)
|
1127
|
-
|
1219
|
+
callback_func = conditional_execute(callback, callback_class is not None)
|
1128
1220
|
|
1129
1221
|
# Calculate scores across all rotations
|
1130
1222
|
for index in range(rotations.shape[0]):
|
@@ -1146,7 +1238,7 @@ def mcc_scoring(
|
|
1146
1238
|
|
1147
1239
|
# template_rot_ft
|
1148
1240
|
rfftn(template_rot, temp_ft)
|
1149
|
-
template_filter_func(temp_ft, template_filter
|
1241
|
+
template_filter_func(ft_template=temp_ft, template_filter=template_filter)
|
1150
1242
|
irfftn(target_mask_ft * temp_ft, temp2)
|
1151
1243
|
irfftn(target_ft * temp_ft, numerator)
|
1152
1244
|
|
@@ -1202,25 +1294,68 @@ def mcc_scoring(
|
|
1202
1294
|
mask_overlap, axis=axes, keepdims=True
|
1203
1295
|
)
|
1204
1296
|
temp[mask_overlap < number_px_threshold] = 0.0
|
1205
|
-
convolution_mode = kwargs.get("convolution_mode", "full")
|
1206
1297
|
|
1207
|
-
|
1208
|
-
temp
|
1209
|
-
|
1210
|
-
|
1211
|
-
|
1298
|
+
callback_func(
|
1299
|
+
temp,
|
1300
|
+
rotation_matrix=rotation,
|
1301
|
+
rotation_index=index,
|
1302
|
+
**callback_class_args,
|
1212
1303
|
)
|
1213
|
-
if callback_class is not None:
|
1214
|
-
callback(
|
1215
|
-
score,
|
1216
|
-
rotation_matrix=rotation,
|
1217
|
-
rotation_index=index,
|
1218
|
-
**callback_class_args,
|
1219
|
-
)
|
1220
1304
|
|
1221
1305
|
return callback
|
1222
1306
|
|
1223
1307
|
|
1308
|
+
def _setup_template_filter_apply_target_filter(
|
1309
|
+
matching_data: MatchingData,
|
1310
|
+
rfftn: Callable,
|
1311
|
+
irfftn: Callable,
|
1312
|
+
fast_shape: Tuple[int],
|
1313
|
+
fast_ft_shape: Tuple[int],
|
1314
|
+
):
|
1315
|
+
filter_template = isinstance(matching_data.template_filter, Compose)
|
1316
|
+
filter_target = isinstance(matching_data.target_filter, Compose)
|
1317
|
+
|
1318
|
+
template_filter = backend.full(
|
1319
|
+
shape=(1,), fill_value=1, dtype=backend._default_dtype
|
1320
|
+
)
|
1321
|
+
|
1322
|
+
if not filter_template and not filter_target:
|
1323
|
+
return template_filter
|
1324
|
+
|
1325
|
+
target_temp = backend.astype(
|
1326
|
+
backend.topleft_pad(matching_data.target, fast_shape), backend._default_dtype
|
1327
|
+
)
|
1328
|
+
target_temp_ft = backend.preallocate_array(fast_ft_shape, backend._complex_dtype)
|
1329
|
+
rfftn(target_temp, target_temp_ft)
|
1330
|
+
if isinstance(matching_data.template_filter, Compose):
|
1331
|
+
template_filter = matching_data.template_filter(
|
1332
|
+
shape=fast_shape,
|
1333
|
+
return_real_fourier=True,
|
1334
|
+
shape_is_real_fourier=False,
|
1335
|
+
data_rfft=target_temp_ft,
|
1336
|
+
)
|
1337
|
+
template_filter = template_filter["data"]
|
1338
|
+
template_filter[tuple(0 for _ in range(template_filter.ndim))] = 0
|
1339
|
+
|
1340
|
+
if isinstance(matching_data.target_filter, Compose):
|
1341
|
+
target_filter = matching_data.target_filter(
|
1342
|
+
shape=fast_shape,
|
1343
|
+
return_real_fourier=True,
|
1344
|
+
shape_is_real_fourier=False,
|
1345
|
+
data_rfft=target_temp_ft,
|
1346
|
+
weight_type=None,
|
1347
|
+
)
|
1348
|
+
target_filter = target_filter["data"]
|
1349
|
+
backend.multiply(target_temp_ft, target_filter, out=target_temp_ft)
|
1350
|
+
|
1351
|
+
irfftn(target_temp_ft, target_temp)
|
1352
|
+
matching_data._target = backend.topleft_pad(
|
1353
|
+
target_temp, matching_data.target.shape
|
1354
|
+
)
|
1355
|
+
|
1356
|
+
return template_filter
|
1357
|
+
|
1358
|
+
|
1224
1359
|
def device_memory_handler(func: Callable):
|
1225
1360
|
"""Decorator function providing SharedMemory Handler."""
|
1226
1361
|
|
@@ -1291,18 +1426,17 @@ def scan(
|
|
1291
1426
|
Tuple
|
1292
1427
|
The merged results from callback_class if provided otherwise None.
|
1293
1428
|
"""
|
1429
|
+
matching_data.to_backend()
|
1294
1430
|
shape_diff = backend.subtract(
|
1295
|
-
matching_data.
|
1431
|
+
matching_data._output_target_shape, matching_data._output_template_shape
|
1296
1432
|
)
|
1433
|
+
shape_diff = backend.multiply(shape_diff, ~matching_data._batch_mask)
|
1297
1434
|
if backend.sum(shape_diff < 0) and not pad_fourier:
|
1298
1435
|
warnings.warn(
|
1299
1436
|
"Target is larger than template and Fourier padding is turned off."
|
1300
|
-
" This can lead to shifted results. You can swap template and target,
|
1301
|
-
" zero-pad the target."
|
1437
|
+
" This can lead to shifted results. You can swap template and target, "
|
1438
|
+
" zero-pad the target or turn off template centering."
|
1302
1439
|
)
|
1303
|
-
|
1304
|
-
matching_data.to_backend()
|
1305
|
-
|
1306
1440
|
fast_shape, fast_ft_shape, fourier_shift = matching_data.fourier_padding(
|
1307
1441
|
pad_fourier=pad_fourier
|
1308
1442
|
)
|
@@ -1315,6 +1449,15 @@ def scan(
|
|
1315
1449
|
complex_dtype=matching_data._complex_dtype,
|
1316
1450
|
fftargs=fftargs,
|
1317
1451
|
)
|
1452
|
+
|
1453
|
+
template_filter = _setup_template_filter_apply_target_filter(
|
1454
|
+
matching_data=matching_data,
|
1455
|
+
rfftn=rfftn,
|
1456
|
+
irfftn=irfftn,
|
1457
|
+
fast_shape=fast_shape,
|
1458
|
+
fast_ft_shape=fast_ft_shape,
|
1459
|
+
)
|
1460
|
+
|
1318
1461
|
setup = matching_setup(
|
1319
1462
|
rfftn=rfftn,
|
1320
1463
|
irfftn=irfftn,
|
@@ -1332,22 +1475,7 @@ def scan(
|
|
1332
1475
|
)
|
1333
1476
|
rfftn, irfftn = None, None
|
1334
1477
|
|
1335
|
-
template_filter
|
1336
|
-
for method, parameters in matching_data.template_filter.items():
|
1337
|
-
parameters["shape"] = fast_shape
|
1338
|
-
parameters["omit_negative_frequencies"] = True
|
1339
|
-
out = preprocessor.apply_method(method=method, parameters=parameters)
|
1340
|
-
if template_filter is None:
|
1341
|
-
template_filter = out
|
1342
|
-
np.multiply(template_filter, out, out=template_filter)
|
1343
|
-
|
1344
|
-
if template_filter is None:
|
1345
|
-
template_filter = backend.full(
|
1346
|
-
shape=(1,), fill_value=1, dtype=backend._default_dtype
|
1347
|
-
)
|
1348
|
-
else:
|
1349
|
-
template_filter = backend.to_backend_array(template_filter)
|
1350
|
-
|
1478
|
+
template_filter = backend.to_backend_array(template_filter)
|
1351
1479
|
template_filter = backend.astype(template_filter, backend._default_dtype)
|
1352
1480
|
template_filter_buffer = backend.arr_to_sharedarr(
|
1353
1481
|
arr=template_filter,
|
@@ -1369,14 +1497,20 @@ def scan(
|
|
1369
1497
|
callback_class = setup.pop("callback_class", callback_class)
|
1370
1498
|
callback_class_args = setup.pop("callback_class_args", callback_class_args)
|
1371
1499
|
callback_classes = [callback_class for _ in range(n_callback_classes)]
|
1500
|
+
|
1501
|
+
convolution_mode = "same"
|
1502
|
+
if backend.sum(backend.to_backend_array(matching_data._target_pad)) > 0:
|
1503
|
+
convolution_mode = "valid"
|
1504
|
+
|
1505
|
+
callback_class_args["fourier_shift"] = fourier_shift
|
1506
|
+
callback_class_args["convolution_mode"] = convolution_mode
|
1507
|
+
callback_class_args["targetshape"] = setup["targetshape"]
|
1508
|
+
callback_class_args["templateshape"] = setup["templateshape"]
|
1509
|
+
|
1372
1510
|
if callback_class == MaxScoreOverRotations:
|
1373
|
-
score_space_shape = backend.subtract(
|
1374
|
-
matching_data.target.shape,
|
1375
|
-
matching_data._target_pad,
|
1376
|
-
)
|
1377
1511
|
callback_classes = [
|
1378
1512
|
class_name(
|
1379
|
-
score_space_shape=
|
1513
|
+
score_space_shape=fast_shape,
|
1380
1514
|
score_space_dtype=matching_data._default_dtype,
|
1381
1515
|
shared_memory_handler=kwargs.get("shared_memory_handler", None),
|
1382
1516
|
rotation_space_dtype=backend._default_dtype_int,
|
@@ -1416,9 +1550,20 @@ def scan(
|
|
1416
1550
|
for index, rotation in enumerate(rotation_list)
|
1417
1551
|
)
|
1418
1552
|
|
1553
|
+
callbacks = callbacks[0:n_callback_classes]
|
1419
1554
|
callbacks = [
|
1420
|
-
tuple(
|
1421
|
-
|
1555
|
+
tuple(
|
1556
|
+
callback._postprocess(
|
1557
|
+
fourier_shift=fourier_shift,
|
1558
|
+
convolution_mode=convolution_mode,
|
1559
|
+
targetshape=setup["targetshape"],
|
1560
|
+
templateshape=setup["templateshape"],
|
1561
|
+
shared_memory_handler=kwargs.get("shared_memory_handler", None),
|
1562
|
+
)
|
1563
|
+
)
|
1564
|
+
if hasattr(callback, "_postprocess")
|
1565
|
+
else tuple(callback)
|
1566
|
+
for callback in callbacks
|
1422
1567
|
if callback is not None
|
1423
1568
|
]
|
1424
1569
|
backend.free_cache()
|
@@ -1530,11 +1675,13 @@ def scan_subsets(
|
|
1530
1675
|
matching_data._target, matching_data._template = None, None
|
1531
1676
|
matching_data._target_mask, matching_data._template_mask = None, None
|
1532
1677
|
|
1678
|
+
candidates = None
|
1533
1679
|
if callback_class is not None:
|
1534
1680
|
candidates = callback_class.merge(
|
1535
1681
|
results, **callback_class_args, inner_merge=False
|
1536
1682
|
)
|
1537
|
-
|
1683
|
+
|
1684
|
+
return candidates
|
1538
1685
|
|
1539
1686
|
|
1540
1687
|
MATCHING_EXHAUSTIVE_REGISTER = {
|
@@ -1544,6 +1691,7 @@ MATCHING_EXHAUSTIVE_REGISTER = {
|
|
1544
1691
|
"CAM": (cam_setup, corr_scoring),
|
1545
1692
|
"FLCSphericalMask": (flcSphericalMask_setup, corr_scoring),
|
1546
1693
|
"FLC": (flc_setup, flc_scoring),
|
1694
|
+
"FLC2": (flc_setup, flc_scoring2),
|
1547
1695
|
"MCC": (mcc_setup, mcc_scoring),
|
1548
1696
|
}
|
1549
1697
|
|