pytme 0.1.8__cp311-cp311-macosx_14_0_arm64.whl → 0.2.0b0__cp311-cp311-macosx_14_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pytme-0.1.8.data → pytme-0.2.0b0.data}/scripts/match_template.py +148 -126
- pytme-0.2.0b0.data/scripts/postprocess.py +570 -0
- {pytme-0.1.8.data → pytme-0.2.0b0.data}/scripts/preprocessor_gui.py +244 -60
- {pytme-0.1.8.dist-info → pytme-0.2.0b0.dist-info}/METADATA +3 -1
- pytme-0.2.0b0.dist-info/RECORD +66 -0
- {pytme-0.1.8.dist-info → pytme-0.2.0b0.dist-info}/WHEEL +1 -1
- scripts/extract_candidates.py +218 -0
- scripts/match_template.py +148 -126
- scripts/match_template_filters.py +852 -0
- scripts/postprocess.py +380 -435
- scripts/preprocessor_gui.py +244 -60
- scripts/refine_matches.py +218 -0
- tme/__init__.py +2 -1
- tme/__version__.py +1 -1
- tme/analyzer.py +545 -78
- tme/backends/cupy_backend.py +80 -15
- tme/backends/npfftw_backend.py +33 -2
- tme/backends/pytorch_backend.py +15 -7
- tme/density.py +156 -63
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/matching_constrained.py +195 -0
- tme/matching_data.py +76 -32
- tme/matching_exhaustive.py +366 -204
- tme/matching_memory.py +1 -0
- tme/matching_optimization.py +728 -651
- tme/matching_utils.py +152 -8
- tme/orientations.py +561 -0
- tme/preprocessor.py +21 -18
- tme/structure.py +2 -37
- pytme-0.1.8.data/scripts/postprocess.py +0 -625
- pytme-0.1.8.dist-info/RECORD +0 -61
- {pytme-0.1.8.data → pytme-0.2.0b0.data}/scripts/estimate_ram_usage.py +0 -0
- {pytme-0.1.8.data → pytme-0.2.0b0.data}/scripts/preprocess.py +0 -0
- {pytme-0.1.8.dist-info → pytme-0.2.0b0.dist-info}/LICENSE +0 -0
- {pytme-0.1.8.dist-info → pytme-0.2.0b0.dist-info}/entry_points.txt +0 -0
- {pytme-0.1.8.dist-info → pytme-0.2.0b0.dist-info}/top_level.txt +0 -0
tme/matching_exhaustive.py
CHANGED
@@ -25,6 +25,7 @@ from .matching_utils import (
|
|
25
25
|
conditional_execute,
|
26
26
|
)
|
27
27
|
from .matching_memory import MatchingMemoryUsage, MATCHING_MEMORY_REGISTRY
|
28
|
+
# from .preprocessing import Compose
|
28
29
|
from .preprocessor import Preprocessor
|
29
30
|
from .matching_data import MatchingData
|
30
31
|
from .backends import backend
|
@@ -43,6 +44,58 @@ def _run_inner(backend_name, backend_args, **kwargs):
|
|
43
44
|
return scan(**kwargs)
|
44
45
|
|
45
46
|
|
47
|
+
def normalize_under_mask(template: NDArray, mask: NDArray, mask_intensity) -> None:
|
48
|
+
"""
|
49
|
+
Standardizes the values in in template by subtracting the mean and dividing by the
|
50
|
+
standard deviation based on the elements in mask. Subsequently, the template is
|
51
|
+
multiplied by the mask.
|
52
|
+
|
53
|
+
Parameters
|
54
|
+
----------
|
55
|
+
template : NDArray
|
56
|
+
The data array to be normalized. This array is modified in-place.
|
57
|
+
mask : NDArray
|
58
|
+
A boolean array of the same shape as `template`. True values indicate the positions in `template`
|
59
|
+
to consider for normalization.
|
60
|
+
mask_intensity : float
|
61
|
+
Mask intensity used to compute expectations.
|
62
|
+
|
63
|
+
References
|
64
|
+
----------
|
65
|
+
.. [1] T. Hrabe, Y. Chen, S. Pfeffer, L. Kuhn Cuellar, A.-V. Mangold,
|
66
|
+
and F. Förster, J. Struct. Biol. 178, 177 (2012).
|
67
|
+
.. [2] M. L. Chaillet, G. van der Schot, I. Gubins, S. Roet,
|
68
|
+
R. C. Veltkamp, and F. Förster, Int. J. Mol. Sci. 24,
|
69
|
+
13375 (2023)
|
70
|
+
|
71
|
+
Returns
|
72
|
+
-------
|
73
|
+
None
|
74
|
+
This function modifies `template` in-place and does not return any value.
|
75
|
+
"""
|
76
|
+
masked_mean = backend.sum(backend.multiply(template, mask))
|
77
|
+
masked_mean = backend.divide(masked_mean, mask_intensity)
|
78
|
+
masked_std = backend.sum(backend.multiply(backend.square(template), mask))
|
79
|
+
masked_std = backend.subtract(
|
80
|
+
masked_std / mask_intensity, backend.square(masked_mean)
|
81
|
+
)
|
82
|
+
masked_std = backend.sqrt(backend.maximum(masked_std, 0))
|
83
|
+
|
84
|
+
backend.subtract(template, masked_mean, out=template)
|
85
|
+
backend.divide(template, masked_std, out=template)
|
86
|
+
backend.multiply(template, mask, out=template)
|
87
|
+
return None
|
88
|
+
|
89
|
+
|
90
|
+
def apply_filter(ft_template, template_filter):
|
91
|
+
# This is an approximation to applying the mask, irfftn, normalize, rfftn
|
92
|
+
std_before = backend.std(ft_template)
|
93
|
+
backend.multiply(ft_template, template_filter, out=ft_template)
|
94
|
+
backend.multiply(
|
95
|
+
ft_template, std_before / backend.std(ft_template), out=ft_template
|
96
|
+
)
|
97
|
+
|
98
|
+
|
46
99
|
def cc_setup(
|
47
100
|
rfftn: Callable,
|
48
101
|
irfftn: Callable,
|
@@ -208,11 +261,16 @@ def corr_setup(
|
|
208
261
|
target_pad, ft_target2, ft_window_template = None, None, None
|
209
262
|
|
210
263
|
# Normalizing constants
|
211
|
-
|
212
|
-
|
264
|
+
n_observations = backend.sum(template_mask)
|
265
|
+
template_mean = backend.sum(backend.multiply(template, template_mask))
|
266
|
+
template_mean = backend.divide(template_mean, n_observations)
|
213
267
|
template_ssd = backend.sum(
|
214
|
-
backend.square(
|
268
|
+
backend.square(
|
269
|
+
backend.multiply(backend.multiply(template, template_mean), template_mask)
|
270
|
+
)
|
215
271
|
)
|
272
|
+
template_volume = np.prod(template.shape)
|
273
|
+
backend.multiply(template, template_mask, out=template)
|
216
274
|
|
217
275
|
# Final numerator is score - numerator2
|
218
276
|
numerator2 = backend.multiply(target_window_sum, template_mean)
|
@@ -308,49 +366,6 @@ def cam_setup(**kwargs):
|
|
308
366
|
return corr_setup(**kwargs)
|
309
367
|
|
310
368
|
|
311
|
-
def _normalize_under_mask(template: NDArray, mask: NDArray, mask_intensity) -> None:
|
312
|
-
"""
|
313
|
-
Standardizes the values in in template by subtracting the mean and dividing by the
|
314
|
-
standard deviation based on the elements in mask. Subsequently, the template is
|
315
|
-
multiplied by the mask.
|
316
|
-
|
317
|
-
Parameters
|
318
|
-
----------
|
319
|
-
template : NDArray
|
320
|
-
The data array to be normalized. This array is modified in-place.
|
321
|
-
mask : NDArray
|
322
|
-
A boolean array of the same shape as `template`. True values indicate the positions in `template`
|
323
|
-
to consider for normalization.
|
324
|
-
mask_intensity : float
|
325
|
-
Mask intensity used to compute expectations.
|
326
|
-
|
327
|
-
References
|
328
|
-
----------
|
329
|
-
.. [1] T. Hrabe, Y. Chen, S. Pfeffer, L. Kuhn Cuellar, A.-V. Mangold,
|
330
|
-
and F. Förster, J. Struct. Biol. 178, 177 (2012).
|
331
|
-
.. [2] M. L. Chaillet, G. van der Schot, I. Gubins, S. Roet,
|
332
|
-
R. C. Veltkamp, and F. Förster, Int. J. Mol. Sci. 24,
|
333
|
-
13375 (2023)
|
334
|
-
|
335
|
-
Returns
|
336
|
-
-------
|
337
|
-
None
|
338
|
-
This function modifies `template` in-place and does not return any value.
|
339
|
-
"""
|
340
|
-
masked_mean = backend.sum(backend.multiply(template, mask))
|
341
|
-
masked_mean = backend.divide(masked_mean, mask_intensity)
|
342
|
-
masked_std = backend.sum(backend.multiply(backend.square(template), mask))
|
343
|
-
masked_std = backend.subtract(
|
344
|
-
masked_std / mask_intensity, backend.square(masked_mean)
|
345
|
-
)
|
346
|
-
masked_std = backend.sqrt(backend.maximum(masked_std, 0))
|
347
|
-
|
348
|
-
backend.subtract(template, masked_mean, out=template)
|
349
|
-
backend.divide(template, masked_std, out=template)
|
350
|
-
backend.multiply(template, mask, out=template)
|
351
|
-
return None
|
352
|
-
|
353
|
-
|
354
369
|
def flc_setup(
|
355
370
|
rfftn: Callable,
|
356
371
|
irfftn: Callable,
|
@@ -413,7 +428,7 @@ def flc_setup(
|
|
413
428
|
arr=ft_target2, shared_memory_handler=shared_memory_handler
|
414
429
|
)
|
415
430
|
|
416
|
-
|
431
|
+
normalize_under_mask(
|
417
432
|
template=template, mask=template_mask, mask_intensity=backend.sum(template_mask)
|
418
433
|
)
|
419
434
|
|
@@ -535,13 +550,16 @@ def flcSphericalMask_setup(
|
|
535
550
|
backend.fill(temp2, 0)
|
536
551
|
temp2[nonzero_indices] = 1 / temp[nonzero_indices]
|
537
552
|
|
538
|
-
|
553
|
+
normalize_under_mask(
|
539
554
|
template=template, mask=template_mask, mask_intensity=backend.sum(template_mask)
|
540
555
|
)
|
541
556
|
|
542
557
|
template_buffer = backend.arr_to_sharedarr(
|
543
558
|
arr=template, shared_memory_handler=shared_memory_handler
|
544
559
|
)
|
560
|
+
template_mask_buffer = backend.arr_to_sharedarr(
|
561
|
+
arr=template_mask, shared_memory_handler=shared_memory_handler
|
562
|
+
)
|
545
563
|
target_ft_buffer = backend.arr_to_sharedarr(
|
546
564
|
arr=ft_target, shared_memory_handler=shared_memory_handler
|
547
565
|
)
|
@@ -553,6 +571,7 @@ def flcSphericalMask_setup(
|
|
553
571
|
)
|
554
572
|
|
555
573
|
template_tuple = (template_buffer, template.shape, real_dtype)
|
574
|
+
template_mask_tuple = (template_mask_buffer, template.shape, real_dtype)
|
556
575
|
target_ft_tuple = (target_ft_buffer, fast_ft_shape, complex_dtype)
|
557
576
|
|
558
577
|
inv_denominator_tuple = (inv_denominator_buffer, fast_shape, real_dtype)
|
@@ -560,6 +579,7 @@ def flcSphericalMask_setup(
|
|
560
579
|
|
561
580
|
ret = {
|
562
581
|
"template": template_tuple,
|
582
|
+
"template_mask": template_mask_tuple,
|
563
583
|
"ft_target": target_ft_tuple,
|
564
584
|
"inv_denominator": inv_denominator_tuple,
|
565
585
|
"numerator2": numerator2_tuple,
|
@@ -717,10 +737,6 @@ def corr_scoring(
|
|
717
737
|
datatype.
|
718
738
|
numerator2 : Tuple[type, Tuple[int], type]
|
719
739
|
Tuple containing a pointer to the numerator2 data, its shape, and its datatype.
|
720
|
-
targetshape : Tuple[int]
|
721
|
-
The shape of the target.
|
722
|
-
templateshape : Tuple[int]
|
723
|
-
The shape of the template.
|
724
740
|
fast_shape : Tuple[int]
|
725
741
|
The shape for fast Fourier transform.
|
726
742
|
fast_ft_shape : Tuple[int]
|
@@ -739,8 +755,6 @@ def corr_scoring(
|
|
739
755
|
instantiable.
|
740
756
|
interpolation_order : int
|
741
757
|
The order of interpolation to be used while rotating the template.
|
742
|
-
convolution_mode : str, optional
|
743
|
-
Mode to use for convolution, default is "full".
|
744
758
|
**kwargs :
|
745
759
|
Additional arguments to be passed to the function.
|
746
760
|
|
@@ -756,31 +770,23 @@ def corr_scoring(
|
|
756
770
|
:py:meth:`cam_setup`
|
757
771
|
:py:meth:`flcSphericalMask_setup`
|
758
772
|
"""
|
759
|
-
|
760
|
-
ft_target_buffer, ft_target_shape, ft_target_dtype = ft_target
|
761
|
-
inv_denominator_buffer, inv_denominator_pointer_shape, _ = inv_denominator
|
762
|
-
numerator2_buffer, numerator2_shape, _ = numerator2
|
763
|
-
filter_buffer, filter_shape, filter_dtype = template_filter
|
764
|
-
|
773
|
+
callback = callback_class
|
765
774
|
if callback_class is not None and isinstance(callback_class, type):
|
766
775
|
callback = callback_class(**callback_class_args)
|
767
|
-
elif not isinstance(callback_class, type):
|
768
|
-
callback = callback_class
|
769
776
|
|
770
|
-
|
771
|
-
template = backend.sharedarr_to_arr(template_shape, template_dtype
|
772
|
-
ft_target = backend.sharedarr_to_arr(
|
773
|
-
|
774
|
-
)
|
775
|
-
|
776
|
-
|
777
|
-
|
778
|
-
|
779
|
-
|
780
|
-
|
781
|
-
|
782
|
-
|
783
|
-
)
|
777
|
+
template_buffer, template_shape, template_dtype = template
|
778
|
+
template = backend.sharedarr_to_arr(template_buffer, template_shape, template_dtype)
|
779
|
+
ft_target = backend.sharedarr_to_arr(*ft_target)
|
780
|
+
inv_denominator = backend.sharedarr_to_arr(*inv_denominator)
|
781
|
+
numerator2 = backend.sharedarr_to_arr(*numerator2)
|
782
|
+
template_filter = backend.sharedarr_to_arr(*template_filter)
|
783
|
+
|
784
|
+
norm_template, template_mask, mask_sum = False, 1, 1
|
785
|
+
if "template_mask" in kwargs:
|
786
|
+
template_mask = backend.sharedarr_to_arr(
|
787
|
+
kwargs["template_mask"][0], template_shape, template_dtype
|
788
|
+
)
|
789
|
+
norm_template, mask_sum = True, backend.sum(template_mask)
|
784
790
|
|
785
791
|
arr = backend.preallocate_array(fast_shape, real_dtype)
|
786
792
|
ft_temp = backend.preallocate_array(fast_ft_shape, complex_dtype)
|
@@ -799,17 +805,16 @@ def corr_scoring(
|
|
799
805
|
norm_denominator = (backend.sum(inv_denominator) != 1) & (
|
800
806
|
backend.size(inv_denominator) != 1
|
801
807
|
)
|
802
|
-
filter_template = backend.size(template_filter) != 0
|
803
808
|
|
809
|
+
norm_template = conditional_execute(normalize_under_mask, norm_template)
|
810
|
+
callback_func = conditional_execute(callback, callback_class is not None)
|
804
811
|
norm_func_numerator = conditional_execute(backend.subtract, norm_numerator)
|
805
812
|
norm_func_denominator = conditional_execute(backend.multiply, norm_denominator)
|
806
|
-
template_filter_func = conditional_execute(
|
807
|
-
|
808
|
-
|
809
|
-
fourier_shift = callback_class_args.get("fourier_shift", backend.zeros(arr.ndim))
|
810
|
-
fourier_shift_scores = backend.sum(fourier_shift != 0) != 0
|
813
|
+
template_filter_func = conditional_execute(
|
814
|
+
apply_filter, backend.size(template_filter) != 1
|
815
|
+
)
|
811
816
|
|
812
|
-
|
817
|
+
unpadded_slice = tuple(slice(0, stop) for stop in template.shape)
|
813
818
|
for index in range(rotations.shape[0]):
|
814
819
|
rotation = rotations[index]
|
815
820
|
backend.fill(arr, 0)
|
@@ -820,33 +825,24 @@ def corr_scoring(
|
|
820
825
|
use_geometric_center=False,
|
821
826
|
order=interpolation_order,
|
822
827
|
)
|
823
|
-
rotation_norm = template_sum / backend.sum(arr)
|
824
|
-
backend.multiply(arr, rotation_norm, out=arr)
|
825
828
|
|
826
|
-
|
827
|
-
template_filter_func(ft_temp, template_filter, out=ft_temp)
|
829
|
+
norm_template(arr[unpadded_slice], template_mask, mask_sum)
|
828
830
|
|
831
|
+
rfftn(arr, ft_temp)
|
832
|
+
template_filter_func(ft_template=ft_temp, template_filter=template_filter)
|
829
833
|
backend.multiply(ft_target, ft_temp, out=ft_temp)
|
830
834
|
irfftn(ft_temp, arr)
|
831
835
|
|
832
836
|
norm_func_numerator(arr, numerator2, out=arr)
|
833
837
|
norm_func_denominator(arr, inv_denominator, out=arr)
|
834
838
|
|
835
|
-
|
836
|
-
arr
|
837
|
-
|
838
|
-
|
839
|
-
|
839
|
+
callback_func(
|
840
|
+
arr,
|
841
|
+
rotation_matrix=rotation,
|
842
|
+
rotation_index=index,
|
843
|
+
**callback_class_args,
|
840
844
|
)
|
841
845
|
|
842
|
-
if callback_class is not None:
|
843
|
-
callback(
|
844
|
-
score,
|
845
|
-
rotation_matrix=rotation,
|
846
|
-
rotation_index=index,
|
847
|
-
**callback_class_args,
|
848
|
-
)
|
849
|
-
|
850
846
|
return callback
|
851
847
|
|
852
848
|
|
@@ -894,32 +890,15 @@ def flc_scoring(
|
|
894
890
|
.. [2] T. Hrabe, Y. Chen, S. Pfeffer, L. Kuhn Cuellar, A.-V. Mangold,
|
895
891
|
and F. Förster, J. Struct. Biol. 178, 177 (2012).
|
896
892
|
"""
|
897
|
-
|
898
|
-
template_mask_buffer, *_ = template_mask
|
899
|
-
filter_buffer, filter_shape, filter_dtype = template_filter
|
900
|
-
|
901
|
-
ft_target_buffer, ft_target_shape, ft_target_dtype = ft_target
|
902
|
-
ft_target2_buffer, *_ = ft_target2
|
903
|
-
|
893
|
+
callback = callback_class
|
904
894
|
if callback_class is not None and isinstance(callback_class, type):
|
905
895
|
callback = callback_class(**callback_class_args)
|
906
|
-
elif not isinstance(callback_class, type):
|
907
|
-
callback = callback_class
|
908
896
|
|
909
|
-
|
910
|
-
|
911
|
-
|
912
|
-
|
913
|
-
)
|
914
|
-
ft_target = backend.sharedarr_to_arr(
|
915
|
-
ft_target_shape, ft_target_dtype, ft_target_buffer
|
916
|
-
)
|
917
|
-
ft_target2 = backend.sharedarr_to_arr(
|
918
|
-
ft_target_shape, ft_target_dtype, ft_target2_buffer
|
919
|
-
)
|
920
|
-
template_filter = backend.sharedarr_to_arr(
|
921
|
-
filter_shape, filter_dtype, filter_buffer
|
922
|
-
)
|
897
|
+
template = backend.sharedarr_to_arr(*template)
|
898
|
+
template_mask = backend.sharedarr_to_arr(*template_mask)
|
899
|
+
ft_target = backend.sharedarr_to_arr(*ft_target)
|
900
|
+
ft_target2 = backend.sharedarr_to_arr(*ft_target2)
|
901
|
+
template_filter = backend.sharedarr_to_arr(*template_filter)
|
923
902
|
|
924
903
|
arr = backend.preallocate_array(fast_shape, real_dtype)
|
925
904
|
temp = backend.preallocate_array(fast_shape, real_dtype)
|
@@ -938,12 +917,10 @@ def flc_scoring(
|
|
938
917
|
temp_fft=ft_temp,
|
939
918
|
)
|
940
919
|
eps = backend.eps(real_dtype)
|
941
|
-
|
942
|
-
|
943
|
-
|
944
|
-
|
945
|
-
fourier_shift = callback_class_args.get("fourier_shift", backend.zeros(arr.ndim))
|
946
|
-
fourier_shift_scores = backend.sum(fourier_shift != 0) != 0
|
920
|
+
template_filter_func = conditional_execute(
|
921
|
+
apply_filter, backend.size(template_filter) != 1
|
922
|
+
)
|
923
|
+
callback_func = conditional_execute(callback, callback_class is not None)
|
947
924
|
|
948
925
|
unpadded_slice = tuple(slice(0, stop) for stop in template.shape)
|
949
926
|
for index in range(rotations.shape[0]):
|
@@ -962,7 +939,7 @@ def flc_scoring(
|
|
962
939
|
# Given the amount of FFTs, might aswell normalize properly
|
963
940
|
n_observations = backend.sum(temp)
|
964
941
|
|
965
|
-
|
942
|
+
normalize_under_mask(
|
966
943
|
template=arr[unpadded_slice],
|
967
944
|
mask=temp[unpadded_slice],
|
968
945
|
mask_intensity=n_observations,
|
@@ -985,7 +962,7 @@ def flc_scoring(
|
|
985
962
|
backend.multiply(temp, n_observations, out=temp)
|
986
963
|
|
987
964
|
rfftn(arr, ft_temp)
|
988
|
-
template_filter_func(ft_temp, template_filter
|
965
|
+
template_filter_func(ft_template=ft_temp, template_filter=template_filter)
|
989
966
|
backend.multiply(ft_target, ft_temp, out=ft_temp)
|
990
967
|
irfftn(ft_temp, arr)
|
991
968
|
|
@@ -994,23 +971,161 @@ def flc_scoring(
|
|
994
971
|
backend.fill(temp2, 0)
|
995
972
|
temp2[nonzero_indices] = arr[nonzero_indices] / temp[nonzero_indices]
|
996
973
|
|
997
|
-
|
974
|
+
callback_func(
|
975
|
+
temp2,
|
976
|
+
rotation_matrix=rotation,
|
977
|
+
rotation_index=index,
|
978
|
+
**callback_class_args,
|
979
|
+
)
|
998
980
|
|
999
|
-
|
1000
|
-
temp2 = backend.roll(temp2, shift=fourier_shift, axis=axis)
|
981
|
+
return callback
|
1001
982
|
|
1002
|
-
|
1003
|
-
|
983
|
+
|
984
|
+
def flc_scoring2(
|
985
|
+
template: Tuple[type, Tuple[int], type],
|
986
|
+
template_mask: Tuple[type, Tuple[int], type],
|
987
|
+
ft_target: Tuple[type, Tuple[int], type],
|
988
|
+
ft_target2: Tuple[type, Tuple[int], type],
|
989
|
+
template_filter: Tuple[type, Tuple[int], type],
|
990
|
+
targetshape: Tuple[int],
|
991
|
+
templateshape: Tuple[int],
|
992
|
+
fast_shape: Tuple[int],
|
993
|
+
fast_ft_shape: Tuple[int],
|
994
|
+
rotations: NDArray,
|
995
|
+
real_dtype: type,
|
996
|
+
complex_dtype: type,
|
997
|
+
callback_class: CallbackClass,
|
998
|
+
callback_class_args: Dict,
|
999
|
+
interpolation_order: int,
|
1000
|
+
**kwargs,
|
1001
|
+
) -> CallbackClass:
|
1002
|
+
"""
|
1003
|
+
Computes a normalized cross-correlation score of a target f a template g
|
1004
|
+
and a mask m:
|
1005
|
+
|
1006
|
+
.. math::
|
1007
|
+
|
1008
|
+
\\frac{CC(f, \\frac{g*m - \\overline{g*m}}{\\sigma_{g*m}})}
|
1009
|
+
{N_m * \\sqrt{
|
1010
|
+
\\frac{CC(f^2, m)}{N_m} - (\\frac{CC(f, m)}{N_m})^2}
|
1011
|
+
}
|
1012
|
+
|
1013
|
+
Where:
|
1014
|
+
|
1015
|
+
.. math::
|
1016
|
+
|
1017
|
+
CC(f,g) = \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*)
|
1018
|
+
|
1019
|
+
and Nm is the number of voxels within the template mask m.
|
1020
|
+
|
1021
|
+
References
|
1022
|
+
----------
|
1023
|
+
.. [1] W. Wan, S. Khavnekar, J. Wagner, P. Erdmann, and W. Baumeister
|
1024
|
+
Microsc. Microanal. 26, 2516 (2020)
|
1025
|
+
.. [2] T. Hrabe, Y. Chen, S. Pfeffer, L. Kuhn Cuellar, A.-V. Mangold,
|
1026
|
+
and F. Förster, J. Struct. Biol. 178, 177 (2012).
|
1027
|
+
"""
|
1028
|
+
callback = callback_class
|
1029
|
+
if callback_class is not None and isinstance(callback_class, type):
|
1030
|
+
callback = callback_class(**callback_class_args)
|
1031
|
+
|
1032
|
+
# Retrieve objects from shared memory
|
1033
|
+
template = backend.sharedarr_to_arr(*template)
|
1034
|
+
template_mask = backend.sharedarr_to_arr(*template_mask)
|
1035
|
+
ft_target = backend.sharedarr_to_arr(*ft_target)
|
1036
|
+
ft_target2 = backend.sharedarr_to_arr(*ft_target2)
|
1037
|
+
template_filter = backend.sharedarr_to_arr(*template_filter)
|
1038
|
+
|
1039
|
+
arr = backend.preallocate_array(fast_shape, real_dtype)
|
1040
|
+
temp = backend.preallocate_array(fast_shape, real_dtype)
|
1041
|
+
temp2 = backend.preallocate_array(fast_shape, real_dtype)
|
1042
|
+
|
1043
|
+
ft_temp = backend.preallocate_array(fast_ft_shape, complex_dtype)
|
1044
|
+
ft_denom = backend.preallocate_array(fast_ft_shape, complex_dtype)
|
1045
|
+
|
1046
|
+
eps = backend.eps(real_dtype)
|
1047
|
+
template_filter_func = conditional_execute(
|
1048
|
+
apply_filter, backend.size(template_filter) != 1
|
1049
|
+
)
|
1050
|
+
callback_func = conditional_execute(callback, callback_class is not None)
|
1051
|
+
|
1052
|
+
squeeze_axis = tuple(i for i, x in enumerate(template.shape) if x == 1)
|
1053
|
+
squeeze = tuple(
|
1054
|
+
slice(0, stop) if i not in squeeze_axis else 0
|
1055
|
+
for i, stop in enumerate(template.shape)
|
1056
|
+
)
|
1057
|
+
squeeze_fast = tuple(
|
1058
|
+
slice(0, stop) if i not in squeeze_axis else 0
|
1059
|
+
for i, stop in enumerate(fast_shape)
|
1060
|
+
)
|
1061
|
+
squeeze_fast_ft = tuple(
|
1062
|
+
slice(0, stop) if i not in squeeze_axis else 0
|
1063
|
+
for i, stop in enumerate(fast_ft_shape)
|
1064
|
+
)
|
1065
|
+
|
1066
|
+
rfftn, irfftn = backend.build_fft(
|
1067
|
+
fast_shape=temp[squeeze_fast].shape,
|
1068
|
+
fast_ft_shape=fast_ft_shape,
|
1069
|
+
real_dtype=real_dtype,
|
1070
|
+
complex_dtype=complex_dtype,
|
1071
|
+
fftargs=kwargs.get("fftargs", {}),
|
1072
|
+
inverse_fast_shape=fast_shape,
|
1073
|
+
temp_real=arr[squeeze_fast],
|
1074
|
+
temp_fft=ft_temp,
|
1075
|
+
)
|
1076
|
+
for index in range(rotations.shape[0]):
|
1077
|
+
rotation = rotations[index]
|
1078
|
+
backend.fill(arr, 0)
|
1079
|
+
backend.fill(temp, 0)
|
1080
|
+
backend.rotate_array(
|
1081
|
+
arr=template[squeeze],
|
1082
|
+
arr_mask=template_mask[squeeze],
|
1083
|
+
rotation_matrix=rotation,
|
1084
|
+
out=arr[squeeze],
|
1085
|
+
out_mask=temp[squeeze],
|
1086
|
+
use_geometric_center=False,
|
1087
|
+
order=interpolation_order,
|
1004
1088
|
)
|
1089
|
+
# Given the amount of FFTs, might aswell normalize properly
|
1090
|
+
n_observations = backend.sum(temp)
|
1005
1091
|
|
1006
|
-
|
1007
|
-
|
1008
|
-
|
1009
|
-
|
1010
|
-
|
1011
|
-
|
1012
|
-
|
1092
|
+
normalize_under_mask(
|
1093
|
+
template=arr[squeeze],
|
1094
|
+
mask=temp[squeeze],
|
1095
|
+
mask_intensity=n_observations,
|
1096
|
+
)
|
1097
|
+
rfftn(temp[squeeze_fast], ft_temp[squeeze_fast_ft])
|
1098
|
+
|
1099
|
+
backend.multiply(ft_target, ft_temp[squeeze_fast_ft], out=ft_denom)
|
1100
|
+
irfftn(ft_denom, temp)
|
1101
|
+
backend.divide(temp, n_observations, out=temp)
|
1102
|
+
backend.square(temp, out=temp)
|
1103
|
+
|
1104
|
+
backend.multiply(ft_target2, ft_temp[squeeze_fast_ft], out=ft_denom)
|
1105
|
+
irfftn(ft_denom, temp2)
|
1106
|
+
backend.divide(temp2, n_observations, out=temp2)
|
1107
|
+
|
1108
|
+
backend.subtract(temp2, temp, out=temp)
|
1109
|
+
backend.maximum(temp, 0.0, out=temp)
|
1110
|
+
backend.sqrt(temp, out=temp)
|
1111
|
+
backend.multiply(temp, n_observations, out=temp)
|
1112
|
+
|
1113
|
+
rfftn(arr[squeeze_fast], ft_temp[squeeze_fast_ft])
|
1114
|
+
template_filter_func(ft_template=ft_temp, template_filter=template_filter)
|
1115
|
+
backend.multiply(ft_target, ft_temp[squeeze_fast_ft], out=ft_denom)
|
1116
|
+
irfftn(ft_denom, arr)
|
1013
1117
|
|
1118
|
+
tol = tol = 1e3 * eps * backend.max(backend.abs(temp))
|
1119
|
+
nonzero_indices = temp > tol
|
1120
|
+
backend.fill(temp2, 0)
|
1121
|
+
temp2[nonzero_indices] = arr[nonzero_indices] / temp[nonzero_indices]
|
1122
|
+
|
1123
|
+
callback_func(
|
1124
|
+
temp2,
|
1125
|
+
rotation_matrix=rotation,
|
1126
|
+
rotation_index=index,
|
1127
|
+
**callback_class_args,
|
1128
|
+
)
|
1014
1129
|
return callback
|
1015
1130
|
|
1016
1131
|
|
@@ -1064,35 +1179,18 @@ def mcc_scoring(
|
|
1064
1179
|
--------
|
1065
1180
|
:py:class:`tme.matching_optimization.MaskedCrossCorrelation`
|
1066
1181
|
"""
|
1067
|
-
|
1068
|
-
ft_target_buffer, ft_target_shape, ft_target_dtype = ft_target
|
1069
|
-
ft_target2_buffer, ft_target_shape, ft_target_dtype = ft_target2
|
1070
|
-
template_mask_buffer, _, _ = template
|
1071
|
-
ft_target_mask_buffer, _, _ = ft_target
|
1072
|
-
filter_buffer, filter_shape, filter_dtype = template_filter
|
1073
|
-
|
1182
|
+
callback = callback_class
|
1074
1183
|
if callback_class is not None and isinstance(callback_class, type):
|
1075
1184
|
callback = callback_class(**callback_class_args)
|
1076
|
-
elif not isinstance(callback_class, type):
|
1077
|
-
callback = callback_class
|
1078
1185
|
|
1079
1186
|
# Retrieve objects from shared memory
|
1080
|
-
|
1081
|
-
|
1082
|
-
|
1083
|
-
)
|
1084
|
-
|
1085
|
-
|
1086
|
-
)
|
1087
|
-
template_mask = backend.sharedarr_to_arr(
|
1088
|
-
template_shape, template_dtype, template_mask_buffer
|
1089
|
-
)
|
1090
|
-
target_mask_ft = backend.sharedarr_to_arr(
|
1091
|
-
ft_target_shape, ft_target_dtype, ft_target_mask_buffer
|
1092
|
-
)
|
1093
|
-
template_filter = backend.sharedarr_to_arr(
|
1094
|
-
filter_shape, filter_dtype, filter_buffer
|
1095
|
-
)
|
1187
|
+
template_buffer, template_shape, template_dtype = template
|
1188
|
+
template = backend.sharedarr_to_arr(*template)
|
1189
|
+
target_ft = backend.sharedarr_to_arr(*ft_target)
|
1190
|
+
target_ft2 = backend.sharedarr_to_arr(*ft_target2)
|
1191
|
+
template_mask = backend.sharedarr_to_arr(*template_mask)
|
1192
|
+
target_mask_ft = backend.sharedarr_to_arr(*ft_target_mask)
|
1193
|
+
template_filter = backend.sharedarr_to_arr(*template_filter)
|
1096
1194
|
|
1097
1195
|
axes = tuple(range(template.ndim))
|
1098
1196
|
eps = backend.eps(real_dtype)
|
@@ -1117,14 +1215,10 @@ def mcc_scoring(
|
|
1117
1215
|
temp_fft=temp_ft,
|
1118
1216
|
)
|
1119
1217
|
|
1120
|
-
|
1121
|
-
|
1122
|
-
|
1123
|
-
axis = tuple(range(template.ndim))
|
1124
|
-
fourier_shift = callback_class_args.get(
|
1125
|
-
"fourier_shift", backend.zeros(template.ndim)
|
1218
|
+
template_filter_func = conditional_execute(
|
1219
|
+
apply_filter, backend.size(template_filter) != 1
|
1126
1220
|
)
|
1127
|
-
|
1221
|
+
callback_func = conditional_execute(callback, callback_class is not None)
|
1128
1222
|
|
1129
1223
|
# Calculate scores across all rotations
|
1130
1224
|
for index in range(rotations.shape[0]):
|
@@ -1146,7 +1240,7 @@ def mcc_scoring(
|
|
1146
1240
|
|
1147
1241
|
# template_rot_ft
|
1148
1242
|
rfftn(template_rot, temp_ft)
|
1149
|
-
template_filter_func(temp_ft, template_filter
|
1243
|
+
template_filter_func(ft_template=temp_ft, template_filter=template_filter)
|
1150
1244
|
irfftn(target_mask_ft * temp_ft, temp2)
|
1151
1245
|
irfftn(target_ft * temp_ft, numerator)
|
1152
1246
|
|
@@ -1202,25 +1296,69 @@ def mcc_scoring(
|
|
1202
1296
|
mask_overlap, axis=axes, keepdims=True
|
1203
1297
|
)
|
1204
1298
|
temp[mask_overlap < number_px_threshold] = 0.0
|
1205
|
-
convolution_mode = kwargs.get("convolution_mode", "full")
|
1206
1299
|
|
1207
|
-
|
1208
|
-
temp
|
1209
|
-
|
1210
|
-
|
1211
|
-
|
1300
|
+
callback_func(
|
1301
|
+
temp,
|
1302
|
+
rotation_matrix=rotation,
|
1303
|
+
rotation_index=index,
|
1304
|
+
**callback_class_args,
|
1212
1305
|
)
|
1213
|
-
if callback_class is not None:
|
1214
|
-
callback(
|
1215
|
-
score,
|
1216
|
-
rotation_matrix=rotation,
|
1217
|
-
rotation_index=index,
|
1218
|
-
**callback_class_args,
|
1219
|
-
)
|
1220
1306
|
|
1221
1307
|
return callback
|
1222
1308
|
|
1223
1309
|
|
1310
|
+
def _setup_template_filter(
|
1311
|
+
matching_data: MatchingData,
|
1312
|
+
rfftn: Callable,
|
1313
|
+
irfftn: Callable,
|
1314
|
+
fast_shape: Tuple[int],
|
1315
|
+
fast_ft_shape: Tuple[int],
|
1316
|
+
):
|
1317
|
+
filter_template = isinstance(matching_data.template_filter, Compose)
|
1318
|
+
filter_target = isinstance(matching_data.target_filter, Compose)
|
1319
|
+
|
1320
|
+
template_filter = backend.full(
|
1321
|
+
shape=(1,), fill_value=1, dtype=backend._default_dtype
|
1322
|
+
)
|
1323
|
+
|
1324
|
+
if not filter_template and not filter_target:
|
1325
|
+
return template_filter
|
1326
|
+
|
1327
|
+
target_temp = backend.astype(
|
1328
|
+
backend.topleft_pad(matching_data.target, fast_shape), backend._default_dtype
|
1329
|
+
)
|
1330
|
+
target_temp_ft = backend.preallocate_array(fast_ft_shape, backend._complex_dtype)
|
1331
|
+
rfftn(target_temp, target_temp_ft)
|
1332
|
+
|
1333
|
+
if isinstance(matching_data.template_filter, Compose):
|
1334
|
+
template_filter = matching_data.template_filter(
|
1335
|
+
shape=fast_shape,
|
1336
|
+
return_real_fourier=True,
|
1337
|
+
shape_is_real_fourier=False,
|
1338
|
+
data_rfft=target_temp_ft,
|
1339
|
+
)
|
1340
|
+
template_filter = template_filter["data"]
|
1341
|
+
template_filter[tuple(0 for _ in range(template_filter.ndim))] = 0
|
1342
|
+
|
1343
|
+
if isinstance(matching_data.target_filter, Compose):
|
1344
|
+
target_filter = matching_data.target_filter(
|
1345
|
+
shape=fast_shape,
|
1346
|
+
return_real_fourier=True,
|
1347
|
+
shape_is_real_fourier=False,
|
1348
|
+
data_rfft=target_temp_ft,
|
1349
|
+
weight_type=None,
|
1350
|
+
)
|
1351
|
+
target_filter = target_filter["data"]
|
1352
|
+
backend.multiply(target_temp_ft, target_filter, out=target_temp_ft)
|
1353
|
+
|
1354
|
+
irfftn(target_temp_ft, target_temp)
|
1355
|
+
matching_data._target = backend.topleft_pad(
|
1356
|
+
target_temp, matching_data.target.shape
|
1357
|
+
)
|
1358
|
+
|
1359
|
+
return template_filter
|
1360
|
+
|
1361
|
+
|
1224
1362
|
def device_memory_handler(func: Callable):
|
1225
1363
|
"""Decorator function providing SharedMemory Handler."""
|
1226
1364
|
|
@@ -1291,18 +1429,17 @@ def scan(
|
|
1291
1429
|
Tuple
|
1292
1430
|
The merged results from callback_class if provided otherwise None.
|
1293
1431
|
"""
|
1432
|
+
matching_data.to_backend()
|
1294
1433
|
shape_diff = backend.subtract(
|
1295
|
-
matching_data.
|
1434
|
+
matching_data._output_target_shape, matching_data._output_template_shape
|
1296
1435
|
)
|
1436
|
+
shape_diff = backend.multiply(shape_diff, ~matching_data._batch_mask)
|
1297
1437
|
if backend.sum(shape_diff < 0) and not pad_fourier:
|
1298
1438
|
warnings.warn(
|
1299
1439
|
"Target is larger than template and Fourier padding is turned off."
|
1300
1440
|
" This can lead to shifted results. You can swap template and target, or"
|
1301
1441
|
" zero-pad the target."
|
1302
1442
|
)
|
1303
|
-
|
1304
|
-
matching_data.to_backend()
|
1305
|
-
|
1306
1443
|
fast_shape, fast_ft_shape, fourier_shift = matching_data.fourier_padding(
|
1307
1444
|
pad_fourier=pad_fourier
|
1308
1445
|
)
|
@@ -1315,6 +1452,15 @@ def scan(
|
|
1315
1452
|
complex_dtype=matching_data._complex_dtype,
|
1316
1453
|
fftargs=fftargs,
|
1317
1454
|
)
|
1455
|
+
|
1456
|
+
# template_filter = _setup_template_filter(
|
1457
|
+
# matching_data=matching_data,
|
1458
|
+
# rfftn=rfftn,
|
1459
|
+
# irfftn=irfftn,
|
1460
|
+
# fast_shape=fast_shape,
|
1461
|
+
# fast_ft_shape=fast_ft_shape,
|
1462
|
+
# )
|
1463
|
+
|
1318
1464
|
setup = matching_setup(
|
1319
1465
|
rfftn=rfftn,
|
1320
1466
|
irfftn=irfftn,
|
@@ -1332,6 +1478,7 @@ def scan(
|
|
1332
1478
|
)
|
1333
1479
|
rfftn, irfftn = None, None
|
1334
1480
|
|
1481
|
+
|
1335
1482
|
template_filter, preprocessor = None, Preprocessor()
|
1336
1483
|
for method, parameters in matching_data.template_filter.items():
|
1337
1484
|
parameters["shape"] = fast_shape
|
@@ -1345,9 +1492,8 @@ def scan(
|
|
1345
1492
|
template_filter = backend.full(
|
1346
1493
|
shape=(1,), fill_value=1, dtype=backend._default_dtype
|
1347
1494
|
)
|
1348
|
-
else:
|
1349
|
-
template_filter = backend.to_backend_array(template_filter)
|
1350
1495
|
|
1496
|
+
template_filter = backend.to_backend_array(template_filter)
|
1351
1497
|
template_filter = backend.astype(template_filter, backend._default_dtype)
|
1352
1498
|
template_filter_buffer = backend.arr_to_sharedarr(
|
1353
1499
|
arr=template_filter,
|
@@ -1369,14 +1515,21 @@ def scan(
|
|
1369
1515
|
callback_class = setup.pop("callback_class", callback_class)
|
1370
1516
|
callback_class_args = setup.pop("callback_class_args", callback_class_args)
|
1371
1517
|
callback_classes = [callback_class for _ in range(n_callback_classes)]
|
1518
|
+
|
1519
|
+
convolution_mode = "same"
|
1520
|
+
if backend.sum(backend.to_backend_array(matching_data._target_pad)) > 0:
|
1521
|
+
convolution_mode = "valid"
|
1522
|
+
|
1523
|
+
|
1524
|
+
callback_class_args["fourier_shift"] = fourier_shift
|
1525
|
+
callback_class_args["convolution_mode"] = convolution_mode
|
1526
|
+
callback_class_args["targetshape"] = setup["targetshape"]
|
1527
|
+
callback_class_args["templateshape"] = setup["templateshape"]
|
1528
|
+
|
1372
1529
|
if callback_class == MaxScoreOverRotations:
|
1373
|
-
score_space_shape = backend.subtract(
|
1374
|
-
matching_data.target.shape,
|
1375
|
-
matching_data._target_pad,
|
1376
|
-
)
|
1377
1530
|
callback_classes = [
|
1378
1531
|
class_name(
|
1379
|
-
score_space_shape=
|
1532
|
+
score_space_shape=fast_shape,
|
1380
1533
|
score_space_dtype=matching_data._default_dtype,
|
1381
1534
|
shared_memory_handler=kwargs.get("shared_memory_handler", None),
|
1382
1535
|
rotation_space_dtype=backend._default_dtype_int,
|
@@ -1416,10 +1569,16 @@ def scan(
|
|
1416
1569
|
for index, rotation in enumerate(rotation_list)
|
1417
1570
|
)
|
1418
1571
|
|
1572
|
+
callbacks = callbacks[0:n_callback_classes]
|
1419
1573
|
callbacks = [
|
1420
|
-
tuple(callback
|
1421
|
-
|
1422
|
-
|
1574
|
+
tuple(callback._postprocess(
|
1575
|
+
fourier_shift = fourier_shift,
|
1576
|
+
convolution_mode = convolution_mode,
|
1577
|
+
targetshape = setup["targetshape"],
|
1578
|
+
templateshape = setup["templateshape"],
|
1579
|
+
shared_memory_handler=kwargs.get("shared_memory_handler", None)
|
1580
|
+
)) if hasattr(callback, "_postprocess") else tuple(callback)
|
1581
|
+
for callback in callbacks if callback is not None
|
1423
1582
|
]
|
1424
1583
|
backend.free_cache()
|
1425
1584
|
|
@@ -1530,11 +1689,13 @@ def scan_subsets(
|
|
1530
1689
|
matching_data._target, matching_data._template = None, None
|
1531
1690
|
matching_data._target_mask, matching_data._template_mask = None, None
|
1532
1691
|
|
1692
|
+
candidates = None
|
1533
1693
|
if callback_class is not None:
|
1534
1694
|
candidates = callback_class.merge(
|
1535
1695
|
results, **callback_class_args, inner_merge=False
|
1536
1696
|
)
|
1537
|
-
|
1697
|
+
|
1698
|
+
return candidates
|
1538
1699
|
|
1539
1700
|
|
1540
1701
|
MATCHING_EXHAUSTIVE_REGISTER = {
|
@@ -1544,6 +1705,7 @@ MATCHING_EXHAUSTIVE_REGISTER = {
|
|
1544
1705
|
"CAM": (cam_setup, corr_scoring),
|
1545
1706
|
"FLCSphericalMask": (flcSphericalMask_setup, corr_scoring),
|
1546
1707
|
"FLC": (flc_setup, flc_scoring),
|
1708
|
+
"FLC2": (flc_setup, flc_scoring2),
|
1547
1709
|
"MCC": (mcc_setup, mcc_scoring),
|
1548
1710
|
}
|
1549
1711
|
|