pytme 0.1.9__cp311-cp311-macosx_14_0_arm64.whl → 0.2.0b0__cp311-cp311-macosx_14_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pytme-0.1.9.data → pytme-0.2.0b0.data}/scripts/match_template.py +148 -126
- pytme-0.2.0b0.data/scripts/postprocess.py +570 -0
- {pytme-0.1.9.data → pytme-0.2.0b0.data}/scripts/preprocessor_gui.py +244 -60
- {pytme-0.1.9.dist-info → pytme-0.2.0b0.dist-info}/METADATA +3 -1
- pytme-0.2.0b0.dist-info/RECORD +66 -0
- {pytme-0.1.9.dist-info → pytme-0.2.0b0.dist-info}/WHEEL +1 -1
- scripts/extract_candidates.py +218 -0
- scripts/match_template.py +148 -126
- scripts/match_template_filters.py +852 -0
- scripts/postprocess.py +380 -435
- scripts/preprocessor_gui.py +244 -60
- scripts/refine_matches.py +218 -0
- tme/__init__.py +2 -1
- tme/__version__.py +1 -1
- tme/analyzer.py +545 -78
- tme/backends/cupy_backend.py +80 -15
- tme/backends/npfftw_backend.py +33 -2
- tme/backends/pytorch_backend.py +15 -7
- tme/density.py +156 -63
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/matching_constrained.py +195 -0
- tme/matching_data.py +74 -33
- tme/matching_exhaustive.py +351 -208
- tme/matching_memory.py +1 -0
- tme/matching_optimization.py +728 -651
- tme/matching_utils.py +152 -8
- tme/orientations.py +561 -0
- tme/preprocessor.py +21 -18
- tme/structure.py +2 -37
- pytme-0.1.9.data/scripts/postprocess.py +0 -625
- pytme-0.1.9.dist-info/RECORD +0 -61
- {pytme-0.1.9.data → pytme-0.2.0b0.data}/scripts/estimate_ram_usage.py +0 -0
- {pytme-0.1.9.data → pytme-0.2.0b0.data}/scripts/preprocess.py +0 -0
- {pytme-0.1.9.dist-info → pytme-0.2.0b0.dist-info}/LICENSE +0 -0
- {pytme-0.1.9.dist-info → pytme-0.2.0b0.dist-info}/entry_points.txt +0 -0
- {pytme-0.1.9.dist-info → pytme-0.2.0b0.dist-info}/top_level.txt +0 -0
tme/matching_exhaustive.py
CHANGED
@@ -25,6 +25,7 @@ from .matching_utils import (
|
|
25
25
|
conditional_execute,
|
26
26
|
)
|
27
27
|
from .matching_memory import MatchingMemoryUsage, MATCHING_MEMORY_REGISTRY
|
28
|
+
# from .preprocessing import Compose
|
28
29
|
from .preprocessor import Preprocessor
|
29
30
|
from .matching_data import MatchingData
|
30
31
|
from .backends import backend
|
@@ -43,6 +44,58 @@ def _run_inner(backend_name, backend_args, **kwargs):
|
|
43
44
|
return scan(**kwargs)
|
44
45
|
|
45
46
|
|
47
|
+
def normalize_under_mask(template: NDArray, mask: NDArray, mask_intensity) -> None:
|
48
|
+
"""
|
49
|
+
Standardizes the values in in template by subtracting the mean and dividing by the
|
50
|
+
standard deviation based on the elements in mask. Subsequently, the template is
|
51
|
+
multiplied by the mask.
|
52
|
+
|
53
|
+
Parameters
|
54
|
+
----------
|
55
|
+
template : NDArray
|
56
|
+
The data array to be normalized. This array is modified in-place.
|
57
|
+
mask : NDArray
|
58
|
+
A boolean array of the same shape as `template`. True values indicate the positions in `template`
|
59
|
+
to consider for normalization.
|
60
|
+
mask_intensity : float
|
61
|
+
Mask intensity used to compute expectations.
|
62
|
+
|
63
|
+
References
|
64
|
+
----------
|
65
|
+
.. [1] T. Hrabe, Y. Chen, S. Pfeffer, L. Kuhn Cuellar, A.-V. Mangold,
|
66
|
+
and F. Förster, J. Struct. Biol. 178, 177 (2012).
|
67
|
+
.. [2] M. L. Chaillet, G. van der Schot, I. Gubins, S. Roet,
|
68
|
+
R. C. Veltkamp, and F. Förster, Int. J. Mol. Sci. 24,
|
69
|
+
13375 (2023)
|
70
|
+
|
71
|
+
Returns
|
72
|
+
-------
|
73
|
+
None
|
74
|
+
This function modifies `template` in-place and does not return any value.
|
75
|
+
"""
|
76
|
+
masked_mean = backend.sum(backend.multiply(template, mask))
|
77
|
+
masked_mean = backend.divide(masked_mean, mask_intensity)
|
78
|
+
masked_std = backend.sum(backend.multiply(backend.square(template), mask))
|
79
|
+
masked_std = backend.subtract(
|
80
|
+
masked_std / mask_intensity, backend.square(masked_mean)
|
81
|
+
)
|
82
|
+
masked_std = backend.sqrt(backend.maximum(masked_std, 0))
|
83
|
+
|
84
|
+
backend.subtract(template, masked_mean, out=template)
|
85
|
+
backend.divide(template, masked_std, out=template)
|
86
|
+
backend.multiply(template, mask, out=template)
|
87
|
+
return None
|
88
|
+
|
89
|
+
|
90
|
+
def apply_filter(ft_template, template_filter):
|
91
|
+
# This is an approximation to applying the mask, irfftn, normalize, rfftn
|
92
|
+
std_before = backend.std(ft_template)
|
93
|
+
backend.multiply(ft_template, template_filter, out=ft_template)
|
94
|
+
backend.multiply(
|
95
|
+
ft_template, std_before / backend.std(ft_template), out=ft_template
|
96
|
+
)
|
97
|
+
|
98
|
+
|
46
99
|
def cc_setup(
|
47
100
|
rfftn: Callable,
|
48
101
|
irfftn: Callable,
|
@@ -212,13 +265,12 @@ def corr_setup(
|
|
212
265
|
template_mean = backend.sum(backend.multiply(template, template_mask))
|
213
266
|
template_mean = backend.divide(template_mean, n_observations)
|
214
267
|
template_ssd = backend.sum(
|
215
|
-
backend.square(
|
216
|
-
backend.multiply(template, template_mean),
|
217
|
-
|
218
|
-
))
|
268
|
+
backend.square(
|
269
|
+
backend.multiply(backend.multiply(template, template_mean), template_mask)
|
270
|
+
)
|
219
271
|
)
|
220
272
|
template_volume = np.prod(template.shape)
|
221
|
-
backend.multiply(template, template_mask, out
|
273
|
+
backend.multiply(template, template_mask, out=template)
|
222
274
|
|
223
275
|
# Final numerator is score - numerator2
|
224
276
|
numerator2 = backend.multiply(target_window_sum, template_mean)
|
@@ -314,49 +366,6 @@ def cam_setup(**kwargs):
|
|
314
366
|
return corr_setup(**kwargs)
|
315
367
|
|
316
368
|
|
317
|
-
def _normalize_under_mask(template: NDArray, mask: NDArray, mask_intensity) -> None:
|
318
|
-
"""
|
319
|
-
Standardizes the values in in template by subtracting the mean and dividing by the
|
320
|
-
standard deviation based on the elements in mask. Subsequently, the template is
|
321
|
-
multiplied by the mask.
|
322
|
-
|
323
|
-
Parameters
|
324
|
-
----------
|
325
|
-
template : NDArray
|
326
|
-
The data array to be normalized. This array is modified in-place.
|
327
|
-
mask : NDArray
|
328
|
-
A boolean array of the same shape as `template`. True values indicate the positions in `template`
|
329
|
-
to consider for normalization.
|
330
|
-
mask_intensity : float
|
331
|
-
Mask intensity used to compute expectations.
|
332
|
-
|
333
|
-
References
|
334
|
-
----------
|
335
|
-
.. [1] T. Hrabe, Y. Chen, S. Pfeffer, L. Kuhn Cuellar, A.-V. Mangold,
|
336
|
-
and F. Förster, J. Struct. Biol. 178, 177 (2012).
|
337
|
-
.. [2] M. L. Chaillet, G. van der Schot, I. Gubins, S. Roet,
|
338
|
-
R. C. Veltkamp, and F. Förster, Int. J. Mol. Sci. 24,
|
339
|
-
13375 (2023)
|
340
|
-
|
341
|
-
Returns
|
342
|
-
-------
|
343
|
-
None
|
344
|
-
This function modifies `template` in-place and does not return any value.
|
345
|
-
"""
|
346
|
-
masked_mean = backend.sum(backend.multiply(template, mask))
|
347
|
-
masked_mean = backend.divide(masked_mean, mask_intensity)
|
348
|
-
masked_std = backend.sum(backend.multiply(backend.square(template), mask))
|
349
|
-
masked_std = backend.subtract(
|
350
|
-
masked_std / mask_intensity, backend.square(masked_mean)
|
351
|
-
)
|
352
|
-
masked_std = backend.sqrt(backend.maximum(masked_std, 0))
|
353
|
-
|
354
|
-
backend.subtract(template, masked_mean, out=template)
|
355
|
-
backend.divide(template, masked_std, out=template)
|
356
|
-
backend.multiply(template, mask, out=template)
|
357
|
-
return None
|
358
|
-
|
359
|
-
|
360
369
|
def flc_setup(
|
361
370
|
rfftn: Callable,
|
362
371
|
irfftn: Callable,
|
@@ -419,7 +428,7 @@ def flc_setup(
|
|
419
428
|
arr=ft_target2, shared_memory_handler=shared_memory_handler
|
420
429
|
)
|
421
430
|
|
422
|
-
|
431
|
+
normalize_under_mask(
|
423
432
|
template=template, mask=template_mask, mask_intensity=backend.sum(template_mask)
|
424
433
|
)
|
425
434
|
|
@@ -541,7 +550,7 @@ def flcSphericalMask_setup(
|
|
541
550
|
backend.fill(temp2, 0)
|
542
551
|
temp2[nonzero_indices] = 1 / temp[nonzero_indices]
|
543
552
|
|
544
|
-
|
553
|
+
normalize_under_mask(
|
545
554
|
template=template, mask=template_mask, mask_intensity=backend.sum(template_mask)
|
546
555
|
)
|
547
556
|
|
@@ -728,10 +737,6 @@ def corr_scoring(
|
|
728
737
|
datatype.
|
729
738
|
numerator2 : Tuple[type, Tuple[int], type]
|
730
739
|
Tuple containing a pointer to the numerator2 data, its shape, and its datatype.
|
731
|
-
targetshape : Tuple[int]
|
732
|
-
The shape of the target.
|
733
|
-
templateshape : Tuple[int]
|
734
|
-
The shape of the template.
|
735
740
|
fast_shape : Tuple[int]
|
736
741
|
The shape for fast Fourier transform.
|
737
742
|
fast_ft_shape : Tuple[int]
|
@@ -750,8 +755,6 @@ def corr_scoring(
|
|
750
755
|
instantiable.
|
751
756
|
interpolation_order : int
|
752
757
|
The order of interpolation to be used while rotating the template.
|
753
|
-
convolution_mode : str, optional
|
754
|
-
Mode to use for convolution, default is "full".
|
755
758
|
**kwargs :
|
756
759
|
Additional arguments to be passed to the function.
|
757
760
|
|
@@ -767,37 +770,23 @@ def corr_scoring(
|
|
767
770
|
:py:meth:`cam_setup`
|
768
771
|
:py:meth:`flcSphericalMask_setup`
|
769
772
|
"""
|
770
|
-
|
771
|
-
ft_target_buffer, ft_target_shape, ft_target_dtype = ft_target
|
772
|
-
inv_denominator_buffer, inv_denominator_pointer_shape, _ = inv_denominator
|
773
|
-
numerator2_buffer, numerator2_shape, _ = numerator2
|
774
|
-
filter_buffer, filter_shape, filter_dtype = template_filter
|
775
|
-
|
773
|
+
callback = callback_class
|
776
774
|
if callback_class is not None and isinstance(callback_class, type):
|
777
775
|
callback = callback_class(**callback_class_args)
|
778
|
-
elif not isinstance(callback_class, type):
|
779
|
-
callback = callback_class
|
780
776
|
|
781
|
-
|
782
|
-
template = backend.sharedarr_to_arr(template_shape, template_dtype
|
783
|
-
ft_target = backend.sharedarr_to_arr(
|
784
|
-
|
785
|
-
)
|
786
|
-
|
787
|
-
inv_denominator_pointer_shape, template_dtype, inv_denominator_buffer
|
788
|
-
)
|
789
|
-
numerator2 = backend.sharedarr_to_arr(
|
790
|
-
numerator2_shape, template_dtype, numerator2_buffer
|
791
|
-
)
|
792
|
-
template_filter = backend.sharedarr_to_arr(
|
793
|
-
filter_shape, filter_dtype, filter_buffer
|
794
|
-
)
|
777
|
+
template_buffer, template_shape, template_dtype = template
|
778
|
+
template = backend.sharedarr_to_arr(template_buffer, template_shape, template_dtype)
|
779
|
+
ft_target = backend.sharedarr_to_arr(*ft_target)
|
780
|
+
inv_denominator = backend.sharedarr_to_arr(*inv_denominator)
|
781
|
+
numerator2 = backend.sharedarr_to_arr(*numerator2)
|
782
|
+
template_filter = backend.sharedarr_to_arr(*template_filter)
|
795
783
|
|
796
784
|
norm_template, template_mask, mask_sum = False, 1, 1
|
797
785
|
if "template_mask" in kwargs:
|
798
|
-
template_mask = backend.sharedarr_to_arr(
|
786
|
+
template_mask = backend.sharedarr_to_arr(
|
787
|
+
kwargs["template_mask"][0], template_shape, template_dtype
|
788
|
+
)
|
799
789
|
norm_template, mask_sum = True, backend.sum(template_mask)
|
800
|
-
norm_template = conditional_execute(_normalize_under_mask, norm_template)
|
801
790
|
|
802
791
|
arr = backend.preallocate_array(fast_shape, real_dtype)
|
803
792
|
ft_temp = backend.preallocate_array(fast_ft_shape, complex_dtype)
|
@@ -816,17 +805,15 @@ def corr_scoring(
|
|
816
805
|
norm_denominator = (backend.sum(inv_denominator) != 1) & (
|
817
806
|
backend.size(inv_denominator) != 1
|
818
807
|
)
|
819
|
-
filter_template = backend.size(template_filter) != 0
|
820
808
|
|
809
|
+
norm_template = conditional_execute(normalize_under_mask, norm_template)
|
810
|
+
callback_func = conditional_execute(callback, callback_class is not None)
|
821
811
|
norm_func_numerator = conditional_execute(backend.subtract, norm_numerator)
|
822
812
|
norm_func_denominator = conditional_execute(backend.multiply, norm_denominator)
|
823
|
-
template_filter_func = conditional_execute(
|
824
|
-
|
825
|
-
|
826
|
-
fourier_shift = callback_class_args.get("fourier_shift", backend.zeros(arr.ndim))
|
827
|
-
fourier_shift_scores = backend.sum(fourier_shift != 0) != 0
|
813
|
+
template_filter_func = conditional_execute(
|
814
|
+
apply_filter, backend.size(template_filter) != 1
|
815
|
+
)
|
828
816
|
|
829
|
-
template_sum = backend.sum(template)
|
830
817
|
unpadded_slice = tuple(slice(0, stop) for stop in template.shape)
|
831
818
|
for index in range(rotations.shape[0]):
|
832
819
|
rotation = rotations[index]
|
@@ -838,34 +825,24 @@ def corr_scoring(
|
|
838
825
|
use_geometric_center=False,
|
839
826
|
order=interpolation_order,
|
840
827
|
)
|
841
|
-
|
842
|
-
backend.multiply(arr, rotation_norm, out=arr)
|
828
|
+
|
843
829
|
norm_template(arr[unpadded_slice], template_mask, mask_sum)
|
844
830
|
|
845
831
|
rfftn(arr, ft_temp)
|
846
|
-
template_filter_func(ft_temp, template_filter
|
847
|
-
|
832
|
+
template_filter_func(ft_template=ft_temp, template_filter=template_filter)
|
848
833
|
backend.multiply(ft_target, ft_temp, out=ft_temp)
|
849
834
|
irfftn(ft_temp, arr)
|
850
835
|
|
851
836
|
norm_func_numerator(arr, numerator2, out=arr)
|
852
837
|
norm_func_denominator(arr, inv_denominator, out=arr)
|
853
838
|
|
854
|
-
|
855
|
-
arr
|
856
|
-
|
857
|
-
|
858
|
-
|
839
|
+
callback_func(
|
840
|
+
arr,
|
841
|
+
rotation_matrix=rotation,
|
842
|
+
rotation_index=index,
|
843
|
+
**callback_class_args,
|
859
844
|
)
|
860
845
|
|
861
|
-
if callback_class is not None:
|
862
|
-
callback(
|
863
|
-
score,
|
864
|
-
rotation_matrix=rotation,
|
865
|
-
rotation_index=index,
|
866
|
-
**callback_class_args,
|
867
|
-
)
|
868
|
-
|
869
846
|
return callback
|
870
847
|
|
871
848
|
|
@@ -913,32 +890,15 @@ def flc_scoring(
|
|
913
890
|
.. [2] T. Hrabe, Y. Chen, S. Pfeffer, L. Kuhn Cuellar, A.-V. Mangold,
|
914
891
|
and F. Förster, J. Struct. Biol. 178, 177 (2012).
|
915
892
|
"""
|
916
|
-
|
917
|
-
template_mask_buffer, *_ = template_mask
|
918
|
-
filter_buffer, filter_shape, filter_dtype = template_filter
|
919
|
-
|
920
|
-
ft_target_buffer, ft_target_shape, ft_target_dtype = ft_target
|
921
|
-
ft_target2_buffer, *_ = ft_target2
|
922
|
-
|
893
|
+
callback = callback_class
|
923
894
|
if callback_class is not None and isinstance(callback_class, type):
|
924
895
|
callback = callback_class(**callback_class_args)
|
925
|
-
elif not isinstance(callback_class, type):
|
926
|
-
callback = callback_class
|
927
896
|
|
928
|
-
|
929
|
-
|
930
|
-
|
931
|
-
|
932
|
-
)
|
933
|
-
ft_target = backend.sharedarr_to_arr(
|
934
|
-
ft_target_shape, ft_target_dtype, ft_target_buffer
|
935
|
-
)
|
936
|
-
ft_target2 = backend.sharedarr_to_arr(
|
937
|
-
ft_target_shape, ft_target_dtype, ft_target2_buffer
|
938
|
-
)
|
939
|
-
template_filter = backend.sharedarr_to_arr(
|
940
|
-
filter_shape, filter_dtype, filter_buffer
|
941
|
-
)
|
897
|
+
template = backend.sharedarr_to_arr(*template)
|
898
|
+
template_mask = backend.sharedarr_to_arr(*template_mask)
|
899
|
+
ft_target = backend.sharedarr_to_arr(*ft_target)
|
900
|
+
ft_target2 = backend.sharedarr_to_arr(*ft_target2)
|
901
|
+
template_filter = backend.sharedarr_to_arr(*template_filter)
|
942
902
|
|
943
903
|
arr = backend.preallocate_array(fast_shape, real_dtype)
|
944
904
|
temp = backend.preallocate_array(fast_shape, real_dtype)
|
@@ -957,12 +917,10 @@ def flc_scoring(
|
|
957
917
|
temp_fft=ft_temp,
|
958
918
|
)
|
959
919
|
eps = backend.eps(real_dtype)
|
960
|
-
|
961
|
-
|
962
|
-
|
963
|
-
|
964
|
-
fourier_shift = callback_class_args.get("fourier_shift", backend.zeros(arr.ndim))
|
965
|
-
fourier_shift_scores = backend.sum(fourier_shift != 0) != 0
|
920
|
+
template_filter_func = conditional_execute(
|
921
|
+
apply_filter, backend.size(template_filter) != 1
|
922
|
+
)
|
923
|
+
callback_func = conditional_execute(callback, callback_class is not None)
|
966
924
|
|
967
925
|
unpadded_slice = tuple(slice(0, stop) for stop in template.shape)
|
968
926
|
for index in range(rotations.shape[0]):
|
@@ -981,7 +939,7 @@ def flc_scoring(
|
|
981
939
|
# Given the amount of FFTs, might aswell normalize properly
|
982
940
|
n_observations = backend.sum(temp)
|
983
941
|
|
984
|
-
|
942
|
+
normalize_under_mask(
|
985
943
|
template=arr[unpadded_slice],
|
986
944
|
mask=temp[unpadded_slice],
|
987
945
|
mask_intensity=n_observations,
|
@@ -1004,7 +962,7 @@ def flc_scoring(
|
|
1004
962
|
backend.multiply(temp, n_observations, out=temp)
|
1005
963
|
|
1006
964
|
rfftn(arr, ft_temp)
|
1007
|
-
template_filter_func(ft_temp, template_filter
|
965
|
+
template_filter_func(ft_template=ft_temp, template_filter=template_filter)
|
1008
966
|
backend.multiply(ft_target, ft_temp, out=ft_temp)
|
1009
967
|
irfftn(ft_temp, arr)
|
1010
968
|
|
@@ -1013,23 +971,161 @@ def flc_scoring(
|
|
1013
971
|
backend.fill(temp2, 0)
|
1014
972
|
temp2[nonzero_indices] = arr[nonzero_indices] / temp[nonzero_indices]
|
1015
973
|
|
1016
|
-
|
974
|
+
callback_func(
|
975
|
+
temp2,
|
976
|
+
rotation_matrix=rotation,
|
977
|
+
rotation_index=index,
|
978
|
+
**callback_class_args,
|
979
|
+
)
|
980
|
+
|
981
|
+
return callback
|
982
|
+
|
983
|
+
|
984
|
+
def flc_scoring2(
|
985
|
+
template: Tuple[type, Tuple[int], type],
|
986
|
+
template_mask: Tuple[type, Tuple[int], type],
|
987
|
+
ft_target: Tuple[type, Tuple[int], type],
|
988
|
+
ft_target2: Tuple[type, Tuple[int], type],
|
989
|
+
template_filter: Tuple[type, Tuple[int], type],
|
990
|
+
targetshape: Tuple[int],
|
991
|
+
templateshape: Tuple[int],
|
992
|
+
fast_shape: Tuple[int],
|
993
|
+
fast_ft_shape: Tuple[int],
|
994
|
+
rotations: NDArray,
|
995
|
+
real_dtype: type,
|
996
|
+
complex_dtype: type,
|
997
|
+
callback_class: CallbackClass,
|
998
|
+
callback_class_args: Dict,
|
999
|
+
interpolation_order: int,
|
1000
|
+
**kwargs,
|
1001
|
+
) -> CallbackClass:
|
1002
|
+
"""
|
1003
|
+
Computes a normalized cross-correlation score of a target f a template g
|
1004
|
+
and a mask m:
|
1005
|
+
|
1006
|
+
.. math::
|
1007
|
+
|
1008
|
+
\\frac{CC(f, \\frac{g*m - \\overline{g*m}}{\\sigma_{g*m}})}
|
1009
|
+
{N_m * \\sqrt{
|
1010
|
+
\\frac{CC(f^2, m)}{N_m} - (\\frac{CC(f, m)}{N_m})^2}
|
1011
|
+
}
|
1017
1012
|
|
1018
|
-
|
1019
|
-
|
1013
|
+
Where:
|
1014
|
+
|
1015
|
+
.. math::
|
1020
1016
|
|
1021
|
-
|
1022
|
-
|
1017
|
+
CC(f,g) = \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*)
|
1018
|
+
|
1019
|
+
and Nm is the number of voxels within the template mask m.
|
1020
|
+
|
1021
|
+
References
|
1022
|
+
----------
|
1023
|
+
.. [1] W. Wan, S. Khavnekar, J. Wagner, P. Erdmann, and W. Baumeister
|
1024
|
+
Microsc. Microanal. 26, 2516 (2020)
|
1025
|
+
.. [2] T. Hrabe, Y. Chen, S. Pfeffer, L. Kuhn Cuellar, A.-V. Mangold,
|
1026
|
+
and F. Förster, J. Struct. Biol. 178, 177 (2012).
|
1027
|
+
"""
|
1028
|
+
callback = callback_class
|
1029
|
+
if callback_class is not None and isinstance(callback_class, type):
|
1030
|
+
callback = callback_class(**callback_class_args)
|
1031
|
+
|
1032
|
+
# Retrieve objects from shared memory
|
1033
|
+
template = backend.sharedarr_to_arr(*template)
|
1034
|
+
template_mask = backend.sharedarr_to_arr(*template_mask)
|
1035
|
+
ft_target = backend.sharedarr_to_arr(*ft_target)
|
1036
|
+
ft_target2 = backend.sharedarr_to_arr(*ft_target2)
|
1037
|
+
template_filter = backend.sharedarr_to_arr(*template_filter)
|
1038
|
+
|
1039
|
+
arr = backend.preallocate_array(fast_shape, real_dtype)
|
1040
|
+
temp = backend.preallocate_array(fast_shape, real_dtype)
|
1041
|
+
temp2 = backend.preallocate_array(fast_shape, real_dtype)
|
1042
|
+
|
1043
|
+
ft_temp = backend.preallocate_array(fast_ft_shape, complex_dtype)
|
1044
|
+
ft_denom = backend.preallocate_array(fast_ft_shape, complex_dtype)
|
1045
|
+
|
1046
|
+
eps = backend.eps(real_dtype)
|
1047
|
+
template_filter_func = conditional_execute(
|
1048
|
+
apply_filter, backend.size(template_filter) != 1
|
1049
|
+
)
|
1050
|
+
callback_func = conditional_execute(callback, callback_class is not None)
|
1051
|
+
|
1052
|
+
squeeze_axis = tuple(i for i, x in enumerate(template.shape) if x == 1)
|
1053
|
+
squeeze = tuple(
|
1054
|
+
slice(0, stop) if i not in squeeze_axis else 0
|
1055
|
+
for i, stop in enumerate(template.shape)
|
1056
|
+
)
|
1057
|
+
squeeze_fast = tuple(
|
1058
|
+
slice(0, stop) if i not in squeeze_axis else 0
|
1059
|
+
for i, stop in enumerate(fast_shape)
|
1060
|
+
)
|
1061
|
+
squeeze_fast_ft = tuple(
|
1062
|
+
slice(0, stop) if i not in squeeze_axis else 0
|
1063
|
+
for i, stop in enumerate(fast_ft_shape)
|
1064
|
+
)
|
1065
|
+
|
1066
|
+
rfftn, irfftn = backend.build_fft(
|
1067
|
+
fast_shape=temp[squeeze_fast].shape,
|
1068
|
+
fast_ft_shape=fast_ft_shape,
|
1069
|
+
real_dtype=real_dtype,
|
1070
|
+
complex_dtype=complex_dtype,
|
1071
|
+
fftargs=kwargs.get("fftargs", {}),
|
1072
|
+
inverse_fast_shape=fast_shape,
|
1073
|
+
temp_real=arr[squeeze_fast],
|
1074
|
+
temp_fft=ft_temp,
|
1075
|
+
)
|
1076
|
+
for index in range(rotations.shape[0]):
|
1077
|
+
rotation = rotations[index]
|
1078
|
+
backend.fill(arr, 0)
|
1079
|
+
backend.fill(temp, 0)
|
1080
|
+
backend.rotate_array(
|
1081
|
+
arr=template[squeeze],
|
1082
|
+
arr_mask=template_mask[squeeze],
|
1083
|
+
rotation_matrix=rotation,
|
1084
|
+
out=arr[squeeze],
|
1085
|
+
out_mask=temp[squeeze],
|
1086
|
+
use_geometric_center=False,
|
1087
|
+
order=interpolation_order,
|
1023
1088
|
)
|
1089
|
+
# Given the amount of FFTs, might aswell normalize properly
|
1090
|
+
n_observations = backend.sum(temp)
|
1024
1091
|
|
1025
|
-
|
1026
|
-
|
1027
|
-
|
1028
|
-
|
1029
|
-
|
1030
|
-
|
1031
|
-
)
|
1092
|
+
normalize_under_mask(
|
1093
|
+
template=arr[squeeze],
|
1094
|
+
mask=temp[squeeze],
|
1095
|
+
mask_intensity=n_observations,
|
1096
|
+
)
|
1097
|
+
rfftn(temp[squeeze_fast], ft_temp[squeeze_fast_ft])
|
1032
1098
|
|
1099
|
+
backend.multiply(ft_target, ft_temp[squeeze_fast_ft], out=ft_denom)
|
1100
|
+
irfftn(ft_denom, temp)
|
1101
|
+
backend.divide(temp, n_observations, out=temp)
|
1102
|
+
backend.square(temp, out=temp)
|
1103
|
+
|
1104
|
+
backend.multiply(ft_target2, ft_temp[squeeze_fast_ft], out=ft_denom)
|
1105
|
+
irfftn(ft_denom, temp2)
|
1106
|
+
backend.divide(temp2, n_observations, out=temp2)
|
1107
|
+
|
1108
|
+
backend.subtract(temp2, temp, out=temp)
|
1109
|
+
backend.maximum(temp, 0.0, out=temp)
|
1110
|
+
backend.sqrt(temp, out=temp)
|
1111
|
+
backend.multiply(temp, n_observations, out=temp)
|
1112
|
+
|
1113
|
+
rfftn(arr[squeeze_fast], ft_temp[squeeze_fast_ft])
|
1114
|
+
template_filter_func(ft_template=ft_temp, template_filter=template_filter)
|
1115
|
+
backend.multiply(ft_target, ft_temp[squeeze_fast_ft], out=ft_denom)
|
1116
|
+
irfftn(ft_denom, arr)
|
1117
|
+
|
1118
|
+
tol = tol = 1e3 * eps * backend.max(backend.abs(temp))
|
1119
|
+
nonzero_indices = temp > tol
|
1120
|
+
backend.fill(temp2, 0)
|
1121
|
+
temp2[nonzero_indices] = arr[nonzero_indices] / temp[nonzero_indices]
|
1122
|
+
|
1123
|
+
callback_func(
|
1124
|
+
temp2,
|
1125
|
+
rotation_matrix=rotation,
|
1126
|
+
rotation_index=index,
|
1127
|
+
**callback_class_args,
|
1128
|
+
)
|
1033
1129
|
return callback
|
1034
1130
|
|
1035
1131
|
|
@@ -1083,35 +1179,18 @@ def mcc_scoring(
|
|
1083
1179
|
--------
|
1084
1180
|
:py:class:`tme.matching_optimization.MaskedCrossCorrelation`
|
1085
1181
|
"""
|
1086
|
-
|
1087
|
-
ft_target_buffer, ft_target_shape, ft_target_dtype = ft_target
|
1088
|
-
ft_target2_buffer, ft_target_shape, ft_target_dtype = ft_target2
|
1089
|
-
template_mask_buffer, _, _ = template
|
1090
|
-
ft_target_mask_buffer, _, _ = ft_target
|
1091
|
-
filter_buffer, filter_shape, filter_dtype = template_filter
|
1092
|
-
|
1182
|
+
callback = callback_class
|
1093
1183
|
if callback_class is not None and isinstance(callback_class, type):
|
1094
1184
|
callback = callback_class(**callback_class_args)
|
1095
|
-
elif not isinstance(callback_class, type):
|
1096
|
-
callback = callback_class
|
1097
1185
|
|
1098
1186
|
# Retrieve objects from shared memory
|
1099
|
-
|
1100
|
-
|
1101
|
-
|
1102
|
-
)
|
1103
|
-
|
1104
|
-
|
1105
|
-
)
|
1106
|
-
template_mask = backend.sharedarr_to_arr(
|
1107
|
-
template_shape, template_dtype, template_mask_buffer
|
1108
|
-
)
|
1109
|
-
target_mask_ft = backend.sharedarr_to_arr(
|
1110
|
-
ft_target_shape, ft_target_dtype, ft_target_mask_buffer
|
1111
|
-
)
|
1112
|
-
template_filter = backend.sharedarr_to_arr(
|
1113
|
-
filter_shape, filter_dtype, filter_buffer
|
1114
|
-
)
|
1187
|
+
template_buffer, template_shape, template_dtype = template
|
1188
|
+
template = backend.sharedarr_to_arr(*template)
|
1189
|
+
target_ft = backend.sharedarr_to_arr(*ft_target)
|
1190
|
+
target_ft2 = backend.sharedarr_to_arr(*ft_target2)
|
1191
|
+
template_mask = backend.sharedarr_to_arr(*template_mask)
|
1192
|
+
target_mask_ft = backend.sharedarr_to_arr(*ft_target_mask)
|
1193
|
+
template_filter = backend.sharedarr_to_arr(*template_filter)
|
1115
1194
|
|
1116
1195
|
axes = tuple(range(template.ndim))
|
1117
1196
|
eps = backend.eps(real_dtype)
|
@@ -1136,14 +1215,10 @@ def mcc_scoring(
|
|
1136
1215
|
temp_fft=temp_ft,
|
1137
1216
|
)
|
1138
1217
|
|
1139
|
-
|
1140
|
-
|
1141
|
-
|
1142
|
-
axis = tuple(range(template.ndim))
|
1143
|
-
fourier_shift = callback_class_args.get(
|
1144
|
-
"fourier_shift", backend.zeros(template.ndim)
|
1218
|
+
template_filter_func = conditional_execute(
|
1219
|
+
apply_filter, backend.size(template_filter) != 1
|
1145
1220
|
)
|
1146
|
-
|
1221
|
+
callback_func = conditional_execute(callback, callback_class is not None)
|
1147
1222
|
|
1148
1223
|
# Calculate scores across all rotations
|
1149
1224
|
for index in range(rotations.shape[0]):
|
@@ -1165,7 +1240,7 @@ def mcc_scoring(
|
|
1165
1240
|
|
1166
1241
|
# template_rot_ft
|
1167
1242
|
rfftn(template_rot, temp_ft)
|
1168
|
-
template_filter_func(temp_ft, template_filter
|
1243
|
+
template_filter_func(ft_template=temp_ft, template_filter=template_filter)
|
1169
1244
|
irfftn(target_mask_ft * temp_ft, temp2)
|
1170
1245
|
irfftn(target_ft * temp_ft, numerator)
|
1171
1246
|
|
@@ -1221,25 +1296,69 @@ def mcc_scoring(
|
|
1221
1296
|
mask_overlap, axis=axes, keepdims=True
|
1222
1297
|
)
|
1223
1298
|
temp[mask_overlap < number_px_threshold] = 0.0
|
1224
|
-
convolution_mode = kwargs.get("convolution_mode", "full")
|
1225
|
-
|
1226
|
-
if fourier_shift_scores:
|
1227
|
-
temp = backend.roll(temp, shift=fourier_shift, axis=axis)
|
1228
1299
|
|
1229
|
-
|
1230
|
-
temp,
|
1300
|
+
callback_func(
|
1301
|
+
temp,
|
1302
|
+
rotation_matrix=rotation,
|
1303
|
+
rotation_index=index,
|
1304
|
+
**callback_class_args,
|
1231
1305
|
)
|
1232
|
-
if callback_class is not None:
|
1233
|
-
callback(
|
1234
|
-
score,
|
1235
|
-
rotation_matrix=rotation,
|
1236
|
-
rotation_index=index,
|
1237
|
-
**callback_class_args,
|
1238
|
-
)
|
1239
1306
|
|
1240
1307
|
return callback
|
1241
1308
|
|
1242
1309
|
|
1310
|
+
def _setup_template_filter(
|
1311
|
+
matching_data: MatchingData,
|
1312
|
+
rfftn: Callable,
|
1313
|
+
irfftn: Callable,
|
1314
|
+
fast_shape: Tuple[int],
|
1315
|
+
fast_ft_shape: Tuple[int],
|
1316
|
+
):
|
1317
|
+
filter_template = isinstance(matching_data.template_filter, Compose)
|
1318
|
+
filter_target = isinstance(matching_data.target_filter, Compose)
|
1319
|
+
|
1320
|
+
template_filter = backend.full(
|
1321
|
+
shape=(1,), fill_value=1, dtype=backend._default_dtype
|
1322
|
+
)
|
1323
|
+
|
1324
|
+
if not filter_template and not filter_target:
|
1325
|
+
return template_filter
|
1326
|
+
|
1327
|
+
target_temp = backend.astype(
|
1328
|
+
backend.topleft_pad(matching_data.target, fast_shape), backend._default_dtype
|
1329
|
+
)
|
1330
|
+
target_temp_ft = backend.preallocate_array(fast_ft_shape, backend._complex_dtype)
|
1331
|
+
rfftn(target_temp, target_temp_ft)
|
1332
|
+
|
1333
|
+
if isinstance(matching_data.template_filter, Compose):
|
1334
|
+
template_filter = matching_data.template_filter(
|
1335
|
+
shape=fast_shape,
|
1336
|
+
return_real_fourier=True,
|
1337
|
+
shape_is_real_fourier=False,
|
1338
|
+
data_rfft=target_temp_ft,
|
1339
|
+
)
|
1340
|
+
template_filter = template_filter["data"]
|
1341
|
+
template_filter[tuple(0 for _ in range(template_filter.ndim))] = 0
|
1342
|
+
|
1343
|
+
if isinstance(matching_data.target_filter, Compose):
|
1344
|
+
target_filter = matching_data.target_filter(
|
1345
|
+
shape=fast_shape,
|
1346
|
+
return_real_fourier=True,
|
1347
|
+
shape_is_real_fourier=False,
|
1348
|
+
data_rfft=target_temp_ft,
|
1349
|
+
weight_type=None,
|
1350
|
+
)
|
1351
|
+
target_filter = target_filter["data"]
|
1352
|
+
backend.multiply(target_temp_ft, target_filter, out=target_temp_ft)
|
1353
|
+
|
1354
|
+
irfftn(target_temp_ft, target_temp)
|
1355
|
+
matching_data._target = backend.topleft_pad(
|
1356
|
+
target_temp, matching_data.target.shape
|
1357
|
+
)
|
1358
|
+
|
1359
|
+
return template_filter
|
1360
|
+
|
1361
|
+
|
1243
1362
|
def device_memory_handler(func: Callable):
|
1244
1363
|
"""Decorator function providing SharedMemory Handler."""
|
1245
1364
|
|
@@ -1310,18 +1429,17 @@ def scan(
|
|
1310
1429
|
Tuple
|
1311
1430
|
The merged results from callback_class if provided otherwise None.
|
1312
1431
|
"""
|
1432
|
+
matching_data.to_backend()
|
1313
1433
|
shape_diff = backend.subtract(
|
1314
|
-
matching_data.
|
1434
|
+
matching_data._output_target_shape, matching_data._output_template_shape
|
1315
1435
|
)
|
1436
|
+
shape_diff = backend.multiply(shape_diff, ~matching_data._batch_mask)
|
1316
1437
|
if backend.sum(shape_diff < 0) and not pad_fourier:
|
1317
1438
|
warnings.warn(
|
1318
1439
|
"Target is larger than template and Fourier padding is turned off."
|
1319
1440
|
" This can lead to shifted results. You can swap template and target, or"
|
1320
1441
|
" zero-pad the target."
|
1321
1442
|
)
|
1322
|
-
|
1323
|
-
matching_data.to_backend()
|
1324
|
-
|
1325
1443
|
fast_shape, fast_ft_shape, fourier_shift = matching_data.fourier_padding(
|
1326
1444
|
pad_fourier=pad_fourier
|
1327
1445
|
)
|
@@ -1334,6 +1452,15 @@ def scan(
|
|
1334
1452
|
complex_dtype=matching_data._complex_dtype,
|
1335
1453
|
fftargs=fftargs,
|
1336
1454
|
)
|
1455
|
+
|
1456
|
+
# template_filter = _setup_template_filter(
|
1457
|
+
# matching_data=matching_data,
|
1458
|
+
# rfftn=rfftn,
|
1459
|
+
# irfftn=irfftn,
|
1460
|
+
# fast_shape=fast_shape,
|
1461
|
+
# fast_ft_shape=fast_ft_shape,
|
1462
|
+
# )
|
1463
|
+
|
1337
1464
|
setup = matching_setup(
|
1338
1465
|
rfftn=rfftn,
|
1339
1466
|
irfftn=irfftn,
|
@@ -1351,6 +1478,7 @@ def scan(
|
|
1351
1478
|
)
|
1352
1479
|
rfftn, irfftn = None, None
|
1353
1480
|
|
1481
|
+
|
1354
1482
|
template_filter, preprocessor = None, Preprocessor()
|
1355
1483
|
for method, parameters in matching_data.template_filter.items():
|
1356
1484
|
parameters["shape"] = fast_shape
|
@@ -1364,9 +1492,8 @@ def scan(
|
|
1364
1492
|
template_filter = backend.full(
|
1365
1493
|
shape=(1,), fill_value=1, dtype=backend._default_dtype
|
1366
1494
|
)
|
1367
|
-
else:
|
1368
|
-
template_filter = backend.to_backend_array(template_filter)
|
1369
1495
|
|
1496
|
+
template_filter = backend.to_backend_array(template_filter)
|
1370
1497
|
template_filter = backend.astype(template_filter, backend._default_dtype)
|
1371
1498
|
template_filter_buffer = backend.arr_to_sharedarr(
|
1372
1499
|
arr=template_filter,
|
@@ -1388,14 +1515,21 @@ def scan(
|
|
1388
1515
|
callback_class = setup.pop("callback_class", callback_class)
|
1389
1516
|
callback_class_args = setup.pop("callback_class_args", callback_class_args)
|
1390
1517
|
callback_classes = [callback_class for _ in range(n_callback_classes)]
|
1518
|
+
|
1519
|
+
convolution_mode = "same"
|
1520
|
+
if backend.sum(backend.to_backend_array(matching_data._target_pad)) > 0:
|
1521
|
+
convolution_mode = "valid"
|
1522
|
+
|
1523
|
+
|
1524
|
+
callback_class_args["fourier_shift"] = fourier_shift
|
1525
|
+
callback_class_args["convolution_mode"] = convolution_mode
|
1526
|
+
callback_class_args["targetshape"] = setup["targetshape"]
|
1527
|
+
callback_class_args["templateshape"] = setup["templateshape"]
|
1528
|
+
|
1391
1529
|
if callback_class == MaxScoreOverRotations:
|
1392
|
-
score_space_shape = backend.subtract(
|
1393
|
-
matching_data.target.shape,
|
1394
|
-
matching_data._target_pad,
|
1395
|
-
)
|
1396
1530
|
callback_classes = [
|
1397
1531
|
class_name(
|
1398
|
-
score_space_shape=
|
1532
|
+
score_space_shape=fast_shape,
|
1399
1533
|
score_space_dtype=matching_data._default_dtype,
|
1400
1534
|
shared_memory_handler=kwargs.get("shared_memory_handler", None),
|
1401
1535
|
rotation_space_dtype=backend._default_dtype_int,
|
@@ -1435,10 +1569,16 @@ def scan(
|
|
1435
1569
|
for index, rotation in enumerate(rotation_list)
|
1436
1570
|
)
|
1437
1571
|
|
1572
|
+
callbacks = callbacks[0:n_callback_classes]
|
1438
1573
|
callbacks = [
|
1439
|
-
tuple(callback
|
1440
|
-
|
1441
|
-
|
1574
|
+
tuple(callback._postprocess(
|
1575
|
+
fourier_shift = fourier_shift,
|
1576
|
+
convolution_mode = convolution_mode,
|
1577
|
+
targetshape = setup["targetshape"],
|
1578
|
+
templateshape = setup["templateshape"],
|
1579
|
+
shared_memory_handler=kwargs.get("shared_memory_handler", None)
|
1580
|
+
)) if hasattr(callback, "_postprocess") else tuple(callback)
|
1581
|
+
for callback in callbacks if callback is not None
|
1442
1582
|
]
|
1443
1583
|
backend.free_cache()
|
1444
1584
|
|
@@ -1549,11 +1689,13 @@ def scan_subsets(
|
|
1549
1689
|
matching_data._target, matching_data._template = None, None
|
1550
1690
|
matching_data._target_mask, matching_data._template_mask = None, None
|
1551
1691
|
|
1692
|
+
candidates = None
|
1552
1693
|
if callback_class is not None:
|
1553
1694
|
candidates = callback_class.merge(
|
1554
1695
|
results, **callback_class_args, inner_merge=False
|
1555
1696
|
)
|
1556
|
-
|
1697
|
+
|
1698
|
+
return candidates
|
1557
1699
|
|
1558
1700
|
|
1559
1701
|
MATCHING_EXHAUSTIVE_REGISTER = {
|
@@ -1563,6 +1705,7 @@ MATCHING_EXHAUSTIVE_REGISTER = {
|
|
1563
1705
|
"CAM": (cam_setup, corr_scoring),
|
1564
1706
|
"FLCSphericalMask": (flcSphericalMask_setup, corr_scoring),
|
1565
1707
|
"FLC": (flc_setup, flc_scoring),
|
1708
|
+
"FLC2": (flc_setup, flc_scoring2),
|
1566
1709
|
"MCC": (mcc_setup, mcc_scoring),
|
1567
1710
|
}
|
1568
1711
|
|